2023-06-19 10:28:55 +08:00
package controller
import (
"bytes"
2023-09-17 15:39:46 +08:00
"context"
2023-06-19 10:28:55 +08:00
"encoding/json"
2023-06-25 10:25:33 +08:00
"errors"
2023-06-19 10:28:55 +08:00
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
2023-08-12 23:41:44 +08:00
"time"
2023-10-22 17:56:20 +08:00
"github.com/gin-gonic/gin"
2023-06-19 10:28:55 +08:00
)
2023-07-22 16:18:03 +08:00
const (
APITypeOpenAI = iota
APITypeClaude
APITypePaLM
2023-07-22 23:24:09 +08:00
APITypeBaidu
2023-07-23 11:51:44 +08:00
APITypeZhipu
2023-07-28 23:45:08 +08:00
APITypeAli
2023-07-29 21:55:57 +08:00
APITypeXunfei
2023-09-03 12:51:59 +08:00
APITypeAIProxyLibrary
2023-10-03 14:19:03 +08:00
APITypeTencent
2023-12-18 23:45:08 +08:00
APITypeGemini
2023-07-22 16:18:03 +08:00
)
2023-07-23 15:18:58 +08:00
var httpClient * http . Client
2023-08-12 23:41:44 +08:00
var impatientHTTPClient * http . Client
2023-07-23 15:18:58 +08:00
func init ( ) {
2023-10-22 20:39:49 +08:00
if common . RelayTimeout == 0 {
httpClient = & http . Client { }
} else {
httpClient = & http . Client {
Timeout : time . Duration ( common . RelayTimeout ) * time . Second ,
}
}
2023-08-12 23:41:44 +08:00
impatientHTTPClient = & http . Client {
Timeout : 5 * time . Second ,
}
2023-07-23 15:18:58 +08:00
}
2023-06-19 15:00:22 +08:00
func relayTextHelper ( c * gin . Context , relayMode int ) * OpenAIErrorWithStatusCode {
2023-06-19 10:28:55 +08:00
channelType := c . GetInt ( "channel" )
2023-09-17 19:18:16 +08:00
channelId := c . GetInt ( "channel_id" )
2023-06-19 10:28:55 +08:00
tokenId := c . GetInt ( "token_id" )
2023-06-21 17:26:26 +08:00
userId := c . GetInt ( "id" )
2023-06-19 10:28:55 +08:00
group := c . GetString ( "group" )
2024-01-11 14:12:38 +08:00
tokenUnlimited := c . GetBool ( "token_unlimited_quota" )
2023-11-23 03:06:04 +08:00
startTime := time . Now ( )
2023-06-19 10:28:55 +08:00
var textRequest GeneralOpenAIRequest
2023-11-23 03:06:04 +08:00
err := common . UnmarshalBodyReusable ( c , & textRequest )
if err != nil {
return errorWrapper ( err , "bind_request_body_failed" , http . StatusBadRequest )
2023-06-19 10:28:55 +08:00
}
2023-06-25 11:46:23 +08:00
if relayMode == RelayModeModerations && textRequest . Model == "" {
2023-06-19 10:28:55 +08:00
textRequest . Model = "text-moderation-latest"
}
2023-07-15 12:03:23 +08:00
if relayMode == RelayModeEmbeddings && textRequest . Model == "" {
textRequest . Model = c . Param ( "model" )
}
2023-06-25 10:25:33 +08:00
// request validation
if textRequest . Model == "" {
return errorWrapper ( errors . New ( "model is required" ) , "required_field_missing" , http . StatusBadRequest )
}
switch relayMode {
case RelayModeCompletions :
if textRequest . Prompt == "" {
2023-06-25 11:46:23 +08:00
return errorWrapper ( errors . New ( "field prompt is required" ) , "required_field_missing" , http . StatusBadRequest )
2023-06-25 10:25:33 +08:00
}
case RelayModeChatCompletions :
2023-06-25 11:46:23 +08:00
if textRequest . Messages == nil || len ( textRequest . Messages ) == 0 {
return errorWrapper ( errors . New ( "field messages is required" ) , "required_field_missing" , http . StatusBadRequest )
2023-06-25 10:25:33 +08:00
}
case RelayModeEmbeddings :
2023-06-25 11:46:23 +08:00
case RelayModeModerations :
2023-06-25 10:25:33 +08:00
if textRequest . Input == "" {
2023-06-25 11:46:23 +08:00
return errorWrapper ( errors . New ( "field input is required" ) , "required_field_missing" , http . StatusBadRequest )
}
case RelayModeEdits :
if textRequest . Instruction == "" {
return errorWrapper ( errors . New ( "field instruction is required" ) , "required_field_missing" , http . StatusBadRequest )
2023-06-25 10:25:33 +08:00
}
}
2023-06-27 13:42:45 +08:00
// map model name
modelMapping := c . GetString ( "model_mapping" )
isModelMapped := false
2023-07-29 19:17:26 +08:00
if modelMapping != "" && modelMapping != "{}" {
2023-06-27 13:42:45 +08:00
modelMap := make ( map [ string ] string )
err := json . Unmarshal ( [ ] byte ( modelMapping ) , & modelMap )
if err != nil {
return errorWrapper ( err , "unmarshal_model_mapping_failed" , http . StatusInternalServerError )
}
if modelMap [ textRequest . Model ] != "" {
textRequest . Model = modelMap [ textRequest . Model ]
isModelMapped = true
}
}
2023-07-22 16:18:03 +08:00
apiType := APITypeOpenAI
2023-07-24 23:34:14 +08:00
switch channelType {
case common . ChannelTypeAnthropic :
2023-07-22 16:18:03 +08:00
apiType = APITypeClaude
2023-07-24 23:34:14 +08:00
case common . ChannelTypeBaidu :
2023-07-22 23:24:09 +08:00
apiType = APITypeBaidu
2023-07-24 23:34:14 +08:00
case common . ChannelTypePaLM :
2023-07-23 00:32:47 +08:00
apiType = APITypePaLM
2023-07-24 23:34:14 +08:00
case common . ChannelTypeZhipu :
2023-07-23 11:51:44 +08:00
apiType = APITypeZhipu
2023-07-28 23:45:08 +08:00
case common . ChannelTypeAli :
apiType = APITypeAli
2023-07-29 21:55:57 +08:00
case common . ChannelTypeXunfei :
apiType = APITypeXunfei
2023-09-03 12:51:59 +08:00
case common . ChannelTypeAIProxyLibrary :
apiType = APITypeAIProxyLibrary
2023-10-03 14:19:03 +08:00
case common . ChannelTypeTencent :
apiType = APITypeTencent
2023-12-18 23:45:08 +08:00
case common . ChannelTypeGemini :
apiType = APITypeGemini
2023-07-22 16:18:03 +08:00
}
2023-06-19 10:28:55 +08:00
baseURL := common . ChannelBaseURLs [ channelType ]
requestURL := c . Request . URL . String ( )
2023-06-20 22:32:56 +08:00
if c . GetString ( "base_url" ) != "" {
2023-06-19 10:28:55 +08:00
baseURL = c . GetString ( "base_url" )
}
2023-10-22 17:50:52 +08:00
fullRequestURL := getFullRequestURL ( baseURL , requestURL , channelType )
2023-07-22 16:18:03 +08:00
switch apiType {
case APITypeOpenAI :
if channelType == common . ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
query := c . Request . URL . Query ( )
apiVersion := query . Get ( "api-version" )
if apiVersion == "" {
apiVersion = c . GetString ( "api_version" )
}
requestURL := strings . Split ( requestURL , "?" ) [ 0 ]
requestURL = fmt . Sprintf ( "%s?api-version=%s" , requestURL , apiVersion )
baseURL = c . GetString ( "base_url" )
task := strings . TrimPrefix ( requestURL , "/v1/" )
model_ := textRequest . Model
model_ = strings . Replace ( model_ , "." , "" , - 1 )
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings . TrimSuffix ( model_ , "-0301" )
model_ = strings . TrimSuffix ( model_ , "-0314" )
model_ = strings . TrimSuffix ( model_ , "-0613" )
fullRequestURL = fmt . Sprintf ( "%s/openai/deployments/%s/%s" , baseURL , model_ , task )
}
case APITypeClaude :
fullRequestURL = "https://api.anthropic.com/v1/complete"
if baseURL != "" {
fullRequestURL = fmt . Sprintf ( "%s/v1/complete" , baseURL )
2023-06-19 10:28:55 +08:00
}
2023-07-22 23:24:09 +08:00
case APITypeBaidu :
switch textRequest . Model {
case "ERNIE-Bot" :
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo" :
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
2023-10-22 18:48:35 +08:00
case "ERNIE-Bot-4" :
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
2023-07-22 23:24:09 +08:00
case "BLOOMZ-7B" :
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
2023-07-29 12:15:07 +08:00
case "Embedding-V1" :
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
2023-07-22 23:24:09 +08:00
}
apiKey := c . Request . Header . Get ( "Authorization" )
apiKey = strings . TrimPrefix ( apiKey , "Bearer " )
2023-08-12 23:41:44 +08:00
var err error
if apiKey , err = getBaiduAccessToken ( apiKey ) ; err != nil {
return errorWrapper ( err , "invalid_baidu_config" , http . StatusInternalServerError )
}
fullRequestURL += "?access_token=" + apiKey
2023-07-23 00:32:47 +08:00
case APITypePaLM :
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
2023-07-24 22:37:57 +08:00
if baseURL != "" {
fullRequestURL = fmt . Sprintf ( "%s/v1beta2/models/chat-bison-001:generateMessage" , baseURL )
}
2023-07-23 00:32:47 +08:00
apiKey := c . Request . Header . Get ( "Authorization" )
apiKey = strings . TrimPrefix ( apiKey , "Bearer " )
fullRequestURL += "?key=" + apiKey
2023-12-18 23:45:08 +08:00
case APITypeGemini :
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
requestBaseURL = baseURL
}
2023-12-19 12:53:56 +08:00
version := "v1beta"
2023-12-18 23:45:08 +08:00
if c . GetString ( "api_version" ) != "" {
version = c . GetString ( "api_version" )
}
action := "generateContent"
if textRequest . Stream {
action = "streamGenerateContent"
}
fullRequestURL = fmt . Sprintf ( "%s/%s/models/%s:%s" , requestBaseURL , version , textRequest . Model , action )
apiKey := c . Request . Header . Get ( "Authorization" )
apiKey = strings . TrimPrefix ( apiKey , "Bearer " )
fullRequestURL += "?key=" + apiKey
2023-12-19 12:53:56 +08:00
//log.Println(fullRequestURL)
2023-07-23 11:51:44 +08:00
case APITypeZhipu :
method := "invoke"
if textRequest . Stream {
method = "sse-invoke"
}
fullRequestURL = fmt . Sprintf ( "https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s" , textRequest . Model , method )
2023-07-28 23:45:08 +08:00
case APITypeAli :
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
2023-09-03 22:12:35 +08:00
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
2023-10-03 14:19:03 +08:00
case APITypeTencent :
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
2023-09-03 12:51:59 +08:00
case APITypeAIProxyLibrary :
fullRequestURL = fmt . Sprintf ( "%s/api/library/ask" , baseURL )
2023-06-19 10:28:55 +08:00
}
var promptTokens int
2023-06-24 15:28:11 +08:00
var completionTokens int
2023-06-19 10:28:55 +08:00
switch relayMode {
case RelayModeChatCompletions :
2023-11-17 18:24:37 +08:00
promptTokens , err = countTokenMessages ( textRequest . Messages , textRequest . Model )
if err != nil {
return errorWrapper ( err , "count_token_messages_failed" , http . StatusInternalServerError )
}
2023-06-19 10:28:55 +08:00
case RelayModeCompletions :
promptTokens = countTokenInput ( textRequest . Prompt , textRequest . Model )
2023-06-25 11:46:23 +08:00
case RelayModeModerations :
2023-06-19 10:28:55 +08:00
promptTokens = countTokenInput ( textRequest . Input , textRequest . Model )
}
2023-12-21 20:14:04 +08:00
modelPrice := common . GetModelPrice ( textRequest . Model )
2023-06-19 10:28:55 +08:00
groupRatio := common . GetGroupRatio ( group )
2023-12-21 20:14:04 +08:00
var preConsumedQuota int
var ratio float64
var modelRatio float64
if modelPrice == - 1 {
preConsumedTokens := common . PreConsumedQuota
if textRequest . MaxTokens != 0 {
preConsumedTokens = promptTokens + int ( textRequest . MaxTokens )
}
modelRatio = common . GetModelRatio ( textRequest . Model )
ratio = modelRatio * groupRatio
preConsumedQuota = int ( float64 ( preConsumedTokens ) * ratio )
} else {
preConsumedQuota = int ( modelPrice * common . QuotaPerUnit * groupRatio )
}
2023-06-21 17:26:26 +08:00
userQuota , err := model . CacheGetUserQuota ( userId )
if err != nil {
2023-06-23 22:59:44 +08:00
return errorWrapper ( err , "get_user_quota_failed" , http . StatusInternalServerError )
2023-06-21 17:26:26 +08:00
}
2023-12-05 17:11:37 +08:00
if userQuota < 0 || userQuota - preConsumedQuota < 0 {
2023-10-01 12:49:40 +08:00
return errorWrapper ( errors . New ( "user quota is not enough" ) , "insufficient_user_quota" , http . StatusForbidden )
}
2023-08-16 23:40:24 +08:00
err = model . CacheDecreaseUserQuota ( userId , preConsumedQuota )
if err != nil {
return errorWrapper ( err , "decrease_user_quota_failed" , http . StatusInternalServerError )
}
if userQuota > 100 * preConsumedQuota {
2024-01-10 14:23:23 +08:00
// 用户额度充足,判断令牌额度是否充足
2024-01-11 14:12:38 +08:00
if ! tokenUnlimited {
2024-01-10 14:23:23 +08:00
// 非无限令牌,判断令牌额度是否充足
2024-01-11 14:12:38 +08:00
tokenQuota := c . GetInt ( "token_quota" )
2024-01-10 14:23:23 +08:00
if tokenQuota > 100 * preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common . LogInfo ( c . Request . Context ( ) , fmt . Sprintf ( "user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume" , userId , userQuota , tokenId , tokenQuota ) )
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common . LogInfo ( c . Request . Context ( ) , fmt . Sprintf ( "user %d with unlimited token has enough quota %d, trusted and no need to pre-consume" , userId , userQuota ) )
}
2023-06-21 17:26:26 +08:00
}
2023-11-23 02:56:18 +08:00
if preConsumedQuota > 0 {
2023-11-12 18:51:26 +08:00
userQuota , err = model . PreConsumeTokenQuota ( tokenId , preConsumedQuota )
2023-06-19 10:28:55 +08:00
if err != nil {
2023-06-23 22:59:44 +08:00
return errorWrapper ( err , "pre_consume_token_quota_failed" , http . StatusForbidden )
2023-06-19 10:28:55 +08:00
}
}
2023-06-27 13:42:45 +08:00
var requestBody io . Reader
if isModelMapped {
jsonStr , err := json . Marshal ( textRequest )
if err != nil {
return errorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
}
requestBody = bytes . NewBuffer ( jsonStr )
} else {
requestBody = c . Request . Body
}
2023-07-22 16:18:03 +08:00
switch apiType {
case APITypeClaude :
2023-07-22 17:12:13 +08:00
claudeRequest := requestOpenAI2Claude ( textRequest )
2023-07-22 16:18:03 +08:00
jsonStr , err := json . Marshal ( claudeRequest )
if err != nil {
return errorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
}
requestBody = bytes . NewBuffer ( jsonStr )
2023-07-22 23:24:09 +08:00
case APITypeBaidu :
2023-07-29 12:15:07 +08:00
var jsonData [ ] byte
var err error
switch relayMode {
case RelayModeEmbeddings :
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu ( textRequest )
jsonData , err = json . Marshal ( baiduEmbeddingRequest )
default :
baiduRequest := requestOpenAI2Baidu ( textRequest )
jsonData , err = json . Marshal ( baiduRequest )
}
2023-07-22 23:24:09 +08:00
if err != nil {
return errorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
}
2023-07-29 12:15:07 +08:00
requestBody = bytes . NewBuffer ( jsonData )
2023-07-23 00:32:47 +08:00
case APITypePaLM :
palmRequest := requestOpenAI2PaLM ( textRequest )
jsonStr , err := json . Marshal ( palmRequest )
if err != nil {
return errorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
}
requestBody = bytes . NewBuffer ( jsonStr )
2023-12-18 23:45:08 +08:00
case APITypeGemini :
geminiChatRequest := requestOpenAI2Gemini ( textRequest )
jsonStr , err := json . Marshal ( geminiChatRequest )
if err != nil {
return errorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
}
requestBody = bytes . NewBuffer ( jsonStr )
2023-07-23 11:51:44 +08:00
case APITypeZhipu :
zhipuRequest := requestOpenAI2Zhipu ( textRequest )
jsonStr , err := json . Marshal ( zhipuRequest )
if err != nil {
return errorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
}
requestBody = bytes . NewBuffer ( jsonStr )
2023-07-28 23:45:08 +08:00
case APITypeAli :
2023-09-03 22:12:35 +08:00
var jsonStr [ ] byte
var err error
switch relayMode {
case RelayModeEmbeddings :
aliEmbeddingRequest := embeddingRequestOpenAI2Ali ( textRequest )
jsonStr , err = json . Marshal ( aliEmbeddingRequest )
default :
aliRequest := requestOpenAI2Ali ( textRequest )
jsonStr , err = json . Marshal ( aliRequest )
}
2023-07-28 23:45:08 +08:00
if err != nil {
return errorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
}
requestBody = bytes . NewBuffer ( jsonStr )
2023-10-03 14:19:03 +08:00
case APITypeTencent :
apiKey := c . Request . Header . Get ( "Authorization" )
apiKey = strings . TrimPrefix ( apiKey , "Bearer " )
appId , secretId , secretKey , err := parseTencentConfig ( apiKey )
if err != nil {
return errorWrapper ( err , "invalid_tencent_config" , http . StatusInternalServerError )
}
tencentRequest := requestOpenAI2Tencent ( textRequest )
tencentRequest . AppId = appId
tencentRequest . SecretId = secretId
jsonStr , err := json . Marshal ( tencentRequest )
if err != nil {
return errorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
}
sign := getTencentSign ( * tencentRequest , secretKey )
c . Request . Header . Set ( "Authorization" , sign )
requestBody = bytes . NewBuffer ( jsonStr )
2023-09-03 12:51:59 +08:00
case APITypeAIProxyLibrary :
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary ( textRequest )
aiProxyLibraryRequest . LibraryId = c . GetString ( "library_id" )
jsonStr , err := json . Marshal ( aiProxyLibraryRequest )
2023-07-28 23:45:08 +08:00
if err != nil {
return errorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
}
requestBody = bytes . NewBuffer ( jsonStr )
2023-07-22 16:18:03 +08:00
}
2023-07-29 21:55:57 +08:00
var req * http . Request
var resp * http . Response
isStream := textRequest . Stream
if apiType != APITypeXunfei { // cause xunfei use websocket
req , err = http . NewRequest ( c . Request . Method , fullRequestURL , requestBody )
2023-08-31 00:44:16 +08:00
// 设置GetBody函数, 该函数返回一个新的io.ReadCloser, 该io.ReadCloser返回与原始请求体相同的数据
req . GetBody = func ( ) ( io . ReadCloser , error ) {
return io . NopCloser ( requestBody ) , nil
}
2023-07-29 21:55:57 +08:00
if err != nil {
return errorWrapper ( err , "new_request_failed" , http . StatusInternalServerError )
2023-07-22 16:18:03 +08:00
}
2023-07-29 21:55:57 +08:00
apiKey := c . Request . Header . Get ( "Authorization" )
apiKey = strings . TrimPrefix ( apiKey , "Bearer " )
switch apiType {
case APITypeOpenAI :
if channelType == common . ChannelTypeAzure {
req . Header . Set ( "api-key" , apiKey )
} else {
req . Header . Set ( "Authorization" , c . Request . Header . Get ( "Authorization" ) )
2023-09-15 17:59:01 +08:00
if c . Request . Header . Get ( "OpenAI-Organization" ) != "" {
req . Header . Set ( "OpenAI-Organization" , c . Request . Header . Get ( "OpenAI-Organization" ) )
}
2023-08-27 16:16:45 +08:00
if channelType == common . ChannelTypeOpenRouter {
req . Header . Set ( "HTTP-Referer" , "https://github.com/songquanpeng/one-api" )
req . Header . Set ( "X-Title" , "One API" )
}
2023-07-29 21:55:57 +08:00
}
case APITypeClaude :
req . Header . Set ( "x-api-key" , apiKey )
anthropicVersion := c . Request . Header . Get ( "anthropic-version" )
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req . Header . Set ( "anthropic-version" , anthropicVersion )
case APITypeZhipu :
token := getZhipuToken ( apiKey )
req . Header . Set ( "Authorization" , token )
case APITypeAli :
req . Header . Set ( "Authorization" , "Bearer " + apiKey )
if textRequest . Stream {
req . Header . Set ( "X-DashScope-SSE" , "enable" )
}
2023-10-03 14:19:03 +08:00
case APITypeTencent :
req . Header . Set ( "Authorization" , apiKey )
2023-12-21 23:08:09 +08:00
case APITypeGemini :
req . Header . Set ( "Content-Type" , "application/json" )
2023-09-03 12:51:59 +08:00
default :
req . Header . Set ( "Authorization" , "Bearer " + apiKey )
2023-07-22 16:18:03 +08:00
}
2023-12-21 23:08:09 +08:00
if apiType != APITypeGemini {
// 设置公共头部...
req . Header . Set ( "Content-Type" , c . Request . Header . Get ( "Content-Type" ) )
req . Header . Set ( "Accept" , c . Request . Header . Get ( "Accept" ) )
if isStream && c . Request . Header . Get ( "Accept" ) == "" {
req . Header . Set ( "Accept" , "text/event-stream" )
}
2023-10-22 17:56:20 +08:00
}
2023-10-31 00:03:22 +08:00
//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
2023-07-29 21:55:57 +08:00
resp , err = httpClient . Do ( req )
if err != nil {
return errorWrapper ( err , "do_request_failed" , http . StatusInternalServerError )
2023-07-28 23:45:08 +08:00
}
2023-07-29 21:55:57 +08:00
err = req . Body . Close ( )
if err != nil {
return errorWrapper ( err , "close_request_body_failed" , http . StatusInternalServerError )
}
err = c . Request . Body . Close ( )
if err != nil {
return errorWrapper ( err , "close_request_body_failed" , http . StatusInternalServerError )
}
2023-08-06 18:09:00 +08:00
isStream = isStream || strings . HasPrefix ( resp . Header . Get ( "Content-Type" ) , "text/event-stream" )
2023-07-29 21:55:57 +08:00
2023-08-20 22:07:50 +08:00
if resp . StatusCode != http . StatusOK {
2023-09-13 22:05:10 +08:00
if preConsumedQuota != 0 {
2023-09-17 15:39:46 +08:00
go func ( ctx context . Context ) {
2023-09-13 22:05:10 +08:00
// return pre-consumed quota
2023-11-15 18:27:13 +08:00
err := model . PostConsumeTokenQuota ( tokenId , userQuota , - preConsumedQuota , 0 , false )
2023-09-13 22:05:10 +08:00
if err != nil {
2023-09-17 15:39:46 +08:00
common . LogError ( ctx , "error return pre-consumed quota: " + err . Error ( ) )
2023-09-13 22:05:10 +08:00
}
2023-09-17 15:39:46 +08:00
} ( c . Request . Context ( ) )
2023-09-13 22:05:10 +08:00
}
2023-09-09 01:50:41 +08:00
return relayErrorHandler ( resp )
2023-07-22 16:18:03 +08:00
}
2023-06-19 10:28:55 +08:00
}
2023-08-19 17:58:45 +08:00
2023-06-19 10:28:55 +08:00
var textResponse TextResponse
2023-08-12 19:36:31 +08:00
tokenName := c . GetString ( "token_name" )
2023-06-19 10:28:55 +08:00
2023-09-17 15:39:46 +08:00
defer func ( ctx context . Context ) {
2023-08-12 19:36:31 +08:00
// c.Writer.Flush()
go func ( ) {
2023-11-23 02:56:18 +08:00
promptTokens = textResponse . Usage . PromptTokens
completionTokens = textResponse . Usage . CompletionTokens
2023-07-29 22:32:05 +08:00
2023-12-21 20:14:04 +08:00
quota := 0
if modelPrice == - 1 {
completionRatio := common . GetCompletionRatio ( textRequest . Model )
quota = promptTokens + int ( float64 ( completionTokens ) * completionRatio )
quota = int ( float64 ( quota ) * ratio )
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int ( modelPrice * common . QuotaPerUnit * groupRatio )
2023-11-23 02:56:18 +08:00
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model . PostConsumeTokenQuota ( tokenId , userQuota , quotaDelta , preConsumedQuota , true )
if err != nil {
common . LogError ( ctx , "error consuming token remain quota: " + err . Error ( ) )
}
err = model . CacheUpdateUserQuota ( userId )
if err != nil {
common . LogError ( ctx , "error update user quota cache: " + err . Error ( ) )
2023-06-25 09:59:58 +08:00
}
2023-12-21 20:20:09 +08:00
2023-11-23 02:56:18 +08:00
// record all the consume log even if quota is 0
2023-11-23 03:06:04 +08:00
useTimeSeconds := time . Now ( ) . Unix ( ) - startTime . Unix ( )
2023-12-21 20:14:04 +08:00
var logContent string
if modelPrice == - 1 {
logContent = fmt . Sprintf ( "模型倍率 %.2f,分组倍率 %.2f,用时 %d秒" , modelRatio , groupRatio , useTimeSeconds )
} else {
logContent = fmt . Sprintf ( "模型价格 %.2f,分组倍率 %.2f,用时 %d秒" , modelPrice , groupRatio , useTimeSeconds )
}
2023-12-21 20:20:09 +08:00
logModel := textRequest . Model
if strings . HasPrefix ( logModel , "gpt-4-gizmo" ) {
logModel = "gpt-4-gizmo-*"
logContent += fmt . Sprintf ( ",模型 %s" , textRequest . Model )
}
model . RecordConsumeLog ( ctx , userId , channelId , promptTokens , completionTokens , logModel , tokenName , quota , logContent , tokenId , userQuota )
2023-11-23 02:56:18 +08:00
model . UpdateUserUsedQuotaAndRequestCount ( userId , quota )
model . UpdateChannelUsedQuota ( channelId , quota )
//if quota != 0 {
//
//}
2023-08-12 19:36:31 +08:00
} ( )
2023-09-17 15:39:46 +08:00
} ( c . Request . Context ( ) )
2023-07-22 16:18:03 +08:00
switch apiType {
case APITypeOpenAI :
if isStream {
2023-07-22 17:48:45 +08:00
err , responseText := openaiStreamHandler ( c , resp , relayMode )
2023-07-22 16:18:03 +08:00
if err != nil {
2023-07-22 17:48:45 +08:00
return err
2023-06-19 10:28:55 +08:00
}
2023-07-29 22:32:05 +08:00
textResponse . Usage . PromptTokens = promptTokens
textResponse . Usage . CompletionTokens = countTokenText ( responseText , textRequest . Model )
2023-07-22 16:18:03 +08:00
return nil
} else {
2023-11-23 02:56:18 +08:00
err , usage := openaiHandler ( c , resp , promptTokens , textRequest . Model )
2023-07-22 16:18:03 +08:00
if err != nil {
2023-07-22 17:48:45 +08:00
return err
2023-07-22 16:18:03 +08:00
}
2023-07-23 11:51:44 +08:00
if usage != nil {
textResponse . Usage = * usage
}
2023-07-22 16:18:03 +08:00
return nil
}
case APITypeClaude :
if isStream {
2023-07-22 17:36:40 +08:00
err , responseText := claudeStreamHandler ( c , resp )
2023-07-22 16:18:03 +08:00
if err != nil {
2023-07-22 17:36:40 +08:00
return err
2023-06-19 10:28:55 +08:00
}
2023-07-29 22:32:05 +08:00
textResponse . Usage . PromptTokens = promptTokens
textResponse . Usage . CompletionTokens = countTokenText ( responseText , textRequest . Model )
2023-07-22 16:18:03 +08:00
return nil
} else {
2023-07-22 17:36:40 +08:00
err , usage := claudeHandler ( c , resp , promptTokens , textRequest . Model )
2023-06-19 10:28:55 +08:00
if err != nil {
2023-07-22 17:36:40 +08:00
return err
2023-06-19 10:28:55 +08:00
}
2023-07-23 11:51:44 +08:00
if usage != nil {
textResponse . Usage = * usage
}
2023-07-22 16:18:03 +08:00
return nil
2023-06-19 10:28:55 +08:00
}
2023-07-22 23:24:09 +08:00
case APITypeBaidu :
if isStream {
err , usage := baiduStreamHandler ( c , resp )
if err != nil {
return err
}
2023-07-23 11:51:44 +08:00
if usage != nil {
textResponse . Usage = * usage
}
2023-07-22 23:24:09 +08:00
return nil
} else {
2023-07-29 12:15:07 +08:00
var err * OpenAIErrorWithStatusCode
var usage * Usage
switch relayMode {
case RelayModeEmbeddings :
err , usage = baiduEmbeddingHandler ( c , resp )
default :
err , usage = baiduHandler ( c , resp )
}
2023-07-22 23:24:09 +08:00
if err != nil {
return err
}
2023-07-23 11:51:44 +08:00
if usage != nil {
textResponse . Usage = * usage
}
2023-07-22 23:24:09 +08:00
return nil
}
2023-07-23 00:32:47 +08:00
case APITypePaLM :
if textRequest . Stream { // PaLM2 API does not support stream
err , responseText := palmStreamHandler ( c , resp )
if err != nil {
return err
}
2023-07-29 22:32:05 +08:00
textResponse . Usage . PromptTokens = promptTokens
textResponse . Usage . CompletionTokens = countTokenText ( responseText , textRequest . Model )
2023-07-23 00:32:47 +08:00
return nil
} else {
err , usage := palmHandler ( c , resp , promptTokens , textRequest . Model )
if err != nil {
return err
}
2023-07-23 11:51:44 +08:00
if usage != nil {
textResponse . Usage = * usage
}
return nil
}
2023-12-18 23:45:08 +08:00
case APITypeGemini :
if textRequest . Stream {
err , responseText := geminiChatStreamHandler ( c , resp )
if err != nil {
return err
}
textResponse . Usage . PromptTokens = promptTokens
textResponse . Usage . CompletionTokens = countTokenText ( responseText , textRequest . Model )
return nil
} else {
err , usage := geminiChatHandler ( c , resp , promptTokens , textRequest . Model )
if err != nil {
return err
}
if usage != nil {
textResponse . Usage = * usage
}
return nil
}
2023-07-23 11:51:44 +08:00
case APITypeZhipu :
if isStream {
err , usage := zhipuStreamHandler ( c , resp )
if err != nil {
return err
}
if usage != nil {
textResponse . Usage = * usage
}
2023-07-29 22:32:05 +08:00
// zhipu's API does not return prompt tokens & completion tokens
textResponse . Usage . PromptTokens = textResponse . Usage . TotalTokens
2023-07-23 11:51:44 +08:00
return nil
} else {
err , usage := zhipuHandler ( c , resp )
if err != nil {
return err
}
if usage != nil {
textResponse . Usage = * usage
}
2023-07-29 22:32:05 +08:00
// zhipu's API does not return prompt tokens & completion tokens
textResponse . Usage . PromptTokens = textResponse . Usage . TotalTokens
2023-07-23 00:32:47 +08:00
return nil
}
2023-07-28 23:45:08 +08:00
case APITypeAli :
if isStream {
err , usage := aliStreamHandler ( c , resp )
if err != nil {
return err
}
if usage != nil {
textResponse . Usage = * usage
}
return nil
} else {
2023-09-03 22:12:35 +08:00
var err * OpenAIErrorWithStatusCode
var usage * Usage
switch relayMode {
case RelayModeEmbeddings :
err , usage = aliEmbeddingHandler ( c , resp )
default :
err , usage = aliHandler ( c , resp )
}
2023-07-28 23:45:08 +08:00
if err != nil {
return err
}
if usage != nil {
textResponse . Usage = * usage
}
return nil
}
2023-07-29 21:55:57 +08:00
case APITypeXunfei :
2023-09-17 18:16:12 +08:00
auth := c . Request . Header . Get ( "Authorization" )
auth = strings . TrimPrefix ( auth , "Bearer " )
splits := strings . Split ( auth , "|" )
if len ( splits ) != 3 {
return errorWrapper ( errors . New ( "invalid auth" ) , "invalid_auth" , http . StatusBadRequest )
}
var err * OpenAIErrorWithStatusCode
var usage * Usage
2023-07-29 21:55:57 +08:00
if isStream {
2023-09-17 18:16:12 +08:00
err , usage = xunfeiStreamHandler ( c , textRequest , splits [ 0 ] , splits [ 1 ] , splits [ 2 ] )
2023-07-29 21:55:57 +08:00
} else {
2023-09-17 18:16:12 +08:00
err , usage = xunfeiHandler ( c , textRequest , splits [ 0 ] , splits [ 1 ] , splits [ 2 ] )
}
if err != nil {
return err
}
if usage != nil {
textResponse . Usage = * usage
2023-07-29 21:55:57 +08:00
}
2023-09-17 18:16:12 +08:00
return nil
2023-09-03 12:51:59 +08:00
case APITypeAIProxyLibrary :
if isStream {
err , usage := aiProxyLibraryStreamHandler ( c , resp )
if err != nil {
return err
}
if usage != nil {
textResponse . Usage = * usage
}
return nil
} else {
err , usage := aiProxyLibraryHandler ( c , resp )
if err != nil {
return err
}
if usage != nil {
textResponse . Usage = * usage
}
return nil
}
2023-10-03 14:19:03 +08:00
case APITypeTencent :
if isStream {
err , responseText := tencentStreamHandler ( c , resp )
if err != nil {
return err
}
textResponse . Usage . PromptTokens = promptTokens
textResponse . Usage . CompletionTokens = countTokenText ( responseText , textRequest . Model )
return nil
} else {
err , usage := tencentHandler ( c , resp )
if err != nil {
return err
}
if usage != nil {
textResponse . Usage = * usage
}
return nil
}
2023-07-22 16:18:03 +08:00
default :
return errorWrapper ( errors . New ( "unknown api type" ) , "unknown_api_type" , http . StatusInternalServerError )
2023-06-19 10:28:55 +08:00
}
}