2025-02-24 16:20:55 +08:00
package middleware
import (
"context"
"fmt"
"net/http"
"one-api/common"
2025-04-16 10:33:43 +08:00
"one-api/common/limiter"
2025-02-24 16:20:55 +08:00
"one-api/setting"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
const (
ModelRequestRateLimitCountMark = "MRRL"
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
// 检查Redis中的请求限制
func checkRedisRateLimit ( ctx context . Context , rdb * redis . Client , key string , maxCount int , duration int64 ) ( bool , error ) {
// 如果maxCount为0, 表示不限制
if maxCount == 0 {
return true , nil
}
// 获取当前计数
length , err := rdb . LLen ( ctx , key ) . Result ( )
if err != nil {
return false , err
}
// 如果未达到限制,允许请求
if length < int64 ( maxCount ) {
return true , nil
}
// 检查时间窗口
oldTimeStr , _ := rdb . LIndex ( ctx , key , - 1 ) . Result ( )
oldTime , err := time . Parse ( timeFormat , oldTimeStr )
if err != nil {
return false , err
}
nowTimeStr := time . Now ( ) . Format ( timeFormat )
nowTime , err := time . Parse ( timeFormat , nowTimeStr )
if err != nil {
return false , err
}
// 如果在时间窗口内已达到限制,拒绝请求
subTime := nowTime . Sub ( oldTime ) . Seconds ( )
if int64 ( subTime ) < duration {
2025-03-02 17:34:39 +08:00
rdb . Expire ( ctx , key , time . Duration ( setting . ModelRequestRateLimitDurationMinutes ) * time . Minute )
2025-02-24 16:20:55 +08:00
return false , nil
}
return true , nil
}
// 记录Redis请求
func recordRedisRequest ( ctx context . Context , rdb * redis . Client , key string , maxCount int ) {
// 如果maxCount为0, 不记录请求
if maxCount == 0 {
return
}
now := time . Now ( ) . Format ( timeFormat )
rdb . LPush ( ctx , key , now )
rdb . LTrim ( ctx , key , 0 , int64 ( maxCount - 1 ) )
2025-03-02 17:34:39 +08:00
rdb . Expire ( ctx , key , time . Duration ( setting . ModelRequestRateLimitDurationMinutes ) * time . Minute )
2025-02-24 16:20:55 +08:00
}
// Redis限流处理器
func redisRateLimitHandler ( duration int64 , totalMaxCount , successMaxCount int ) gin . HandlerFunc {
return func ( c * gin . Context ) {
userId := strconv . Itoa ( c . GetInt ( "id" ) )
ctx := context . Background ( )
rdb := common . RDB
2025-04-16 10:33:43 +08:00
// 1. 检查成功请求数限制
successKey := fmt . Sprintf ( "rateLimit:%s:%s" , ModelRequestRateLimitSuccessCountMark , userId )
allowed , err := checkRedisRateLimit ( ctx , rdb , successKey , successMaxCount , duration )
2025-02-24 16:20:55 +08:00
if err != nil {
2025-04-16 10:33:43 +08:00
fmt . Println ( "检查成功请求数限制失败:" , err . Error ( ) )
2025-02-24 16:20:55 +08:00
abortWithOpenAiMessage ( c , http . StatusInternalServerError , "rate_limit_check_failed" )
return
}
if ! allowed {
2025-04-16 10:33:43 +08:00
abortWithOpenAiMessage ( c , http . StatusTooManyRequests , fmt . Sprintf ( "您已达到请求数限制:%d分钟内最多请求%d次" , setting . ModelRequestRateLimitDurationMinutes , successMaxCount ) )
return
2025-02-24 16:20:55 +08:00
}
2025-04-16 16:36:07 +08:00
//2.检查总请求数限制并记录总请求( 当totalMaxCount为0时会自动跳过, 使用令牌桶限流器
2025-04-16 10:33:43 +08:00
totalKey := fmt . Sprintf ( "rateLimit:%s" , userId )
// 初始化
tb := limiter . New ( ctx , rdb )
allowed , err = tb . Allow (
ctx ,
totalKey ,
limiter . WithCapacity ( int64 ( totalMaxCount ) * duration ) ,
limiter . WithRate ( int64 ( totalMaxCount ) ) ,
limiter . WithRequested ( duration ) ,
)
2025-02-24 16:20:55 +08:00
if err != nil {
2025-04-16 10:33:43 +08:00
fmt . Println ( "检查总请求数限制失败:" , err . Error ( ) )
2025-02-24 16:20:55 +08:00
abortWithOpenAiMessage ( c , http . StatusInternalServerError , "rate_limit_check_failed" )
return
}
2025-04-16 10:33:43 +08:00
2025-02-24 16:20:55 +08:00
if ! allowed {
2025-04-16 10:33:43 +08:00
abortWithOpenAiMessage ( c , http . StatusTooManyRequests , fmt . Sprintf ( "您已达到总请求数限制:%d分钟内最多请求%d次, 包括失败次数, 请检查您的请求是否正确" , setting . ModelRequestRateLimitDurationMinutes , totalMaxCount ) )
2025-02-24 16:20:55 +08:00
}
// 4. 处理请求
c . Next ( )
// 5. 如果请求成功,记录成功请求
if c . Writer . Status ( ) < 400 {
recordRedisRequest ( ctx , rdb , successKey , successMaxCount )
}
}
}
// 内存限流处理器
func memoryRateLimitHandler ( duration int64 , totalMaxCount , successMaxCount int ) gin . HandlerFunc {
2025-03-02 17:34:39 +08:00
inMemoryRateLimiter . Init ( time . Duration ( setting . ModelRequestRateLimitDurationMinutes ) * time . Minute )
2025-02-24 16:20:55 +08:00
return func ( c * gin . Context ) {
userId := strconv . Itoa ( c . GetInt ( "id" ) )
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
// 1. 检查总请求数限制( 当totalMaxCount为0时跳过)
if totalMaxCount > 0 && ! inMemoryRateLimiter . Request ( totalKey , totalMaxCount , duration ) {
c . Status ( http . StatusTooManyRequests )
c . Abort ( )
return
}
// 2. 检查成功请求数限制
// 使用一个临时key来检查限制, 这样可以避免实际记录
checkKey := successKey + "_check"
if ! inMemoryRateLimiter . Request ( checkKey , successMaxCount , duration ) {
c . Status ( http . StatusTooManyRequests )
c . Abort ( )
return
}
// 3. 处理请求
c . Next ( )
// 4. 如果请求成功,记录到实际的成功请求计数中
if c . Writer . Status ( ) < 400 {
inMemoryRateLimiter . Request ( successKey , successMaxCount , duration )
}
}
}
// ModelRequestRateLimit 模型请求限流中间件
func ModelRequestRateLimit ( ) func ( c * gin . Context ) {
2025-03-06 16:32:11 +08:00
return func ( c * gin . Context ) {
// 在每个请求时检查是否启用限流
if ! setting . ModelRequestRateLimitEnabled {
c . Next ( )
return
}
2025-03-02 17:34:39 +08:00
2025-05-05 07:31:54 +08:00
// 计算通用限流参数
2025-03-06 16:32:11 +08:00
duration := int64 ( setting . ModelRequestRateLimitDurationMinutes * 60 )
2025-02-24 16:20:55 +08:00
2025-05-05 07:31:54 +08:00
// 获取用户组
group := c . GetString ( "token_group" )
if group == "" {
group = c . GetString ( "group" )
}
if group == "" {
group = "default" // 默认组
}
// 尝试获取用户组特定的限制
groupTotalCount , groupSuccessCount , found := setting . GetGroupRateLimit ( group )
// 确定最终的限制值
finalTotalCount := setting . ModelRequestRateLimitCount // 默认使用全局总次数限制
finalSuccessCount := setting . ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制
if found {
// 如果找到用户组特定限制,则使用它们
finalTotalCount = groupTotalCount
finalSuccessCount = groupSuccessCount
common . LogWarn ( c . Request . Context ( ) , fmt . Sprintf ( "Using rate limit for group '%s': total=%d, success=%d" , group , finalTotalCount , finalSuccessCount ) )
} else {
common . LogInfo ( c . Request . Context ( ) , fmt . Sprintf ( "No specific rate limit found for group '%s', using global limits: total=%d, success=%d" , group , finalTotalCount , finalSuccessCount ) )
}
// 根据存储类型选择并执行限流处理器,传入最终确定的限制值
2025-03-06 16:32:11 +08:00
if common . RedisEnabled {
2025-05-05 07:31:54 +08:00
redisRateLimitHandler ( duration , finalTotalCount , finalSuccessCount ) ( c )
2025-03-06 16:32:11 +08:00
} else {
2025-05-05 07:31:54 +08:00
memoryRateLimitHandler ( duration , finalTotalCount , finalSuccessCount ) ( c )
2025-03-06 16:32:11 +08:00
}
2025-02-24 16:20:55 +08:00
}
}