Files
new-api/controller/relay.go

448 lines
13 KiB
Go
Raw Normal View History

2023-04-23 18:24:11 +08:00
package controller
import (
2024-04-04 16:35:44 +08:00
"bytes"
2024-08-03 17:12:16 +08:00
"errors"
2023-04-23 18:24:11 +08:00
"fmt"
2024-02-29 01:08:18 +08:00
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
2024-04-04 16:35:44 +08:00
"io"
2023-09-09 03:19:55 +08:00
"log"
2023-04-23 18:24:11 +08:00
"net/http"
"one-api/common"
2024-02-29 01:08:18 +08:00
"one-api/dto"
2024-04-04 16:35:44 +08:00
"one-api/middleware"
"one-api/model"
2024-02-29 01:08:18 +08:00
"one-api/relay"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/service"
2024-04-20 17:18:14 +08:00
"strings"
2023-12-27 16:32:54 +08:00
)
2024-04-04 16:35:44 +08:00
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
2024-02-29 01:08:18 +08:00
var err *dto.OpenAIErrorWithStatusCode
2023-06-19 10:28:55 +08:00
switch relayMode {
2024-02-29 01:08:18 +08:00
case relayconstant.RelayModeImagesGenerations:
2024-07-17 23:50:37 +08:00
err = relay.ImageHelper(c, relayMode)
2024-02-29 01:08:18 +08:00
case relayconstant.RelayModeAudioSpeech:
2024-01-09 15:46:45 +08:00
fallthrough
2024-02-29 01:08:18 +08:00
case relayconstant.RelayModeAudioTranslation:
2024-01-09 15:46:45 +08:00
fallthrough
2024-02-29 01:08:18 +08:00
case relayconstant.RelayModeAudioTranscription:
2024-07-17 23:50:37 +08:00
err = relay.AudioHelper(c)
2024-07-06 17:09:22 +08:00
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode)
2023-06-19 10:28:55 +08:00
default:
2024-02-29 01:08:18 +08:00
err = relay.TextHelper(c)
}
2024-04-04 16:35:44 +08:00
return err
}
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
default:
err = relay.TextHelper(c)
}
return err
}
2024-09-26 00:59:09 +08:00
func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode
defer func() {
if openaiErr != nil {
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}()
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
return
}
2024-09-26 00:59:09 +08:00
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
return
}
if playgroundRequest.Model == "" {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
return
}
c.Set("original_model", playgroundRequest.Model)
group := playgroundRequest.Group
2024-10-10 13:39:09 +08:00
userGroup := c.GetString("group")
2024-09-26 00:59:09 +08:00
if group == "" {
2024-10-10 13:39:09 +08:00
group = userGroup
2024-09-26 00:59:09 +08:00
} else {
2024-10-10 13:39:09 +08:00
if !common.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
}
2024-09-26 00:59:09 +08:00
c.Set("group", group)
}
2024-10-10 13:34:29 +08:00
c.Set("token_name", "playground-"+group)
2024-09-26 00:59:09 +08:00
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
2024-09-27 20:18:03 +08:00
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
2024-09-26 00:59:09 +08:00
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
Relay(c)
}
2024-04-04 16:35:44 +08:00
func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
2024-08-03 16:55:29 +08:00
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
2024-04-04 16:35:44 +08:00
if err != nil {
2024-08-03 17:12:16 +08:00
common.LogError(c, err.Error())
2024-08-03 17:07:14 +08:00
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
2024-08-03 17:12:16 +08:00
break
2023-07-15 19:06:51 +08:00
}
2024-04-04 16:35:44 +08:00
2024-08-03 16:55:29 +08:00
openaiErr = relayRequest(c, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
2024-08-03 17:32:28 +08:00
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
2024-08-03 16:55:29 +08:00
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
2024-04-04 16:35:44 +08:00
}
2024-05-16 16:41:08 +08:00
useChannel := c.GetStringSlice("use_channel")
2024-04-20 17:18:14 +08:00
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
2024-08-03 17:12:16 +08:00
common.LogInfo(c, retryLogStr)
2024-04-20 17:18:14 +08:00
}
2024-04-04 16:35:44 +08:00
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
2024-07-27 17:55:36 +08:00
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
2024-04-04 16:35:44 +08:00
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域
},
}
func WssRelay(c *gin.Context) {
// 将 HTTP 连接升级为 WebSocket 连接
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
defer ws.Close()
if err != nil {
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
service.WssError(c, ws, openaiErr.Error)
return
}
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
openaiErr = wssRequest(c, ws, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
service.WssError(c, ws, openaiErr.Error)
}
2024-04-04 16:35:44 +08:00
}
2024-08-03 16:55:29 +08:00
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.WssHelper(c, ws)
}
2024-08-03 16:55:29 +08:00
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
if retryCount == 0 {
2024-08-03 17:32:28 +08:00
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
if !autoBan {
autoBanInt = 0
}
2024-08-03 16:55:29 +08:00
return &model.Channel{
2024-08-03 17:32:28 +08:00
Id: c.GetInt("channel_id"),
Type: c.GetInt("channel_type"),
Name: c.GetString("channel_name"),
AutoBan: &autoBanInt,
2024-08-03 16:55:29 +08:00
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
if err != nil {
2024-08-03 17:12:16 +08:00
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
2024-08-03 16:55:29 +08:00
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
2024-04-04 16:35:44 +08:00
if openaiErr == nil {
return false
}
2024-08-24 17:27:14 +08:00
if openaiErr.LocalError {
return false
}
2024-04-04 16:35:44 +08:00
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if openaiErr.StatusCode == http.StatusTooManyRequests {
return true
}
2024-04-08 14:10:09 +08:00
if openaiErr.StatusCode == 307 {
return true
}
2024-04-04 16:35:44 +08:00
if openaiErr.StatusCode/100 == 5 {
2024-04-04 21:21:44 +08:00
// 超时不重试
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
return false
}
2024-04-04 16:35:44 +08:00
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
2024-08-03 16:55:29 +08:00
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
return true
}
2024-04-04 16:35:44 +08:00
return false
}
2024-04-23 22:17:36 +08:00
if openaiErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
2024-04-04 16:35:44 +08:00
if openaiErr.StatusCode/100 == 2 {
return false
}
return true
}
2024-08-03 17:32:28 +08:00
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
2024-08-03 17:46:13 +08:00
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
2024-07-04 22:46:33 +08:00
if service.ShouldDisableChannel(channelType, err) && autoBan {
2024-04-04 16:35:44 +08:00
service.DisableChannel(channelId, channelName, err.Error.Message)
}
}
2023-08-14 22:16:32 +08:00
func RelayMidjourney(c *gin.Context) {
relayMode := c.GetInt("relay_mode")
2024-02-29 01:08:18 +08:00
var err *dto.MidjourneyResponse
2023-08-14 22:16:32 +08:00
switch relayMode {
2024-02-29 01:08:18 +08:00
case relayconstant.RelayModeMidjourneyNotify:
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed:
err = relay.RelayMidjourneyTaskImageSeed(c)
2024-03-14 18:08:12 +08:00
case relayconstant.RelayModeSwapFace:
err = relay.RelaySwapFace(c)
2023-08-14 22:16:32 +08:00
default:
2024-02-29 01:08:18 +08:00
err = relay.RelayMidjourneySubmit(c, relayMode)
2023-08-14 22:16:32 +08:00
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
statusCode := http.StatusBadRequest
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests
2023-08-14 22:16:32 +08:00
}
c.JSON(statusCode, gin.H{
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
"type": "upstream_error",
"code": err.Code,
})
2023-08-14 22:16:32 +08:00
channelId := c.GetInt("channel_id")
2024-04-25 20:47:18 +08:00
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
2023-08-14 22:16:32 +08:00
}
}
func RelayNotImplemented(c *gin.Context) {
2024-02-29 01:08:18 +08:00
err := dto.OpenAIError{
Message: "API not implemented",
2023-12-01 01:29:13 +08:00
Type: "new_api_error",
Param: "",
Code: "api_not_implemented",
}
2023-06-23 22:59:44 +08:00
c.JSON(http.StatusNotImplemented, gin.H{
"error": err,
})
}
func RelayNotFound(c *gin.Context) {
2024-02-29 01:08:18 +08:00
err := dto.OpenAIError{
2023-08-11 19:53:01 +08:00
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",
2023-08-11 19:53:01 +08:00
Code: "",
}
2023-06-23 22:59:44 +08:00
c.JSON(http.StatusNotFound, gin.H{
"error": err,
})
}
func RelayTask(c *gin.Context) {
retryTimes := common.RetryTimes
channelId := c.GetInt("channel_id")
relayMode := c.GetInt("relay_mode")
group := c.GetString("group")
originalModel := c.GetString("original_model")
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
taskErr := taskRelayHandler(c, relayMode)
if taskErr == nil {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
2024-08-03 17:46:13 +08:00
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
2024-08-03 17:46:13 +08:00
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
taskErr = taskRelayHandler(c, relayMode)
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
2024-08-03 17:46:13 +08:00
common.LogInfo(c, retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
}
c.JSON(taskErr.StatusCode, taskErr)
}
}
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)
}
return err
}
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
if taskErr == nil {
return false
}
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if taskErr.StatusCode == http.StatusTooManyRequests {
return true
}
if taskErr.StatusCode == 307 {
return true
}
if taskErr.StatusCode/100 == 5 {
// 超时不重试
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
return false
}
return true
}
if taskErr.StatusCode == http.StatusBadRequest {
return false
}
if taskErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
if taskErr.LocalError {
return false
}
if taskErr.StatusCode/100 == 2 {
return false
}
return true
}