怎么使用Go+Redis实现常见限流算法
使用Redis实现固定窗口比较简单,主要是由于固定窗口同时只会存在一个窗口,所以我们可以在第一次进入窗口时使用pexpire命令设置过期时间为窗口时间大小,这样窗口会随过期时间而失效,同时我们使用incr命令增加窗口计数。
因为我们需要在counter==1的时候设置窗口的过期时间,为了保证原子性,我们使用简单的Lua脚本实现。
const fixedWindowLimiterTryAcquireRedisScript = `-- ARGV[1]: 窗口时间大小
-- ARGV[2]: 窗口请求上限
local window = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
-- 获取原始值
local counter = tonumber(redis.call("
get"
, KEYS[1]))
if counter == nil then
counter = 0
end
-- 若到达窗口请求上限,请求失败
if counter >
= limit then
return 0
end
-- 窗口值+1
redis.call("
incr"
, KEYS[1])
if counter == 0 then
redis.call("
pexpire"
, KEYS[1], window)
end
return 1
` package redis
import (
"
context"
"
errors"
"
github.com/go-redis/redis/v8"
"
time"
)
// FixedWindowLimiter 固定窗口限流器
type FixedWindowLimiter struct {
limit int // 窗口请求上限
window int // 窗口时间大小
client *redis.Client // Redis客户端
script *redis.Script // TryAcquire脚本
}
func NewFixedWindowLimiter(client *redis.Client, limit int, window time.Duration) (*FixedWindowLimiter, error) {
// redis过期时间精度最大到毫秒,因此窗口必须能被毫秒整除
if window%time.Millisecond != 0 {
return nil, errors.New("
the window uint must not be less than millisecond"
)
}
return &
FixedWindowLimiter{
limit: limit,
window: int(window / time.Millisecond),
client: client,
script: redis.NewScript(fixedWindowLimiterTryAcquireRedisScript),
}, nil
}
func (l *FixedWindowLimiter) TryAcquire(ctx context.Context, resource string) error {
success, err := l.script.Run(ctx, l.client, []string{resource}, l.window, l.limit).Bool()
if err != nil {
return err
}
// 若到达窗口请求上限,请求失败
if !success {
return ErrAcquireFailed
}
return nil
} 滑动窗口hash实现
我们使用Redis的hash存储每个小窗口的计数,每次请求会把所有有效窗口的计数累加到count,使用hdel删除失效窗口,最后判断窗口的总计数是否大于上限。
我们基本上把所有的逻辑都放到Lua脚本里面,其中大头是对hash的遍历,时间复杂度是O(N),N是小窗口数量,所以小窗口数量最好不要太多。
const slidingWindowLimiterTryAcquireRedisScriptHashImpl = `-- ARGV[1]: 窗口时间大小
-- ARGV[2]: 窗口请求上限
-- ARGV[3]: 当前小窗口值
-- ARGV[4]: 起始小窗口值
local window = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local currentSmallWindow = tonumber(ARGV[3])
local startSmallWindow = tonumber(ARGV[4])
-- 计算当前窗口的请求总数
local counters = redis.call("
hgetall"
, KEYS[1])
local count = 0
for i = 1, #(counters) / 2 do
local smallWindow = tonumber(counters[i * 2 - 1])
local counter = tonumber(counters[i * 2])
if smallWindow <
startSmallWindow then
redis.call("
hdel"
, KEYS[1], smallWindow)
else
count = count + counter
end
end
-- 若到达窗口请求上限,请求失败
if count >
= limit then
return 0
end
-- 若没到窗口请求上限,当前小窗口计数器+1,请求成功
redis.call("
hincrby"
, KEYS[1], currentSmallWindow, 1)
redis.call("
pexpire"
, KEYS[1], window)
return 1
` package redis
import (
"
context"
"
errors"
"
github.com/go-redis/redis/v8"
"
time"
)
// SlidingWindowLimiter 滑动窗口限流器
type SlidingWindowLimiter struct {
limit int // 窗口请求上限
window int64 // 窗口时间大小
smallWindow int64 // 小窗口时间大小
smallWindows int64 // 小窗口数量
client *redis.Client // Redis客户端
script *redis.Script // TryAcquire脚本
}
func NewSlidingWindowLimiter(client *redis.Client, limit int, window, smallWindow time.Duration) (
*SlidingWindowLimiter, error) {
// redis过期时间精度最大到毫秒,因此窗口必须能被毫秒整除
if window%time.Millisecond != 0 || smallWindow%time.Millisecond != 0 {
return nil, errors.New("
the window uint must not be less than millisecond"
)
}
// 窗口时间必须能够被小窗口时间整除
if window%smallWindow != 0 {
return nil, errors.New("
window cannot be split by integers"
)
}
return &
SlidingWindowLimiter{
limit: limit,
window: int64(window / time.Millisecond),
smallWindow: int64(smallWindow / time.Millisecond),
smallWindows: int64(window / smallWindow),
client: client,
script: redis.NewScript(slidingWindowLimiterTryAcquireRedisScriptHashImpl),
}, nil
}
func (l *SlidingWindowLimiter) TryAcquire(ctx context.Context, resource string) error {
// 获取当前小窗口值
currentSmallWindow := time.Now().UnixMilli() / l.smallWindow * l.smallWindow
// 获取起始小窗口值
startSmallWindow := currentSmallWindow - l.smallWindow*(l.smallWindows-1)
success, err := l.script.Run(
ctx, l.client, []string{resource}, l.window, l.limit, currentSmallWindow, startSmallWindow).Bool()
if err != nil {
return err
}
// 若到达窗口请求上限,请求失败
if !success {
return ErrAcquireFailed
}
return nil
} list实现
如果小窗口数量特别多,可以使用list优化时间复杂度,list的结构是:
[counter, smallWindow1, count1, smallWindow2, count2, smallWindow3, count3...]
也就是我们使用list的第一个元素存储计数器,每个窗口用两个元素表示,第一个元素表示小窗口值,第二个元素表示这个小窗口的计数。由于Redis Lua脚本不支持字符串分割函数,因此不能将小窗口的值和计数放在同一元素中。
具体操作流程:
1.获取list长度
2.如果长度是0,设置counter,长度+1
3.如果长度大于1,获取第二第三个元素
如果该值小于起始小窗口值,counter-第三个元素的值,删除第二第三个元素,长度-2
4.如果counter大于等于limit,请求失败
5.如果长度大于1,获取倒数第二第一个元素
如果倒数第二个元素小窗口值大于等于当前小窗口值,表示当前请求因为网络延迟的问题,到达服务器的时候,窗口已经过时了,把倒数第二个元素当成当前小窗口(因为它更新),倒数第一个元素值+1
否则,添加新的窗口值,添加新的计数(1),更新过期时间
6.否则,添加新的窗口值,添加新的计数(1),更新过期时间
7.counter + 1
8.返回成功
const slidingWindowLimiterTryAcquireRedisScriptListImpl = `-- ARGV[1]: 窗口时间大小
-- ARGV[2]: 窗口请求上限
-- ARGV[3]: 当前小窗口值
-- ARGV[4]: 起始小窗口值
local window = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local currentSmallWindow = tonumber(ARGV[3])
local startSmallWindow = tonumber(ARGV[4])
-- 获取list长度
local len = redis.call("
llen"
, KEYS[1])
-- 如果长度是0,设置counter,长度+1
local counter = 0
if len == 0 then
redis.call("
rpush"
, KEYS[1], 0)
redis.call("
pexpire"
, KEYS[1], window)
len = len + 1
else
-- 如果长度大于1,获取第二第个元素
local smallWindow1 = tonumber(redis.call("
lindex"
, KEYS[1], 1))
counter = tonumber(redis.call("
lindex"
, KEYS[1], 0))
-- 如果该值小于起始小窗口值
if smallWindow1 <
startSmallWindow then
local count1 = redis.call("
lindex"
, KEYS[1], 2)
-- counter-第三个元素的值
counter = counter - count1
-- 长度-2
len = len - 2
-- 删除第二第三个元素
redis.call("
lrem"
, KEYS[1], 1, smallWindow1)
redis.call("
lrem"
, KEYS[1], 1, count1)
end
end
-- 若到达窗口请求上限,请求失败
if counter >
= limit then
return 0
end
-- 如果长度大于1,获取倒数第二第一个元素
if len >
1 then
local smallWindown = tonumber(redis.call("
lindex"
, KEYS[1], -2))
-- 如果倒数第二个元素小窗口值大于等于当前小窗口值
if smallWindown >
= currentSmallWindow then
-- 把倒数第二个元素当成当前小窗口(因为它更新),倒数第一个元素值+1
local countn = redis.call("
lindex"
, KEYS[1], -1)
redis.call("
lset"
, KEYS[1], -1, countn + 1)
else
-- 否则,添加新的窗口值,添加新的计数(1),更新过期时间
redis.call("
rpush"
, KEYS[1], currentSmallWindow, 1)
redis.call("
pexpire"
, KEYS[1], window)
end
else
-- 否则,添加新的窗口值,添加新的计数(1),更新过期时间
redis.call("
rpush"
, KEYS[1], currentSmallWindow, 1)
redis.call("
pexpire"
, KEYS[1], window)
end
-- counter + 1并更新
redis.call("
lset"
, KEYS[1], 0, counter + 1)
return 1
`
算法都是操作list头部或者尾部,所以时间复杂度接近O(1)
漏桶算法漏桶需要保存当前水位和上次放水时间,因此我们使用hash来保存这两个值。
const leakyBucketLimiterTryAcquireRedisScript = `-- ARGV[1]: 最高水位
-- ARGV[2]: 水流速度/秒
-- ARGV[3]: 当前时间(秒)
local peakLevel = tonumber(ARGV[1])
local currentVelocity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local lastTime = tonumber(redis.call("
hget"
, KEYS[1], "
lastTime"
))
local currentLevel = tonumber(redis.call("
hget"
, KEYS[1], "
currentLevel"
))
-- 初始化
if lastTime == nil then
lastTime = now
currentLevel = 0
redis.call("
hmset"
, KEYS[1], "
currentLevel"
, currentLevel, "
lastTime"
, lastTime)
end
-- 尝试放水
-- 距离上次放水的时间
local interval = now - lastTime
if interval >
0 then
-- 当前水位-距离上次放水的时间(秒)*水流速度
local newLevel = currentLevel - interval * currentVelocity
if newLevel <
0 then
newLevel = 0
end
currentLevel = newLevel
redis.call("
hmset"
, KEYS[1], "
currentLevel"
, newLevel, "
lastTime"
, now)
end
-- 若到达最高水位,请求失败
if currentLevel >
= peakLevel then
return 0
end
-- 若没有到达最高水位,当前水位+1,请求成功
redis.call("
hincrby"
, KEYS[1], "
currentLevel"
, 1)
redis.call("
expire"
, KEYS[1], peakLevel / currentVelocity)
return 1
` package redis
import (
"
context"
"
github.com/go-redis/redis/v8"
"
time"
)
// LeakyBucketLimiter 漏桶限流器
type LeakyBucketLimiter struct {
peakLevel int // 最高水位
currentVelocity int // 水流速度/秒
client *redis.Client // Redis客户端
script *redis.Script // TryAcquire脚本
}
func NewLeakyBucketLimiter(client *redis.Client, peakLevel, currentVelocity int) *LeakyBucketLimiter {
return &
LeakyBucketLimiter{
peakLevel: peakLevel,
currentVelocity: currentVelocity,
client: client,
script: redis.NewScript(leakyBucketLimiterTryAcquireRedisScript),
}
}
func (l *LeakyBucketLimiter) TryAcquire(ctx context.Context, resource string) error {
// 当前时间
now := time.Now().Unix()
success, err := l.script.Run(ctx, l.client, []string{resource}, l.peakLevel, l.currentVelocity, now).Bool()
if err != nil {
return err
}
// 若到达窗口请求上限,请求失败
if !success {
return ErrAcquireFailed
}
return nil
} 令牌桶
令牌桶可以看作是漏桶的相反算法,它们一个是把水倒进桶里,一个是从桶里获取令牌。
const tokenBucketLimiterTryAcquireRedisScript = `-- ARGV[1]: 容量
-- ARGV[2]: 发放令牌速率/秒
-- ARGV[3]: 当前时间(秒)
local capacity = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local lastTime = tonumber(redis.call("
hget"
, KEYS[1], "
lastTime"
))
local currentTokens = tonumber(redis.call("
hget"
, KEYS[1], "
currentTokens"
))
-- 初始化
if lastTime == nil then
lastTime = now
currentTokens = capacity
redis.call("
hmset"
, KEYS[1], "
currentTokens"
, currentTokens, "
lastTime"
, lastTime)
end
-- 尝试发放令牌
-- 距离上次发放令牌的时间
local interval = now - lastTime
if interval >
0 then
-- 当前令牌数量+距离上次发放令牌的时间(秒)*发放令牌速率
local newTokens = currentTokens + interval * rate
if newTokens >
capacity then
newTokens = capacity
end
currentTokens = newTokens
redis.call("
hmset"
, KEYS[1], "
currentTokens"
, newTokens, "
lastTime"
, now)
end
-- 如果没有令牌,请求失败
if currentTokens == 0 then
return 0
end
-- 果有令牌,当前令牌-1,请求成功
redis.call("
hincrby"
, KEYS[1], "
currentTokens"
, -1)
redis.call("
expire"
, KEYS[1], capacity / rate)
return 1
` package redis
import (
"
context"
"
github.com/go-redis/redis/v8"
"
time"
)
// TokenBucketLimiter 令牌桶限流器
type TokenBucketLimiter struct {
capacity int // 容量
rate int // 发放令牌速率/秒
client *redis.Client // Redis客户端
script *redis.Script // TryAcquire脚本
}
func NewTokenBucketLimiter(client *redis.Client, capacity, rate int) *TokenBucketLimiter {
return &
TokenBucketLimiter{
capacity: capacity,
rate: rate,
client: client,
script: redis.NewScript(tokenBucketLimiterTryAcquireRedisScript),
}
}
func (l *TokenBucketLimiter) TryAcquire(ctx context.Context, resource string) error {
// 当前时间
now := time.Now().Unix()
success, err := l.script.Run(ctx, l.client, []string{resource}, l.capacity, l.rate, now).Bool()
if err != nil {
return err
}
// 若到达窗口请求上限,请求失败
if !success {
return ErrAcquireFailed
}
return nil
} 滑动日志
算法流程与滑动窗口相同,只是它可以指定多个策略,同时在请求失败的时候,需要通知调用方是被哪个策略所拦截。
const slidingLogLimiterTryAcquireRedisScriptHashImpl = `-- ARGV[1]: 当前小窗口值
-- ARGV[2]: 第一个策略的窗口时间大小
-- ARGV[i * 2 + 1]: 每个策略的起始小窗口值
-- ARGV[i * 2 + 2]: 每个策略的窗口请求上限
local currentSmallWindow = tonumber(ARGV[1])
-- 第一个策略的窗口时间大小
local window = tonumber(ARGV[2])
-- 第一个策略的起始小窗口值
local startSmallWindow = tonumber(ARGV[3])
local strategiesLen = #(ARGV) / 2 - 1
-- 计算每个策略当前窗口的请求总数
local counters = redis.call("
hgetall"
, KEYS[1])
local counts = {}
-- 初始化counts
for j = 1, strategiesLen do
counts[j] = 0
end
for i = 1, #(counters) / 2 do
local smallWindow = tonumber(counters[i * 2 - 1])
local counter = tonumber(counters[i * 2])
if smallWindow <
startSmallWindow then
redis.call("
hdel"
, KEYS[1], smallWindow)
else
for j = 1, strategiesLen do
if smallWindow >
= tonumber(ARGV[j * 2 + 1]) then
counts[j] = counts[j] + counter
end
end
end
end
-- 若到达对应策略窗口请求上限,请求失败,返回违背的策略下标
for i = 1, strategiesLen do
if counts[i] >
= tonumber(ARGV[i * 2 + 2]) then
return i - 1
end
end
-- 若没到窗口请求上限,当前小窗口计数器+1,请求成功
redis.call("
hincrby"
, KEYS[1], currentSmallWindow, 1)
redis.call("
pexpire"
, KEYS[1], window)
return -1
` package redis
import (
"
context"
"
errors"
"
fmt"
"
github.com/go-redis/redis/v8"
"
sort"
"
time"
)
// ViolationStrategyError 违背策略错误
type ViolationStrategyError struct {
Limit int // 窗口请求上限
Window time.Duration // 窗口时间大小
}
func (e *ViolationStrategyError) Error() string {
return fmt.Sprintf("
violation strategy that limit = %d and window = %d"
, e.Limit, e.Window)
}
// SlidingLogLimiterStrategy 滑动日志限流器的策略
type SlidingLogLimiterStrategy struct {
limit int // 窗口请求上限
window int64 // 窗口时间大小
smallWindows int64 // 小窗口数量
}
func NewSlidingLogLimiterStrategy(limit int, window time.Duration) *SlidingLogLimiterStrategy {
return &
SlidingLogLimiterStrategy{
limit: limit,
window: int64(window),
}
}
// SlidingLogLimiter 滑动日志限流器
type SlidingLogLimiter struct {
strategies []*SlidingLogLimiterStrategy // 滑动日志限流器策略列表
smallWindow int64 // 小窗口时间大小
client *redis.Client // Redis客户端
script *redis.Script // TryAcquire脚本
}
func NewSlidingLogLimiter(client *redis.Client, smallWindow time.Duration, strategies ...*SlidingLogLimiterStrategy) (
*SlidingLogLimiter, error) {
// 复制策略避免被修改
strategies = append(make([]*SlidingLogLimiterStrategy, 0, len(strategies)), strategies...)
// 不能不设置策略
if len(strategies) == 0 {
return nil, errors.New("
must be set strategies"
)
}
// redis过期时间精度最大到毫秒,因此窗口必须能被毫秒整除
if smallWindow%time.Millisecond != 0 {
return nil, errors.New("
the window uint must not be less than millisecond"
)
}
smallWindow = smallWindow / time.Millisecond
for _, strategy := range strategies {
if strategy.window%int64(time.Millisecond) != 0 {
return nil, errors.New("
the window uint must not be less than millisecond"
)
}
strategy.window = strategy.window / int64(time.Millisecond)
}
// 排序策略,窗口时间大的排前面,相同窗口上限大的排前面
sort.Slice(strategies, func(i, j int) bool {
a, b := strategies[i], strategies[j]
if a.window == b.window {
return a.limit >
b.limit
}
return a.window >
b.window
})
for i, strategy := range strategies {
// 随着窗口时间变小,窗口上限也应该变小
if i >
0 {
if strategy.limit >
= strategies[i-1].limit {
return nil, errors.New("
the smaller window should be the smaller limit"
)
}
}
// 窗口时间必须能够被小窗口时间整除
if strategy.window%int64(smallWindow) != 0 {
return nil, errors.New("
window cannot be split by integers"
)
}
strategy.smallWindows = strategy.window / int64(smallWindow)
}
return &
SlidingLogLimiter{
strategies: strategies,
smallWindow: int64(smallWindow),
client: client,
script: redis.NewScript(slidingLogLimiterTryAcquireRedisScriptHashImpl),
}, nil
}
func (l *SlidingLogLimiter) TryAcquire(ctx context.Context, resource string) error {
// 获取当前小窗口值
currentSmallWindow := time.Now().UnixMilli() / l.smallWindow * l.smallWindow
args := make([]interface{}, len(l.strategies)*2+2)
args[0] = currentSmallWindow
args[1] = l.strategies[0].window
// 获取每个策略的起始小窗口值
for i, strategy := range l.strategies {
args[i*2+2] = currentSmallWindow - l.smallWindow*(strategy.smallWindows-1)
args[i*2+3] = strategy.limit
}
index, err := l.script.Run(
ctx, l.client, []string{resource}, args...).Int()
if err != nil {
return err
}
// 若到达窗口请求上限,请求失败
if index != -1 {
return &
ViolationStrategyError{
Limit: l.strategies[index].limit,
Window: time.Duration(l.strategies[index].window),
}
}
return nil
}
限流算法是一种常用的防止应用程序被恶意攻击或故障瘫痪的手段。本文介绍如何使用Go+Redis实现常见的限流算法,让您的应用程序更加安全可靠。
一、什么是限流算法?
限流算法是指系统对访问进行限制,通过限制访问数量或速度来保护系统免受过载和崩溃的影响。常见的限流算法有令牌桶算法、漏桶算法、计数器算法等。
二、令牌桶算法
令牌桶算法是一种基于令牌的限流算法,用于控制资源的访问速率和流量,保证服务的可用性和稳定性。在令牌桶算法中,访问者需要获取一个令牌才能进行访问,而令牌的数量是有限的。当所有的令牌被使用完了,访问者就需要等待资源的下一批令牌。
三、漏桶算法
漏桶算法是一种基于漏桶的限流算法,用于控制数据的流量。在漏桶算法中,数据会以固定的速率流出漏桶,而当漏桶中数据量达到一定阈值时,会被直接丢弃,从而保证了流量的稳定性和可控性。
四、计数器算法
计数器算法是一种基于计数器的限流算法,用于制定一定的频率限制,从而控制系统的访问量。在计数器算法中,系统会对每个用户的访问次数进行计数,当达到设定的限制阈值时,系统会拒绝访问。
五、使用Go+Redis实现限流算法
在Go+Redis中,实现限流算法的方法非常简单。我们只需要使用Redis提供的计数器、时间戳、哈希表、列表等功能,就可以轻松实现令牌桶、漏桶和计数器等限流算法。
六、使用Redis实现令牌桶算法示例
以下是使用Redis实现令牌桶算法的示例代码:
func getToken() bool {
if r := redis.NewClient(&redis.Options{
Addr: \"localhost:6379\",
Password: \"\", // no password set
DB: 0, // use default DB
}); r != nil {
defer r.Close()
pipe := r.Pipeline()
now := time.Now().UnixNano()
// 尝试放令牌
pipe.ZAdd(\"tokens\", &redis.Z{
Score: now,
Member: now,
})
// 删除指定范围之外的令牌
pipe.ZRemRangeByScore(\"tokens\", \"0\", strconv.FormatInt(now-int64(tokenInterval), 10))
// 获取剩余的令牌数
pipe.ZCard(\"tokens\")
results, _ := pipe.Exec()
// 获取令牌数
count := int(results[2].(*redis.IntCmd).Val())
return count <= tokenMax
}
return false
}
七、总结
限流算法是一种非常强大的保护系统的手段,在高并发访问量、用户攻击和故障处理等场景下都能有效地保证应用的可用性和稳定性。本文介绍了常见的令牌桶、漏桶和计数器算法,并提供了使用Go+Redis实现令牌桶算法的示例代码,希望对大家的开发工作有所帮助。