Files
new-api/controller/relay.go

484 lines
15 KiB
Go
Raw Normal View History

2023-04-23 18:24:11 +08:00
package controller
import (
2024-04-04 16:35:44 +08:00
"bytes"
2024-08-03 17:12:16 +08:00
"errors"
2023-04-23 18:24:11 +08:00
"fmt"
2024-04-04 16:35:44 +08:00
"io"
2023-09-09 03:19:55 +08:00
"log"
2023-04-23 18:24:11 +08:00
"net/http"
"one-api/common"
"one-api/constant"
constant2 "one-api/constant"
2024-02-29 01:08:18 +08:00
"one-api/dto"
2024-04-04 16:35:44 +08:00
"one-api/middleware"
"one-api/model"
2024-02-29 01:08:18 +08:00
"one-api/relay"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
2024-02-29 01:08:18 +08:00
"one-api/service"
"one-api/types"
2024-04-20 17:18:14 +08:00
"strings"
2025-05-02 13:59:46 +08:00
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
2023-12-27 16:32:54 +08:00
)
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
var err *types.NewAPIError
2023-06-19 10:28:55 +08:00
switch relayMode {
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
err = relay.ImageHelper(c)
2024-02-29 01:08:18 +08:00
case relayconstant.RelayModeAudioSpeech:
2024-01-09 15:46:45 +08:00
fallthrough
2024-02-29 01:08:18 +08:00
case relayconstant.RelayModeAudioTranslation:
2024-01-09 15:46:45 +08:00
fallthrough
2024-02-29 01:08:18 +08:00
case relayconstant.RelayModeAudioTranscription:
2024-07-17 23:50:37 +08:00
err = relay.AudioHelper(c)
2024-07-06 17:09:22 +08:00
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode)
2025-01-23 05:54:39 +08:00
case relayconstant.RelayModeEmbeddings:
err = relay.EmbeddingHelper(c)
2025-05-02 13:59:46 +08:00
case relayconstant.RelayModeResponses:
err = relay.ResponsesHelper(c)
2025-05-26 13:34:41 +08:00
case relayconstant.RelayModeGemini:
err = relay.GeminiHelper(c)
2023-06-19 10:28:55 +08:00
default:
2024-02-29 01:08:18 +08:00
err = relay.TextHelper(c)
}
if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
// 保存错误日志到mysql中
userId := c.GetInt("id")
tokenName := c.GetString("token_name")
modelName := c.GetString("original_model")
tokenId := c.GetInt("token_id")
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
2025-07-23 20:01:03 +08:00
other["error_type"] = err.GetErrorType()
other["error_code"] = err.GetErrorCode()
other["status_code"] = err.StatusCode
other["channel_id"] = channelId
other["channel_name"] = c.GetString("channel_name")
other["channel_type"] = c.GetInt("channel_type")
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
if isMultiKey {
adminInfo["is_multi_key"] = true
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
}
other["admin_info"] = adminInfo
2025-07-30 19:08:35 +08:00
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
}
2024-04-04 16:35:44 +08:00
return err
}
func Relay(c *gin.Context) {
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
2024-04-04 16:35:44 +08:00
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
var newAPIError *types.NewAPIError
2024-08-03 16:55:29 +08:00
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
2024-04-04 16:35:44 +08:00
if err != nil {
2024-08-03 17:12:16 +08:00
common.LogError(c, err.Error())
newAPIError = err
2024-08-03 17:12:16 +08:00
break
2023-07-15 19:06:51 +08:00
}
2024-04-04 16:35:44 +08:00
newAPIError = relayRequest(c, relayMode, channel)
2024-08-03 16:55:29 +08:00
if newAPIError == nil {
2024-08-03 16:55:29 +08:00
return // 成功处理请求,直接返回
}
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
2024-08-03 16:55:29 +08:00
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
2024-08-03 16:55:29 +08:00
break
}
2024-04-04 16:35:44 +08:00
}
2024-05-16 16:41:08 +08:00
useChannel := c.GetStringSlice("use_channel")
2024-04-20 17:18:14 +08:00
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
2024-08-03 17:12:16 +08:00
common.LogInfo(c, retryLogStr)
2024-04-20 17:18:14 +08:00
}
2024-04-04 16:35:44 +08:00
if newAPIError != nil {
//if newAPIError.StatusCode == http.StatusTooManyRequests {
// common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
//}
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
c.JSON(newAPIError.StatusCode, gin.H{
"error": newAPIError.ToOpenAIError(),
2024-04-04 16:35:44 +08:00
})
}
}
var upgrader = websocket.Upgrader{
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol则必须在此声明对应的 Protocol TODO add other protocol
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域
},
}
func WssRelay(c *gin.Context) {
// 将 HTTP 连接升级为 WebSocket 连接
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
defer ws.Close()
if err != nil {
2025-07-30 22:35:31 +08:00
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
return
}
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
originalModel := c.GetString("original_model")
var newAPIError *types.NewAPIError
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
newAPIError = err
break
}
newAPIError = wssRequest(c, ws, relayMode, channel)
if newAPIError == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if newAPIError != nil {
//if newAPIError.StatusCode == http.StatusTooManyRequests {
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
//}
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
helper.WssError(c, ws, newAPIError.ToOpenAIError())
}
2024-04-04 16:35:44 +08:00
}
2025-03-12 21:31:46 +08:00
func RelayClaude(c *gin.Context) {
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
var newAPIError *types.NewAPIError
2025-03-12 21:31:46 +08:00
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
newAPIError = err
2025-03-12 21:31:46 +08:00
break
}
newAPIError = claudeRequest(c, channel)
2025-03-12 21:31:46 +08:00
if newAPIError == nil {
2025-03-12 21:31:46 +08:00
return // 成功处理请求,直接返回
}
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
2025-03-12 21:31:46 +08:00
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
2025-03-12 21:31:46 +08:00
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if newAPIError != nil {
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
c.JSON(newAPIError.StatusCode, gin.H{
2025-03-12 21:31:46 +08:00
"type": "error",
"error": newAPIError.ToClaudeError(),
2025-03-12 21:31:46 +08:00
})
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
2024-08-03 16:55:29 +08:00
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.WssHelper(c, ws)
}
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
2025-03-12 21:31:46 +08:00
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.ClaudeHelper(c)
}
2024-08-03 16:55:29 +08:00
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
2024-08-03 16:55:29 +08:00
if retryCount == 0 {
2024-08-03 17:32:28 +08:00
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
if !autoBan {
autoBanInt = 0
}
2024-08-03 16:55:29 +08:00
return &model.Channel{
2024-08-03 17:32:28 +08:00
Id: c.GetInt("channel_id"),
Type: c.GetInt("channel_type"),
Name: c.GetString("channel_name"),
AutoBan: &autoBanInt,
2024-08-03 16:55:29 +08:00
}, nil
}
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
2024-08-03 16:55:29 +08:00
if err != nil {
2025-07-30 22:35:31 +08:00
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
if channel == nil {
2025-07-30 22:35:31 +08:00
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在数据库一致性已被破坏retry", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
if newAPIError != nil {
return nil, newAPIError
2024-08-03 16:55:29 +08:00
}
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
2024-04-04 16:35:44 +08:00
if openaiErr == nil {
return false
}
if types.IsChannelError(openaiErr) {
return true
}
2025-07-30 22:35:31 +08:00
if types.IsSkipRetryError(openaiErr) {
2024-08-24 17:27:14 +08:00
return false
}
2024-04-04 16:35:44 +08:00
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if openaiErr.StatusCode == http.StatusTooManyRequests {
return true
}
2024-04-08 14:10:09 +08:00
if openaiErr.StatusCode == 307 {
return true
}
2024-04-04 16:35:44 +08:00
if openaiErr.StatusCode/100 == 5 {
2024-04-04 21:21:44 +08:00
// 超时不重试
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
return false
}
2024-04-04 16:35:44 +08:00
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
2024-08-03 16:55:29 +08:00
channelType := c.GetInt("channel_type")
if channelType == constant.ChannelTypeAnthropic {
2024-08-03 16:55:29 +08:00
return true
}
2024-04-04 16:35:44 +08:00
return false
}
2024-04-23 22:17:36 +08:00
if openaiErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
2024-04-04 16:35:44 +08:00
if openaiErr.StatusCode/100 == 2 {
return false
}
return true
}
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
2024-08-03 17:32:28 +08:00
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
service.DisableChannel(channelError, err.Error())
}
}
2023-08-14 22:16:32 +08:00
func RelayMidjourney(c *gin.Context) {
relayMode := c.GetInt("relay_mode")
2024-02-29 01:08:18 +08:00
var err *dto.MidjourneyResponse
2023-08-14 22:16:32 +08:00
switch relayMode {
2024-02-29 01:08:18 +08:00
case relayconstant.RelayModeMidjourneyNotify:
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed:
err = relay.RelayMidjourneyTaskImageSeed(c)
2024-03-14 18:08:12 +08:00
case relayconstant.RelayModeSwapFace:
err = relay.RelaySwapFace(c)
2023-08-14 22:16:32 +08:00
default:
2024-02-29 01:08:18 +08:00
err = relay.RelayMidjourneySubmit(c, relayMode)
2023-08-14 22:16:32 +08:00
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
statusCode := http.StatusBadRequest
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests
2023-08-14 22:16:32 +08:00
}
c.JSON(statusCode, gin.H{
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
"type": "upstream_error",
"code": err.Code,
})
2023-08-14 22:16:32 +08:00
channelId := c.GetInt("channel_id")
2024-04-25 20:47:18 +08:00
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
2023-08-14 22:16:32 +08:00
}
}
func RelayNotImplemented(c *gin.Context) {
2024-02-29 01:08:18 +08:00
err := dto.OpenAIError{
Message: "API not implemented",
2023-12-01 01:29:13 +08:00
Type: "new_api_error",
Param: "",
Code: "api_not_implemented",
}
2023-06-23 22:59:44 +08:00
c.JSON(http.StatusNotImplemented, gin.H{
"error": err,
})
}
func RelayNotFound(c *gin.Context) {
2024-02-29 01:08:18 +08:00
err := dto.OpenAIError{
2023-08-11 19:53:01 +08:00
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",
2023-08-11 19:53:01 +08:00
Code: "",
}
2023-06-23 22:59:44 +08:00
c.JSON(http.StatusNotFound, gin.H{
"error": err,
})
}
func RelayTask(c *gin.Context) {
retryTimes := common.RetryTimes
channelId := c.GetInt("channel_id")
relayMode := c.GetInt("relay_mode")
group := c.GetString("group")
originalModel := c.GetString("original_model")
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
taskErr := taskRelayHandler(c, relayMode)
if taskErr == nil {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, newAPIError := getChannel(c, group, originalModel, i)
if newAPIError != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
2024-08-03 17:46:13 +08:00
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
taskErr = taskRelayHandler(c, relayMode)
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
2024-08-03 17:46:13 +08:00
common.LogInfo(c, retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
}
c.JSON(taskErr.StatusCode, taskErr)
}
}
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
2025-07-22 17:36:38 +08:00
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)
}
return err
}
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
if taskErr == nil {
return false
}
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if taskErr.StatusCode == http.StatusTooManyRequests {
return true
}
if taskErr.StatusCode == 307 {
return true
}
if taskErr.StatusCode/100 == 5 {
// 超时不重试
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
return false
}
return true
}
if taskErr.StatusCode == http.StatusBadRequest {
return false
}
if taskErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
if taskErr.LocalError {
return false
}
if taskErr.StatusCode/100 == 2 {
return false
}
return true
}