2024-02-29 01:08:18 +08:00
package relay
import (
"bytes"
2024-03-06 17:41:55 +08:00
"context"
2024-02-29 01:08:18 +08:00
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/http"
"one-api/common"
2024-03-20 17:07:42 +08:00
"one-api/constant"
2024-02-29 01:08:18 +08:00
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func getAndValidateTextRequest ( c * gin . Context , relayInfo * relaycommon . RelayInfo ) ( * dto . GeneralOpenAIRequest , error ) {
textRequest := & dto . GeneralOpenAIRequest { }
err := common . UnmarshalBodyReusable ( c , textRequest )
if err != nil {
return nil , err
}
if relayInfo . RelayMode == relayconstant . RelayModeModerations && textRequest . Model == "" {
textRequest . Model = "text-moderation-latest"
}
if relayInfo . RelayMode == relayconstant . RelayModeEmbeddings && textRequest . Model == "" {
textRequest . Model = c . Param ( "model" )
}
if textRequest . MaxTokens < 0 || textRequest . MaxTokens > math . MaxInt32 / 2 {
return nil , errors . New ( "max_tokens is invalid" )
}
if textRequest . Model == "" {
return nil , errors . New ( "model is required" )
}
switch relayInfo . RelayMode {
case relayconstant . RelayModeCompletions :
if textRequest . Prompt == "" {
return nil , errors . New ( "field prompt is required" )
}
case relayconstant . RelayModeChatCompletions :
if textRequest . Messages == nil || len ( textRequest . Messages ) == 0 {
return nil , errors . New ( "field messages is required" )
}
case relayconstant . RelayModeEmbeddings :
case relayconstant . RelayModeModerations :
2024-09-02 01:11:19 +08:00
if textRequest . Input == "" || textRequest . Input == nil {
2024-02-29 01:08:18 +08:00
return nil , errors . New ( "field input is required" )
}
case relayconstant . RelayModeEdits :
if textRequest . Instruction == "" {
return nil , errors . New ( "field instruction is required" )
}
}
relayInfo . IsStream = textRequest . Stream
return textRequest , nil
}
func TextHelper ( c * gin . Context ) * dto . OpenAIErrorWithStatusCode {
relayInfo := relaycommon . GenRelayInfo ( c )
// get & validate textRequest 获取并验证文本请求
textRequest , err := getAndValidateTextRequest ( c , relayInfo )
if err != nil {
common . LogError ( c , fmt . Sprintf ( "getAndValidateTextRequest failed: %s" , err . Error ( ) ) )
2024-04-04 16:35:44 +08:00
return service . OpenAIErrorWrapperLocal ( err , "invalid_text_request" , http . StatusBadRequest )
2024-02-29 01:08:18 +08:00
}
// map model name
2024-10-15 18:37:44 +08:00
isModelMapped := false
2024-02-29 01:08:18 +08:00
modelMapping := c . GetString ( "model_mapping" )
2024-07-08 01:27:57 +08:00
//isModelMapped := false
2024-02-29 01:08:18 +08:00
if modelMapping != "" && modelMapping != "{}" {
modelMap := make ( map [ string ] string )
err := json . Unmarshal ( [ ] byte ( modelMapping ) , & modelMap )
if err != nil {
2024-04-04 16:35:44 +08:00
return service . OpenAIErrorWrapperLocal ( err , "unmarshal_model_mapping_failed" , http . StatusInternalServerError )
2024-02-29 01:08:18 +08:00
}
if modelMap [ textRequest . Model ] != "" {
2024-10-15 18:37:44 +08:00
isModelMapped = true
2024-02-29 01:08:18 +08:00
textRequest . Model = modelMap [ textRequest . Model ]
2024-03-11 18:28:51 +08:00
// set upstream model name
2024-07-08 01:27:57 +08:00
//isModelMapped = true
2024-02-29 01:08:18 +08:00
}
}
2024-03-11 18:28:51 +08:00
relayInfo . UpstreamModelName = textRequest . Model
2024-07-16 22:07:10 +08:00
modelPrice , getModelPriceSuccess := common . GetModelPrice ( textRequest . Model , false )
2024-02-29 01:08:18 +08:00
groupRatio := common . GetGroupRatio ( relayInfo . Group )
var preConsumedQuota int
var ratio float64
var modelRatio float64
2024-03-20 17:07:42 +08:00
//err := service.SensitiveWordsCheck(textRequest)
2024-02-29 01:08:18 +08:00
2024-05-23 23:59:55 +08:00
if constant . ShouldCheckPromptSensitive ( ) {
err = checkRequestSensitive ( textRequest , relayInfo )
if err != nil {
2024-04-04 16:35:44 +08:00
return service . OpenAIErrorWrapperLocal ( err , "sensitive_words_detected" , http . StatusBadRequest )
2024-03-20 20:36:55 +08:00
}
2024-05-23 23:59:55 +08:00
}
promptTokens , err := getPromptTokens ( textRequest , relayInfo )
// count messages token error 计算promptTokens错误
if err != nil {
2024-02-29 01:08:18 +08:00
return service . OpenAIErrorWrapper ( err , "count_token_messages_failed" , http . StatusInternalServerError )
}
2024-07-16 22:07:10 +08:00
if ! getModelPriceSuccess {
2024-02-29 01:08:18 +08:00
preConsumedTokens := common . PreConsumedQuota
if textRequest . MaxTokens != 0 {
preConsumedTokens = promptTokens + int ( textRequest . MaxTokens )
}
modelRatio = common . GetModelRatio ( textRequest . Model )
ratio = modelRatio * groupRatio
preConsumedQuota = int ( float64 ( preConsumedTokens ) * ratio )
} else {
preConsumedQuota = int ( modelPrice * common . QuotaPerUnit * groupRatio )
}
// pre-consume quota 预消耗配额
2024-02-29 16:39:52 +08:00
preConsumedQuota , userQuota , openaiErr := preConsumeQuota ( c , preConsumedQuota , relayInfo )
2024-03-03 22:05:00 +08:00
if openaiErr != nil {
2024-02-29 01:08:18 +08:00
return openaiErr
}
2024-07-19 14:06:10 +08:00
includeUsage := false
// 判断用户是否需要返回使用情况
if textRequest . StreamOptions != nil && textRequest . StreamOptions . IncludeUsage {
includeUsage = true
}
2024-07-08 02:00:39 +08:00
// 如果不支持StreamOptions, 将StreamOptions设置为nil
if ! relayInfo . SupportStreamOptions || ! textRequest . Stream {
textRequest . StreamOptions = nil
} else {
// 如果支持StreamOptions, 且请求中没有设置StreamOptions, 根据配置文件设置StreamOptions
if constant . ForceStreamOption {
textRequest . StreamOptions = & dto . StreamOptions {
IncludeUsage : true ,
}
}
}
2024-07-19 14:06:10 +08:00
if includeUsage {
relayInfo . ShouldIncludeUsage = true
2024-07-08 02:00:39 +08:00
}
2024-02-29 16:21:25 +08:00
adaptor := GetAdaptor ( relayInfo . ApiType )
2024-02-29 01:08:18 +08:00
if adaptor == nil {
2024-05-23 23:59:55 +08:00
return service . OpenAIErrorWrapperLocal ( fmt . Errorf ( "invalid api type: %d" , relayInfo . ApiType ) , "invalid_api_type" , http . StatusBadRequest )
2024-02-29 01:08:18 +08:00
}
2024-07-16 22:07:10 +08:00
adaptor . Init ( relayInfo )
2024-02-29 01:08:18 +08:00
var requestBody io . Reader
2024-07-08 01:27:57 +08:00
2024-10-15 18:37:44 +08:00
if relayInfo . ChannelType == common . ChannelTypeOpenAI && ! isModelMapped {
body , err := common . GetRequestBody ( c )
if err != nil {
return service . OpenAIErrorWrapperLocal ( err , "get_request_body_failed" , http . StatusInternalServerError )
}
requestBody = bytes . NewBuffer ( body )
} else {
convertedRequest , err := adaptor . ConvertRequest ( c , relayInfo , textRequest )
if err != nil {
return service . OpenAIErrorWrapperLocal ( err , "convert_request_failed" , http . StatusInternalServerError )
}
jsonData , err := json . Marshal ( convertedRequest )
if err != nil {
return service . OpenAIErrorWrapperLocal ( err , "json_marshal_failed" , http . StatusInternalServerError )
}
requestBody = bytes . NewBuffer ( jsonData )
2024-02-29 01:08:18 +08:00
}
2024-04-20 21:05:23 +08:00
statusCodeMappingStr := c . GetString ( "status_code_mapping" )
2024-10-04 16:08:18 +08:00
var httpResp * http . Response
2024-02-29 01:08:18 +08:00
resp , err := adaptor . DoRequest ( c , relayInfo , requestBody )
2024-03-06 17:41:55 +08:00
if err != nil {
return service . OpenAIErrorWrapper ( err , "do_request_failed" , http . StatusInternalServerError )
}
2024-02-29 01:08:18 +08:00
2024-04-23 11:44:40 +08:00
if resp != nil {
2024-10-04 16:08:18 +08:00
httpResp = resp . ( * http . Response )
relayInfo . IsStream = relayInfo . IsStream || strings . HasPrefix ( httpResp . Header . Get ( "Content-Type" ) , "text/event-stream" )
if httpResp . StatusCode != http . StatusOK {
2024-09-26 00:59:09 +08:00
returnPreConsumedQuota ( c , relayInfo , userQuota , preConsumedQuota )
2024-10-04 16:08:18 +08:00
openaiErr := service . RelayErrorHandler ( httpResp )
2024-04-23 11:44:40 +08:00
// reset status code 重置状态码
service . ResetStatusCode ( openaiErr , statusCodeMappingStr )
return openaiErr
}
2024-03-06 17:41:55 +08:00
}
2024-10-04 16:08:18 +08:00
usage , openaiErr := adaptor . DoResponse ( c , httpResp , relayInfo )
2024-02-29 01:08:18 +08:00
if openaiErr != nil {
2024-09-26 00:59:09 +08:00
returnPreConsumedQuota ( c , relayInfo , userQuota , preConsumedQuota )
2024-04-20 21:05:23 +08:00
// reset status code 重置状态码
service . ResetStatusCode ( openaiErr , statusCodeMappingStr )
2024-03-29 22:20:14 +08:00
return openaiErr
2024-02-29 01:08:18 +08:00
}
2024-10-04 16:08:18 +08:00
postConsumeQuota ( c , relayInfo , textRequest . Model , usage . ( * dto . Usage ) , ratio , preConsumedQuota , userQuota , modelRatio , groupRatio , modelPrice , getModelPriceSuccess , "" )
2024-02-29 01:08:18 +08:00
return nil
}
2024-05-23 23:59:55 +08:00
func getPromptTokens ( textRequest * dto . GeneralOpenAIRequest , info * relaycommon . RelayInfo ) ( int , error ) {
2024-02-29 01:08:18 +08:00
var promptTokens int
var err error
switch info . RelayMode {
case relayconstant . RelayModeChatCompletions :
2024-05-23 23:59:55 +08:00
promptTokens , err = service . CountTokenChatRequest ( * textRequest , textRequest . Model )
2024-02-29 01:08:18 +08:00
case relayconstant . RelayModeCompletions :
2024-05-23 23:59:55 +08:00
promptTokens , err = service . CountTokenInput ( textRequest . Prompt , textRequest . Model )
2024-02-29 01:08:18 +08:00
case relayconstant . RelayModeModerations :
2024-05-23 23:59:55 +08:00
promptTokens , err = service . CountTokenInput ( textRequest . Input , textRequest . Model )
2024-03-05 23:04:57 +08:00
case relayconstant . RelayModeEmbeddings :
2024-05-23 23:59:55 +08:00
promptTokens , err = service . CountTokenInput ( textRequest . Input , textRequest . Model )
2024-02-29 01:08:18 +08:00
default :
err = errors . New ( "unknown relay mode" )
promptTokens = 0
}
info . PromptTokens = promptTokens
2024-05-23 23:59:55 +08:00
return promptTokens , err
}
func checkRequestSensitive ( textRequest * dto . GeneralOpenAIRequest , info * relaycommon . RelayInfo ) error {
var err error
switch info . RelayMode {
case relayconstant . RelayModeChatCompletions :
err = service . CheckSensitiveMessages ( textRequest . Messages )
case relayconstant . RelayModeCompletions :
err = service . CheckSensitiveInput ( textRequest . Prompt )
case relayconstant . RelayModeModerations :
err = service . CheckSensitiveInput ( textRequest . Input )
case relayconstant . RelayModeEmbeddings :
err = service . CheckSensitiveInput ( textRequest . Input )
}
return err
2024-02-29 01:08:18 +08:00
}
// 预扣费并返回用户剩余配额
2024-02-29 16:39:52 +08:00
func preConsumeQuota ( c * gin . Context , preConsumedQuota int , relayInfo * relaycommon . RelayInfo ) ( int , int , * dto . OpenAIErrorWithStatusCode ) {
2024-02-29 01:08:18 +08:00
userQuota , err := model . CacheGetUserQuota ( relayInfo . UserId )
if err != nil {
2024-04-04 16:35:44 +08:00
return 0 , 0 , service . OpenAIErrorWrapperLocal ( err , "get_user_quota_failed" , http . StatusInternalServerError )
2024-02-29 01:08:18 +08:00
}
2024-08-09 18:34:51 +08:00
if userQuota <= 0 {
2024-04-04 16:35:44 +08:00
return 0 , 0 , service . OpenAIErrorWrapperLocal ( errors . New ( "user quota is not enough" ) , "insufficient_user_quota" , http . StatusForbidden )
2024-02-29 01:08:18 +08:00
}
2024-08-09 18:34:51 +08:00
if userQuota - preConsumedQuota < 0 {
2024-08-09 18:48:13 +08:00
return 0 , 0 , service . OpenAIErrorWrapperLocal ( errors . New ( fmt . Sprintf ( "chat pre-consumed quota failed, user quota: %d, need quota: %d" , userQuota , preConsumedQuota ) ) , "insufficient_user_quota" , http . StatusBadRequest )
2024-08-09 18:34:51 +08:00
}
2024-02-29 01:08:18 +08:00
err = model . CacheDecreaseUserQuota ( relayInfo . UserId , preConsumedQuota )
if err != nil {
2024-04-04 16:35:44 +08:00
return 0 , 0 , service . OpenAIErrorWrapperLocal ( err , "decrease_user_quota_failed" , http . StatusInternalServerError )
2024-02-29 01:08:18 +08:00
}
if userQuota > 100 * preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if ! relayInfo . TokenUnlimited {
// 非无限令牌,判断令牌额度是否充足
tokenQuota := c . GetInt ( "token_quota" )
if tokenQuota > 100 * preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
2024-08-03 17:46:13 +08:00
common . LogInfo ( c , fmt . Sprintf ( "user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume" , relayInfo . UserId , userQuota , relayInfo . TokenId , tokenQuota ) )
2024-02-29 01:08:18 +08:00
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
2024-08-03 17:46:13 +08:00
common . LogInfo ( c , fmt . Sprintf ( "user %d with unlimited token has enough quota %d, trusted and no need to pre-consume" , relayInfo . UserId , userQuota ) )
2024-02-29 01:08:18 +08:00
}
}
if preConsumedQuota > 0 {
2024-09-26 00:59:09 +08:00
userQuota , err = model . PreConsumeTokenQuota ( relayInfo , preConsumedQuota )
2024-02-29 01:08:18 +08:00
if err != nil {
2024-04-04 16:35:44 +08:00
return 0 , 0 , service . OpenAIErrorWrapperLocal ( err , "pre_consume_token_quota_failed" , http . StatusForbidden )
2024-02-29 01:08:18 +08:00
}
}
2024-02-29 16:39:52 +08:00
return preConsumedQuota , userQuota , nil
2024-02-29 01:08:18 +08:00
}
2024-09-26 00:59:09 +08:00
func returnPreConsumedQuota ( c * gin . Context , relayInfo * relaycommon . RelayInfo , userQuota int , preConsumedQuota int ) {
2024-03-06 17:41:55 +08:00
if preConsumedQuota != 0 {
go func ( ctx context . Context ) {
// return pre-consumed quota
2024-09-26 00:59:09 +08:00
err := model . PostConsumeTokenQuota ( relayInfo , userQuota , - preConsumedQuota , 0 , false )
2024-03-06 17:41:55 +08:00
if err != nil {
common . SysError ( "error return pre-consumed quota: " + err . Error ( ) )
}
} ( c )
}
}
2024-07-06 17:09:22 +08:00
func postConsumeQuota ( ctx * gin . Context , relayInfo * relaycommon . RelayInfo , modelName string ,
2024-03-20 19:00:51 +08:00
usage * dto . Usage , ratio float64 , preConsumedQuota int , userQuota int , modelRatio float64 , groupRatio float64 ,
2024-07-17 23:50:37 +08:00
modelPrice float64 , usePrice bool , extraContent string ) {
2024-08-01 16:13:08 +08:00
if usage == nil {
usage = & dto . Usage {
PromptTokens : relayInfo . PromptTokens ,
CompletionTokens : 0 ,
TotalTokens : relayInfo . PromptTokens ,
}
extraContent += " ,(可能是请求出错)"
}
2024-02-29 01:08:18 +08:00
useTimeSeconds := time . Now ( ) . Unix ( ) - relayInfo . StartTime . Unix ( )
promptTokens := usage . PromptTokens
completionTokens := usage . CompletionTokens
tokenName := ctx . GetString ( "token_name" )
2024-07-06 17:09:22 +08:00
completionRatio := common . GetCompletionRatio ( modelName )
2024-02-29 01:08:18 +08:00
quota := 0
2024-05-13 16:04:02 +08:00
if ! usePrice {
2024-05-13 15:08:01 +08:00
quota = promptTokens + int ( math . Round ( float64 ( completionTokens ) * completionRatio ) )
quota = int ( math . Round ( float64 ( quota ) * ratio ) )
2024-02-29 01:08:18 +08:00
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int ( modelPrice * common . QuotaPerUnit * groupRatio )
}
totalTokens := promptTokens + completionTokens
var logContent string
2024-07-16 23:40:52 +08:00
if ! usePrice {
2024-08-16 18:27:26 +08:00
logContent = fmt . Sprintf ( "模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f" , modelRatio , completionRatio , groupRatio )
2024-02-29 01:08:18 +08:00
} else {
logContent = fmt . Sprintf ( "模型价格 %.2f,分组倍率 %.2f" , modelPrice , groupRatio )
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt . Sprintf ( "(可能是上游超时)" )
2024-07-06 17:09:22 +08:00
common . LogError ( ctx , fmt . Sprintf ( "total tokens is 0, cannot consume quota, userId %d, channelId %d, " +
"tokenId %d, model %s, pre-consumed quota %d" , relayInfo . UserId , relayInfo . ChannelId , relayInfo . TokenId , modelName , preConsumedQuota ) )
2024-02-29 01:08:18 +08:00
} else {
2024-03-29 22:20:14 +08:00
//if sensitiveResp != nil {
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
//}
2024-02-29 01:08:18 +08:00
quotaDelta := quota - preConsumedQuota
2024-04-04 20:10:30 +08:00
if quotaDelta != 0 {
2024-09-26 00:59:09 +08:00
err := model . PostConsumeTokenQuota ( relayInfo , userQuota , quotaDelta , preConsumedQuota , true )
2024-04-04 20:10:30 +08:00
if err != nil {
common . LogError ( ctx , "error consuming token remain quota: " + err . Error ( ) )
}
2024-02-29 01:08:18 +08:00
}
2024-04-04 20:10:30 +08:00
err := model . CacheUpdateUserQuota ( relayInfo . UserId )
2024-02-29 01:08:18 +08:00
if err != nil {
common . LogError ( ctx , "error update user quota cache: " + err . Error ( ) )
}
model . UpdateUserUsedQuotaAndRequestCount ( relayInfo . UserId , quota )
model . UpdateChannelUsedQuota ( relayInfo . ChannelId , quota )
}
2024-07-06 17:09:22 +08:00
logModel := modelName
2024-02-29 01:08:18 +08:00
if strings . HasPrefix ( logModel , "gpt-4-gizmo" ) {
logModel = "gpt-4-gizmo-*"
2024-07-06 17:09:22 +08:00
logContent += fmt . Sprintf ( ",模型 %s" , modelName )
2024-02-29 01:08:18 +08:00
}
2024-08-16 17:25:03 +08:00
if strings . HasPrefix ( logModel , "gpt-4o-gizmo" ) {
logModel = "gpt-4o-gizmo-*"
logContent += fmt . Sprintf ( ",模型 %s" , modelName )
}
2024-07-17 23:50:37 +08:00
if extraContent != "" {
2024-07-18 00:41:31 +08:00
logContent += ", " + extraContent
2024-07-17 23:50:37 +08:00
}
2024-06-26 18:04:49 +08:00
other := service . GenerateTextOtherInfo ( ctx , relayInfo , modelRatio , groupRatio , completionRatio , modelPrice )
2024-07-06 17:09:22 +08:00
model . RecordConsumeLog ( ctx , relayInfo . UserId , relayInfo . ChannelId , promptTokens , completionTokens , logModel ,
tokenName , quota , logContent , relayInfo . TokenId , userQuota , int ( useTimeSeconds ) , relayInfo . IsStream , other )
2024-02-29 01:08:18 +08:00
//if quota != 0 {
//
//}
}