2023-05-23 10:01:09 +08:00
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
2024-02-29 01:08:18 +08:00
"io"
2024-07-05 20:51:25 +08:00
"math"
2023-05-23 10:01:09 +08:00
"net/http"
2024-02-29 01:08:18 +08:00
"net/http/httptest"
"net/url"
2023-05-23 10:01:09 +08:00
"one-api/common"
2025-07-03 13:10:25 +08:00
"one-api/constant"
2024-02-29 01:08:18 +08:00
"one-api/dto"
2024-07-13 19:55:22 +08:00
"one-api/middleware"
2023-05-23 10:01:09 +08:00
"one-api/model"
2024-02-29 16:21:25 +08:00
"one-api/relay"
2024-02-29 01:08:18 +08:00
relaycommon "one-api/relay/common"
2025-03-02 15:47:12 +08:00
"one-api/relay/helper"
2024-02-29 01:08:18 +08:00
"one-api/service"
2025-07-10 15:02:40 +08:00
"one-api/types"
2023-05-23 10:01:09 +08:00
"strconv"
2024-09-13 03:17:04 +08:00
"strings"
2023-05-23 10:01:09 +08:00
"sync"
"time"
2023-10-22 17:50:52 +08:00
2024-12-31 12:49:13 +08:00
"github.com/bytedance/gopkg/util/gopool"
2023-10-22 17:50:52 +08:00
"github.com/gin-gonic/gin"
2023-05-23 10:01:09 +08:00
)
2025-07-10 15:02:40 +08:00
func testChannel ( channel * model . Channel , testModel string ) ( err error , newAPIError * types . NewAPIError ) {
2024-07-05 20:51:25 +08:00
tik := time . Now ( )
2025-07-03 13:10:25 +08:00
if channel . Type == constant . ChannelTypeMidjourney {
2024-03-06 17:41:55 +08:00
return errors . New ( "midjourney channel test is not supported" ) , nil
}
2025-07-03 13:10:25 +08:00
if channel . Type == constant . ChannelTypeMidjourneyPlus {
return errors . New ( "midjourney plus channel test is not supported" ) , nil
2024-12-31 12:49:13 +08:00
}
2025-07-03 13:10:25 +08:00
if channel . Type == constant . ChannelTypeSunoAPI {
2024-06-12 20:37:42 +08:00
return errors . New ( "suno channel test is not supported" ) , nil
}
2025-07-03 13:10:25 +08:00
if channel . Type == constant . ChannelTypeKling {
2025-06-17 15:35:40 +08:00
return errors . New ( "kling channel test is not supported" ) , nil
}
2025-07-03 13:10:25 +08:00
if channel . Type == constant . ChannelTypeJimeng {
2025-06-20 15:50:00 +08:00
return errors . New ( "jimeng channel test is not supported" ) , nil
}
2024-02-29 01:08:18 +08:00
w := httptest . NewRecorder ( )
c , _ := gin . CreateTestContext ( w )
2025-02-13 16:39:17 +08:00
2025-01-22 04:21:08 +08:00
requestPath := "/v1/chat/completions"
2025-02-13 16:39:17 +08:00
2025-01-22 04:21:08 +08:00
// 先判断是否为 Embedding 模型
if strings . Contains ( strings . ToLower ( testModel ) , "embedding" ) ||
2025-02-13 16:39:17 +08:00
strings . HasPrefix ( testModel , "m3e" ) || // m3e 系列模型
strings . Contains ( testModel , "bge-" ) || // bge 系列模型
2025-02-27 16:49:32 +08:00
strings . Contains ( testModel , "embed" ) ||
2025-07-03 13:10:25 +08:00
channel . Type == constant . ChannelTypeMokaAI { // 其他 embedding 模型
2025-02-13 16:39:17 +08:00
requestPath = "/v1/embeddings" // 修改请求路径
2025-01-22 04:21:08 +08:00
}
2025-02-13 16:39:17 +08:00
2024-02-29 01:08:18 +08:00
c . Request = & http . Request {
Method : "POST" ,
2025-02-13 16:39:17 +08:00
URL : & url . URL { Path : requestPath } , // 使用动态路径
2024-02-29 01:08:18 +08:00
Body : nil ,
Header : make ( http . Header ) ,
}
2024-03-08 21:38:43 +08:00
2024-02-29 01:08:18 +08:00
if testModel == "" {
2024-04-04 17:28:56 +08:00
if channel . TestModel != nil && * channel . TestModel != "" {
testModel = * channel . TestModel
} else {
2024-07-08 17:06:29 +08:00
if len ( channel . GetModels ( ) ) > 0 {
testModel = channel . GetModels ( ) [ 0 ]
2024-07-06 01:32:40 +08:00
} else {
2025-02-13 16:39:17 +08:00
testModel = "gpt-4o-mini"
2024-07-06 01:32:40 +08:00
}
2024-04-04 17:28:56 +08:00
}
2025-02-10 22:39:56 +08:00
}
2025-02-27 16:49:32 +08:00
cache , err := model . GetUserCache ( 1 )
if err != nil {
return err , nil
}
cache . WriteContext ( c )
2024-07-13 19:55:22 +08:00
c . Request . Header . Set ( "Authorization" , "Bearer " + channel . Key )
c . Request . Header . Set ( "Content-Type" , "application/json" )
c . Set ( "channel" , channel . Type )
c . Set ( "base_url" , channel . GetBaseURL ( ) )
2025-03-08 19:53:07 +08:00
group , _ := model . GetUserGroup ( 1 , false )
c . Set ( "group" , group )
2024-07-13 19:55:22 +08:00
middleware . SetupContextForSelectedChannel ( c , channel , testModel )
2025-03-02 15:47:12 +08:00
info := relaycommon . GenRelayInfo ( c )
2025-06-20 16:02:23 +08:00
err = helper . ModelMappedHelper ( c , info , nil )
2025-03-02 15:47:12 +08:00
if err != nil {
2025-07-10 15:02:40 +08:00
return err , types . NewError ( err , types . ErrorCodeChannelModelMappedError )
2025-03-02 15:47:12 +08:00
}
2025-03-02 23:53:10 +08:00
testModel = info . UpstreamModelName
2025-03-02 15:47:12 +08:00
2025-07-03 13:10:25 +08:00
apiType , _ := common . ChannelType2APIType ( channel . Type )
2024-07-13 19:55:22 +08:00
adaptor := relay . GetAdaptor ( apiType )
if adaptor == nil {
2025-07-10 15:02:40 +08:00
return fmt . Errorf ( "invalid api type: %d, adaptor is nil" , apiType ) , types . NewError ( fmt . Errorf ( "invalid api type: %d, adaptor is nil" , apiType ) , types . ErrorCodeInvalidApiType )
2024-07-13 19:55:22 +08:00
}
2024-09-13 03:17:04 +08:00
request := buildTestRequest ( testModel )
2025-04-25 17:08:26 +08:00
// 创建一个用于日志的 info 副本,移除 ApiKey
logInfo := * info
logInfo . ApiKey = ""
common . SysLog ( fmt . Sprintf ( "testing channel %d with model %s , info %+v " , channel . Id , testModel , logInfo ) )
2024-02-01 18:11:00 +08:00
2025-04-03 17:32:48 +08:00
priceData , err := helper . ModelPriceHelper ( c , info , 0 , int ( request . MaxTokens ) )
if err != nil {
2025-07-10 15:02:40 +08:00
return err , types . NewError ( err , types . ErrorCodeModelPriceError )
2025-04-03 17:32:48 +08:00
}
2025-03-02 15:47:12 +08:00
adaptor . Init ( info )
2023-05-23 10:01:09 +08:00
2025-03-13 19:32:08 +08:00
convertedRequest , err := adaptor . ConvertOpenAIRequest ( c , info , request )
2023-05-23 10:01:09 +08:00
if err != nil {
2025-07-10 15:02:40 +08:00
return err , types . NewError ( err , types . ErrorCodeConvertRequestFailed )
2023-05-23 10:01:09 +08:00
}
2024-02-29 01:08:18 +08:00
jsonData , err := json . Marshal ( convertedRequest )
2023-05-23 10:01:09 +08:00
if err != nil {
2025-07-10 15:02:40 +08:00
return err , types . NewError ( err , types . ErrorCodeJsonMarshalFailed )
2023-05-23 10:01:09 +08:00
}
2024-02-29 01:08:18 +08:00
requestBody := bytes . NewBuffer ( jsonData )
c . Request . Body = io . NopCloser ( requestBody )
2025-03-02 15:47:12 +08:00
resp , err := adaptor . DoRequest ( c , info , requestBody )
2023-05-23 10:01:09 +08:00
if err != nil {
2025-07-10 15:02:40 +08:00
return err , types . NewError ( err , types . ErrorCodeDoRequestFailed )
2023-05-23 10:01:09 +08:00
}
2024-10-04 16:08:18 +08:00
var httpResp * http . Response
if resp != nil {
httpResp = resp . ( * http . Response )
if httpResp . StatusCode != http . StatusOK {
2025-03-11 17:25:06 +08:00
err := service . RelayErrorHandler ( httpResp , true )
2025-07-10 15:02:40 +08:00
return err , types . NewError ( err , types . ErrorCodeBadResponse )
2024-10-04 16:08:18 +08:00
}
2024-02-29 01:08:18 +08:00
}
2025-03-02 15:47:12 +08:00
usageA , respErr := adaptor . DoResponse ( c , httpResp , info )
2024-02-29 01:08:18 +08:00
if respErr != nil {
2025-07-10 15:02:40 +08:00
return respErr , respErr
2024-02-29 01:08:18 +08:00
}
2024-10-04 16:08:18 +08:00
if usageA == nil {
2025-07-10 15:02:40 +08:00
return errors . New ( "usage is nil" ) , types . NewError ( errors . New ( "usage is nil" ) , types . ErrorCodeBadResponseBody )
2024-02-29 01:08:18 +08:00
}
2024-10-12 14:13:11 +08:00
usage := usageA . ( * dto . Usage )
2024-02-29 01:08:18 +08:00
result := w . Result ( )
respBody , err := io . ReadAll ( result . Body )
2023-05-23 10:01:09 +08:00
if err != nil {
2025-07-10 15:02:40 +08:00
return err , types . NewError ( err , types . ErrorCodeReadResponseBodyFailed )
2023-05-23 10:01:09 +08:00
}
2025-03-02 15:59:39 +08:00
info . PromptTokens = usage . PromptTokens
2025-04-03 17:32:48 +08:00
2024-07-05 20:51:25 +08:00
quota := 0
2025-03-02 15:47:12 +08:00
if ! priceData . UsePrice {
quota = usage . PromptTokens + int ( math . Round ( float64 ( usage . CompletionTokens ) * priceData . CompletionRatio ) )
quota = int ( math . Round ( float64 ( quota ) * priceData . ModelRatio ) )
if priceData . ModelRatio != 0 && quota <= 0 {
2024-07-05 20:51:25 +08:00
quota = 1
}
} else {
2025-03-02 15:47:12 +08:00
quota = int ( priceData . ModelPrice * common . QuotaPerUnit )
2024-07-05 20:51:25 +08:00
}
tok := time . Now ( )
milliseconds := tok . Sub ( tik ) . Milliseconds ( )
consumedTime := float64 ( milliseconds ) / 1000.0
2025-06-17 21:05:35 +08:00
other := service . GenerateTextOtherInfo ( c , info , priceData . ModelRatio , priceData . GroupRatioInfo . GroupRatio , priceData . CompletionRatio ,
usage . PromptTokensDetails . CachedTokens , priceData . CacheRatio , priceData . ModelPrice , priceData . GroupRatioInfo . GroupSpecialRatio )
2025-07-07 14:26:37 +08:00
model . RecordConsumeLog ( c , 1 , model . RecordConsumeLogParams {
ChannelId : channel . Id ,
PromptTokens : usage . PromptTokens ,
CompletionTokens : usage . CompletionTokens ,
ModelName : info . OriginModelName ,
TokenName : "模型测试" ,
Quota : quota ,
Content : "模型测试" ,
UseTimeSeconds : int ( consumedTime ) ,
IsStream : false ,
Group : info . UsingGroup ,
Other : other ,
} )
2024-02-29 01:08:18 +08:00
common . SysLog ( fmt . Sprintf ( "testing channel #%d, response: \n%s" , channel . Id , string ( respBody ) ) )
2023-07-22 18:15:30 +08:00
return nil , nil
2023-05-23 10:01:09 +08:00
}
2024-09-13 03:17:04 +08:00
func buildTestRequest ( model string ) * dto . GeneralOpenAIRequest {
2024-02-29 01:08:18 +08:00
testRequest := & dto . GeneralOpenAIRequest {
2024-09-13 03:17:04 +08:00
Model : "" , // this will be set later
Stream : false ,
}
2025-02-04 22:52:37 +08:00
2025-01-22 04:21:08 +08:00
// 先判断是否为 Embedding 模型
2025-03-09 15:03:07 +08:00
if strings . Contains ( strings . ToLower ( model ) , "embedding" ) || // 其他 embedding 模型
2025-02-13 16:39:17 +08:00
strings . HasPrefix ( model , "m3e" ) || // m3e 系列模型
2025-03-09 15:03:07 +08:00
strings . Contains ( model , "bge-" ) {
testRequest . Model = model
2025-01-22 04:21:08 +08:00
// Embedding 请求
testRequest . Input = [ ] string { "hello world" }
return testRequest
}
// 并非Embedding 模型
2025-04-17 16:50:52 +08:00
if strings . HasPrefix ( model , "o" ) {
2024-12-25 17:55:20 +08:00
testRequest . MaxCompletionTokens = 10
2025-03-09 18:31:16 +08:00
} else if strings . Contains ( model , "thinking" ) {
2025-03-17 23:44:32 +08:00
if ! strings . Contains ( model , "claude" ) {
testRequest . MaxTokens = 50
}
2025-04-06 23:23:53 +08:00
} else if strings . Contains ( model , "gemini" ) {
2025-06-29 13:36:19 +08:00
testRequest . MaxTokens = 3000
2024-09-13 03:17:04 +08:00
} else {
2025-02-11 20:00:05 +08:00
testRequest . MaxTokens = 10
2023-05-23 10:01:09 +08:00
}
2025-06-07 23:05:01 +08:00
2024-02-29 01:08:18 +08:00
testMessage := dto . Message {
2023-05-23 10:01:09 +08:00
Role : "user" ,
2025-06-07 23:05:01 +08:00
Content : "hi" ,
2023-05-23 10:01:09 +08:00
}
2024-09-13 03:17:04 +08:00
testRequest . Model = model
2023-05-23 10:01:09 +08:00
testRequest . Messages = append ( testRequest . Messages , testMessage )
return testRequest
}
func TestChannel ( c * gin . Context ) {
2024-07-05 20:51:25 +08:00
channelId , err := strconv . Atoi ( c . Param ( "id" ) )
2023-05-23 10:01:09 +08:00
if err != nil {
c . JSON ( http . StatusOK , gin . H {
"success" : false ,
"message" : err . Error ( ) ,
} )
return
}
2024-07-05 20:51:25 +08:00
channel , err := model . GetChannelById ( channelId , true )
2023-05-23 10:01:09 +08:00
if err != nil {
c . JSON ( http . StatusOK , gin . H {
"success" : false ,
"message" : err . Error ( ) ,
} )
return
}
2024-02-29 01:08:18 +08:00
testModel := c . Query ( "model" )
2023-05-23 10:01:09 +08:00
tik := time . Now ( )
2025-07-10 15:02:40 +08:00
_ , newAPIError := testChannel ( channel , testModel )
2023-05-23 10:01:09 +08:00
tok := time . Now ( )
milliseconds := tok . Sub ( tik ) . Milliseconds ( )
go channel . UpdateResponseTime ( milliseconds )
consumedTime := float64 ( milliseconds ) / 1000.0
2025-07-10 15:02:40 +08:00
if newAPIError != nil {
2023-05-23 10:01:09 +08:00
c . JSON ( http . StatusOK , gin . H {
"success" : false ,
2025-07-10 15:02:40 +08:00
"message" : newAPIError . Error ( ) ,
2023-05-23 10:01:09 +08:00
"time" : consumedTime ,
} )
return
}
c . JSON ( http . StatusOK , gin . H {
"success" : true ,
"message" : "" ,
"time" : consumedTime ,
} )
return
}
var testAllChannelsLock sync . Mutex
var testAllChannelsRunning bool = false
2023-06-22 22:01:03 +08:00
func testAllChannels ( notify bool ) error {
2025-02-18 15:59:17 +08:00
2023-05-23 10:01:09 +08:00
testAllChannelsLock . Lock ( )
if testAllChannelsRunning {
testAllChannelsLock . Unlock ( )
return errors . New ( "测试已在运行中" )
}
testAllChannelsRunning = true
testAllChannelsLock . Unlock ( )
2023-12-05 18:15:40 +08:00
channels , err := model . GetAllChannels ( 0 , 0 , true , false )
2023-05-23 10:01:09 +08:00
if err != nil {
return err
}
var disableThreshold = int64 ( common . ChannelDisableThreshold * 1000 )
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
2024-07-19 01:07:37 +08:00
gopool . Go ( func ( ) {
2025-06-10 03:54:18 +08:00
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
defer func ( ) {
testAllChannelsLock . Lock ( )
testAllChannelsRunning = false
testAllChannelsLock . Unlock ( )
} ( )
2023-05-23 10:01:09 +08:00
for _ , channel := range channels {
2024-02-01 18:52:39 +08:00
isChannelEnabled := channel . Status == common . ChannelStatusEnabled
2023-05-23 10:01:09 +08:00
tik := time . Now ( )
2025-07-10 15:02:40 +08:00
err , newAPIError := testChannel ( channel , "" )
2023-05-23 10:01:09 +08:00
tok := time . Now ( )
milliseconds := tok . Sub ( tik ) . Milliseconds ( )
2023-11-17 16:22:13 +08:00
2024-09-13 03:17:04 +08:00
shouldBanChannel := false
2024-07-14 00:14:07 +08:00
// request error disables the channel
2025-07-10 15:02:40 +08:00
if err != nil {
shouldBanChannel = service . ShouldDisableChannel ( channel . Type , newAPIError )
2024-02-01 18:52:39 +08:00
}
2024-07-14 00:14:07 +08:00
2024-09-13 03:17:04 +08:00
if milliseconds > disableThreshold {
err = errors . New ( fmt . Sprintf ( "响应时间 %.2fs 超过阈值 %.2fs" , float64 ( milliseconds ) / 1000.0 , float64 ( disableThreshold ) / 1000.0 ) )
shouldBanChannel = true
2024-07-14 00:14:07 +08:00
}
// disable channel
2024-09-13 03:17:04 +08:00
if isChannelEnabled && shouldBanChannel && channel . GetAutoBan ( ) {
2024-07-14 00:14:07 +08:00
service . DisableChannel ( channel . Id , channel . Name , err . Error ( ) )
}
// enable channel
2025-07-10 15:02:40 +08:00
if ! isChannelEnabled && service . ShouldEnableChannel ( err , newAPIError , channel . Status ) {
2024-07-14 00:14:07 +08:00
service . EnableChannel ( channel . Id , channel . Name )
}
2023-05-23 10:01:09 +08:00
channel . UpdateResponseTime ( milliseconds )
2023-06-22 22:01:03 +08:00
time . Sleep ( common . RequestInterval )
2023-05-23 10:01:09 +08:00
}
2025-06-17 21:05:35 +08:00
2023-06-22 22:01:03 +08:00
if notify {
2025-02-18 15:59:17 +08:00
service . NotifyRootUser ( dto . NotifyTypeChannelTest , "通道测试完成" , "所有通道测试已完成" )
2023-06-22 22:01:03 +08:00
}
2024-07-19 01:07:37 +08:00
} )
2023-05-23 10:01:09 +08:00
return nil
}
func TestAllChannels ( c * gin . Context ) {
2023-06-22 22:01:03 +08:00
err := testAllChannels ( true )
2023-05-23 10:01:09 +08:00
if err != nil {
c . JSON ( http . StatusOK , gin . H {
"success" : false ,
"message" : err . Error ( ) ,
} )
return
}
c . JSON ( http . StatusOK , gin . H {
"success" : true ,
"message" : "" ,
} )
return
}
2023-06-22 22:01:03 +08:00
func AutomaticallyTestChannels ( frequency int ) {
2025-07-10 11:31:07 +08:00
if frequency <= 0 {
common . SysLog ( "CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test" )
return
}
2023-06-22 22:01:03 +08:00
for {
time . Sleep ( time . Duration ( frequency ) * time . Minute )
common . SysLog ( "testing all channels" )
_ = testAllChannels ( false )
common . SysLog ( "channel test finished" )
}
}