2023-06-20 19:09:49 +08:00
|
|
|
|
package model
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"encoding/json"
|
2023-06-21 17:04:18 +08:00
|
|
|
|
"errors"
|
2023-06-20 19:09:49 +08:00
|
|
|
|
"fmt"
|
2023-06-21 17:04:18 +08:00
|
|
|
|
"math/rand"
|
2023-06-20 19:09:49 +08:00
|
|
|
|
"one-api/common"
|
2023-09-17 19:18:16 +08:00
|
|
|
|
"sort"
|
2023-06-21 17:26:26 +08:00
|
|
|
|
"strconv"
|
2023-06-21 17:04:18 +08:00
|
|
|
|
"strings"
|
2023-06-20 19:09:49 +08:00
|
|
|
|
"sync"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2023-07-23 19:26:37 +08:00
|
|
|
|
var (
|
|
|
|
|
|
TokenCacheSeconds = common.SyncFrequency
|
|
|
|
|
|
UserId2GroupCacheSeconds = common.SyncFrequency
|
|
|
|
|
|
UserId2QuotaCacheSeconds = common.SyncFrequency
|
|
|
|
|
|
UserId2StatusCacheSeconds = common.SyncFrequency
|
2023-06-20 19:09:49 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2024-01-25 20:09:06 +08:00
|
|
|
|
// 仅用于定时同步缓存
|
|
|
|
|
|
var token2UserId = make(map[string]int)
|
|
|
|
|
|
var token2UserIdLock sync.RWMutex
|
|
|
|
|
|
|
|
|
|
|
|
func cacheSetToken(token *Token) error {
|
|
|
|
|
|
if !common.RedisEnabled {
|
|
|
|
|
|
return token.SelectUpdate()
|
|
|
|
|
|
}
|
|
|
|
|
|
jsonBytes, err := json.Marshal(token)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
2023-10-22 18:38:29 +08:00
|
|
|
|
}
|
2024-01-25 20:09:06 +08:00
|
|
|
|
err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
token2UserIdLock.Lock()
|
|
|
|
|
|
defer token2UserIdLock.Unlock()
|
|
|
|
|
|
token2UserId[token.Key] = token.UserId
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
|
|
|
|
|
|
func CacheGetTokenByKey(key string) (*Token, error) {
|
2023-06-20 19:09:49 +08:00
|
|
|
|
if !common.RedisEnabled {
|
2024-01-25 20:09:06 +08:00
|
|
|
|
return GetTokenByKey(key)
|
2023-06-20 19:09:49 +08:00
|
|
|
|
}
|
2024-01-25 20:09:06 +08:00
|
|
|
|
var token *Token
|
2024-01-26 15:52:23 +08:00
|
|
|
|
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
2023-06-20 19:09:49 +08:00
|
|
|
|
if err != nil {
|
2024-01-25 20:09:06 +08:00
|
|
|
|
// 如果缓存中不存在,则从数据库中获取
|
|
|
|
|
|
token, err = GetTokenByKey(key)
|
2023-06-20 19:09:49 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
2024-01-25 20:09:06 +08:00
|
|
|
|
err = cacheSetToken(token)
|
|
|
|
|
|
return token, nil
|
|
|
|
|
|
}
|
2024-01-26 15:52:23 +08:00
|
|
|
|
// 如果缓存中存在,则续期时间
|
|
|
|
|
|
err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second)
|
2024-01-25 20:09:06 +08:00
|
|
|
|
err = json.Unmarshal([]byte(tokenObjectString), &token)
|
|
|
|
|
|
return token, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func SyncTokenCache(frequency int) {
|
|
|
|
|
|
for {
|
|
|
|
|
|
time.Sleep(time.Duration(frequency) * time.Second)
|
|
|
|
|
|
common.SysLog("syncing tokens from database")
|
|
|
|
|
|
token2UserIdLock.Lock()
|
|
|
|
|
|
// 从token2UserId中获取所有的key
|
|
|
|
|
|
var copyToken2UserId = make(map[string]int)
|
|
|
|
|
|
for s, i := range token2UserId {
|
|
|
|
|
|
copyToken2UserId[s] = i
|
2023-06-20 19:09:49 +08:00
|
|
|
|
}
|
2024-01-25 20:09:06 +08:00
|
|
|
|
token2UserId = make(map[string]int)
|
|
|
|
|
|
token2UserIdLock.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
for key := range copyToken2UserId {
|
|
|
|
|
|
token, err := GetTokenByKey(key)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
// 如果数据库中不存在,则删除缓存
|
|
|
|
|
|
common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
|
|
|
|
|
|
//delete redis
|
|
|
|
|
|
err := common.RedisDel(fmt.Sprintf("token:%s", key))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
2024-01-26 16:09:50 +08:00
|
|
|
|
// 如果数据库中存在,先检查redis
|
|
|
|
|
|
_, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
// 如果redis中不存在,则跳过
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
err = cacheSetToken(token)
|
2024-01-25 20:09:06 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2023-06-20 19:09:49 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func CacheGetUserGroup(id int) (group string, err error) {
|
|
|
|
|
|
if !common.RedisEnabled {
|
|
|
|
|
|
return GetUserGroup(id)
|
|
|
|
|
|
}
|
|
|
|
|
|
group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
group, err = GetUserGroup(id)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return "", err
|
|
|
|
|
|
}
|
2023-07-23 19:26:37 +08:00
|
|
|
|
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
2023-06-20 19:09:49 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
common.SysError("Redis set user group error: " + err.Error())
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return group, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2024-01-11 18:39:21 +08:00
|
|
|
|
func CacheGetUsername(id int) (username string, err error) {
|
|
|
|
|
|
if !common.RedisEnabled {
|
|
|
|
|
|
return GetUsernameById(id)
|
|
|
|
|
|
}
|
|
|
|
|
|
username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
|
|
|
|
|
|
if err != nil {
|
2024-01-12 13:15:43 +08:00
|
|
|
|
username, err = GetUsernameById(id)
|
2024-01-11 18:39:21 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return "", err
|
|
|
|
|
|
}
|
|
|
|
|
|
err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
common.SysError("Redis set user group error: " + err.Error())
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return username, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-06-21 17:26:26 +08:00
|
|
|
|
func CacheGetUserQuota(id int) (quota int, err error) {
|
|
|
|
|
|
if !common.RedisEnabled {
|
|
|
|
|
|
return GetUserQuota(id)
|
|
|
|
|
|
}
|
|
|
|
|
|
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
quota, err = GetUserQuota(id)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return 0, err
|
|
|
|
|
|
}
|
2023-07-23 19:26:37 +08:00
|
|
|
|
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
2023-06-21 17:26:26 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
common.SysError("Redis set user quota error: " + err.Error())
|
|
|
|
|
|
}
|
|
|
|
|
|
return quota, err
|
|
|
|
|
|
}
|
|
|
|
|
|
quota, err = strconv.Atoi(quotaString)
|
|
|
|
|
|
return quota, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-06-27 19:22:58 +08:00
|
|
|
|
func CacheUpdateUserQuota(id int) error {
|
|
|
|
|
|
if !common.RedisEnabled {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
quota, err := GetUserQuota(id)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
2023-07-23 19:26:37 +08:00
|
|
|
|
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
2023-06-27 19:22:58 +08:00
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-08-16 23:40:24 +08:00
|
|
|
|
func CacheDecreaseUserQuota(id int, quota int) error {
|
|
|
|
|
|
if !common.RedisEnabled {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-09-03 21:31:58 +08:00
|
|
|
|
func CacheIsUserEnabled(userId int) (bool, error) {
|
2023-06-21 17:26:26 +08:00
|
|
|
|
if !common.RedisEnabled {
|
|
|
|
|
|
return IsUserEnabled(userId)
|
|
|
|
|
|
}
|
|
|
|
|
|
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
2023-09-03 21:31:58 +08:00
|
|
|
|
if err == nil {
|
|
|
|
|
|
return enabled == "1", nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
userEnabled, err := IsUserEnabled(userId)
|
2023-06-21 17:26:26 +08:00
|
|
|
|
if err != nil {
|
2023-09-03 21:31:58 +08:00
|
|
|
|
return false, err
|
|
|
|
|
|
}
|
|
|
|
|
|
enabled = "0"
|
|
|
|
|
|
if userEnabled {
|
|
|
|
|
|
enabled = "1"
|
|
|
|
|
|
}
|
|
|
|
|
|
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
common.SysError("Redis set user enabled error: " + err.Error())
|
2023-06-21 17:26:26 +08:00
|
|
|
|
}
|
2023-09-03 21:31:58 +08:00
|
|
|
|
return userEnabled, err
|
2023-06-21 17:26:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2023-06-20 19:09:49 +08:00
|
|
|
|
var group2model2channels map[string]map[string][]*Channel
|
2023-12-23 23:14:58 +08:00
|
|
|
|
var channelsIDM map[int]*Channel
|
2023-06-21 17:04:18 +08:00
|
|
|
|
var channelSyncLock sync.RWMutex
|
2023-06-20 19:09:49 +08:00
|
|
|
|
|
|
|
|
|
|
func InitChannelCache() {
|
2023-06-21 17:04:18 +08:00
|
|
|
|
newChannelId2channel := make(map[int]*Channel)
|
2023-06-20 19:09:49 +08:00
|
|
|
|
var channels []*Channel
|
2023-06-25 23:14:15 +08:00
|
|
|
|
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
2023-06-20 19:09:49 +08:00
|
|
|
|
for _, channel := range channels {
|
2023-06-21 17:04:18 +08:00
|
|
|
|
newChannelId2channel[channel.Id] = channel
|
2023-06-20 19:09:49 +08:00
|
|
|
|
}
|
|
|
|
|
|
var abilities []*Ability
|
|
|
|
|
|
DB.Find(&abilities)
|
|
|
|
|
|
groups := make(map[string]bool)
|
|
|
|
|
|
for _, ability := range abilities {
|
|
|
|
|
|
groups[ability.Group] = true
|
|
|
|
|
|
}
|
2023-06-21 17:04:18 +08:00
|
|
|
|
newGroup2model2channels := make(map[string]map[string][]*Channel)
|
2023-12-23 23:14:58 +08:00
|
|
|
|
newChannelsIDM := make(map[int]*Channel)
|
2023-06-20 19:09:49 +08:00
|
|
|
|
for group := range groups {
|
2023-06-21 17:04:18 +08:00
|
|
|
|
newGroup2model2channels[group] = make(map[string][]*Channel)
|
|
|
|
|
|
}
|
|
|
|
|
|
for _, channel := range channels {
|
2023-12-23 23:14:58 +08:00
|
|
|
|
newChannelsIDM[channel.Id] = channel
|
2023-06-21 17:04:18 +08:00
|
|
|
|
groups := strings.Split(channel.Group, ",")
|
|
|
|
|
|
for _, group := range groups {
|
|
|
|
|
|
models := strings.Split(channel.Models, ",")
|
|
|
|
|
|
for _, model := range models {
|
|
|
|
|
|
if _, ok := newGroup2model2channels[group][model]; !ok {
|
|
|
|
|
|
newGroup2model2channels[group][model] = make([]*Channel, 0)
|
|
|
|
|
|
}
|
|
|
|
|
|
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2023-06-20 19:09:49 +08:00
|
|
|
|
}
|
2023-09-17 19:18:16 +08:00
|
|
|
|
|
|
|
|
|
|
// sort by priority
|
|
|
|
|
|
for group, model2channels := range newGroup2model2channels {
|
|
|
|
|
|
for model, channels := range model2channels {
|
|
|
|
|
|
sort.Slice(channels, func(i, j int) bool {
|
2023-09-18 21:43:45 +08:00
|
|
|
|
return channels[i].GetPriority() > channels[j].GetPriority()
|
2023-09-17 19:18:16 +08:00
|
|
|
|
})
|
|
|
|
|
|
newGroup2model2channels[group][model] = channels
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-06-21 17:04:18 +08:00
|
|
|
|
channelSyncLock.Lock()
|
|
|
|
|
|
group2model2channels = newGroup2model2channels
|
2023-12-23 23:14:58 +08:00
|
|
|
|
channelsIDM = newChannelsIDM
|
2023-06-21 17:04:18 +08:00
|
|
|
|
channelSyncLock.Unlock()
|
2023-06-22 10:59:01 +08:00
|
|
|
|
common.SysLog("channels synced from database")
|
2023-06-20 19:09:49 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func SyncChannelCache(frequency int) {
|
|
|
|
|
|
for {
|
|
|
|
|
|
time.Sleep(time.Duration(frequency) * time.Second)
|
2023-06-22 10:59:01 +08:00
|
|
|
|
common.SysLog("syncing channels from database")
|
2023-06-20 19:09:49 +08:00
|
|
|
|
InitChannelCache()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
2023-12-11 19:50:43 +08:00
|
|
|
|
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
|
|
|
|
|
model = "gpt-4-gizmo-*"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-12-27 19:00:47 +08:00
|
|
|
|
// if memory cache is disabled, get channel directly from database
|
2023-09-29 11:38:27 +08:00
|
|
|
|
if !common.MemoryCacheEnabled {
|
2023-06-20 19:09:49 +08:00
|
|
|
|
return GetRandomSatisfiedChannel(group, model)
|
|
|
|
|
|
}
|
2023-06-21 17:04:18 +08:00
|
|
|
|
channelSyncLock.RLock()
|
|
|
|
|
|
defer channelSyncLock.RUnlock()
|
|
|
|
|
|
channels := group2model2channels[group][model]
|
|
|
|
|
|
if len(channels) == 0 {
|
|
|
|
|
|
return nil, errors.New("channel not found")
|
|
|
|
|
|
}
|
2023-09-18 21:43:45 +08:00
|
|
|
|
endIdx := len(channels)
|
2023-09-17 19:18:16 +08:00
|
|
|
|
// choose by priority
|
|
|
|
|
|
firstChannel := channels[0]
|
2023-09-18 21:43:45 +08:00
|
|
|
|
if firstChannel.GetPriority() > 0 {
|
|
|
|
|
|
for i := range channels {
|
|
|
|
|
|
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
|
|
|
|
|
endIdx = i
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2023-09-17 19:18:16 +08:00
|
|
|
|
}
|
2024-03-02 22:46:26 +08:00
|
|
|
|
// 平滑系数
|
|
|
|
|
|
smoothingFactor := 10
|
|
|
|
|
|
|
2023-12-27 19:00:47 +08:00
|
|
|
|
// Calculate the total weight of all channels up to endIdx
|
|
|
|
|
|
totalWeight := 0
|
|
|
|
|
|
for _, channel := range channels[:endIdx] {
|
2024-03-02 22:46:26 +08:00
|
|
|
|
totalWeight += channel.GetWeight() + smoothingFactor
|
2023-12-27 19:00:47 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if totalWeight == 0 {
|
|
|
|
|
|
// If all weights are 0, select a channel randomly
|
|
|
|
|
|
return channels[rand.Intn(endIdx)], nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Generate a random value in the range [0, totalWeight)
|
|
|
|
|
|
randomWeight := rand.Intn(totalWeight)
|
|
|
|
|
|
|
|
|
|
|
|
// Find a channel based on its weight
|
|
|
|
|
|
for _, channel := range channels[:endIdx] {
|
2024-03-02 22:46:26 +08:00
|
|
|
|
randomWeight -= channel.GetWeight() + smoothingFactor
|
|
|
|
|
|
if randomWeight < 0 {
|
2023-12-27 19:00:47 +08:00
|
|
|
|
return channel, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2024-01-12 14:34:44 +08:00
|
|
|
|
// return null if no channel is not found
|
|
|
|
|
|
return nil, errors.New("channel not found")
|
2023-06-20 19:09:49 +08:00
|
|
|
|
}
|
2023-12-23 23:14:58 +08:00
|
|
|
|
|
|
|
|
|
|
func CacheGetChannel(id int) (*Channel, error) {
|
|
|
|
|
|
if !common.MemoryCacheEnabled {
|
|
|
|
|
|
return GetChannelById(id, true)
|
|
|
|
|
|
}
|
|
|
|
|
|
channelSyncLock.RLock()
|
|
|
|
|
|
defer channelSyncLock.RUnlock()
|
|
|
|
|
|
|
|
|
|
|
|
c, ok := channelsIDM[id]
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
|
|
|
|
|
|
}
|
|
|
|
|
|
return c, nil
|
|
|
|
|
|
}
|