2024-02-29 01:08:18 +08:00
package relay
2023-06-19 10:28:55 +08:00
import (
2023-07-15 12:30:06 +08:00
"bytes"
2023-09-17 15:39:46 +08:00
"context"
2023-07-15 12:30:06 +08:00
"encoding/json"
"errors"
"fmt"
2023-10-22 17:50:52 +08:00
"github.com/gin-gonic/gin"
2023-06-19 10:28:55 +08:00
"io"
"net/http"
2023-07-15 12:30:06 +08:00
"one-api/common"
2024-05-23 23:59:55 +08:00
"one-api/constant"
2024-02-29 01:08:18 +08:00
"one-api/dto"
2023-07-15 12:30:06 +08:00
"one-api/model"
2024-02-29 16:21:25 +08:00
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
2023-11-09 17:08:32 +08:00
"strings"
2024-01-21 15:01:59 +08:00
"time"
2023-06-19 10:28:55 +08:00
)
2024-02-29 16:21:25 +08:00
func RelayImageHelper ( c * gin . Context , relayMode int ) * dto . OpenAIErrorWithStatusCode {
2023-07-15 12:30:06 +08:00
tokenId := c . GetInt ( "token_id" )
channelType := c . GetInt ( "channel" )
2023-09-17 19:18:16 +08:00
channelId := c . GetInt ( "channel_id" )
2023-07-15 12:30:06 +08:00
userId := c . GetInt ( "id" )
group := c . GetString ( "group" )
2024-01-21 15:01:59 +08:00
startTime := time . Now ( )
2023-07-15 12:30:06 +08:00
2024-02-29 01:08:18 +08:00
var imageRequest dto . ImageRequest
2024-03-13 22:30:10 +08:00
err := common . UnmarshalBodyReusable ( c , & imageRequest )
if err != nil {
return service . OpenAIErrorWrapper ( err , "bind_request_body_failed" , http . StatusBadRequest )
2023-07-15 12:30:06 +08:00
}
2023-11-09 17:08:32 +08:00
if imageRequest . Model == "" {
2024-04-24 15:08:15 +08:00
imageRequest . Model = "dall-e-3"
2023-11-09 17:08:32 +08:00
}
2023-11-10 17:08:33 +08:00
if imageRequest . Size == "" {
imageRequest . Size = "1024x1024"
}
2023-11-27 14:24:28 +08:00
if imageRequest . N == 0 {
imageRequest . N = 1
}
2023-07-15 12:30:06 +08:00
// Prompt validation
if imageRequest . Prompt == "" {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( errors . New ( "prompt is required" ) , "required_field_missing" , http . StatusBadRequest )
2023-07-15 12:30:06 +08:00
}
2024-05-23 23:59:55 +08:00
if constant . ShouldCheckPromptSensitive ( ) {
err = service . CheckSensitiveInput ( imageRequest . Prompt )
if err != nil {
return service . OpenAIErrorWrapper ( err , "sensitive_words_detected" , http . StatusBadRequest )
}
}
2023-11-09 17:08:32 +08:00
if strings . Contains ( imageRequest . Size , "× " ) {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( errors . New ( "size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '× '" ) , "invalid_field_value" , http . StatusBadRequest )
2023-11-09 17:08:32 +08:00
}
2023-07-15 12:30:06 +08:00
// Not "256x256", "512x512", or "1024x1024"
2023-11-10 17:08:33 +08:00
if imageRequest . Model == "dall-e-2" || imageRequest . Model == "dall-e" {
if imageRequest . Size != "" && imageRequest . Size != "256x256" && imageRequest . Size != "512x512" && imageRequest . Size != "1024x1024" {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( errors . New ( "size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024" ) , "invalid_field_value" , http . StatusBadRequest )
2023-11-10 17:08:33 +08:00
}
} else if imageRequest . Model == "dall-e-3" {
if imageRequest . Size != "" && imageRequest . Size != "1024x1024" && imageRequest . Size != "1024x1792" && imageRequest . Size != "1792x1024" {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( errors . New ( "size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024" ) , "invalid_field_value" , http . StatusBadRequest )
2023-11-10 17:08:33 +08:00
}
2023-12-01 01:29:13 +08:00
if imageRequest . N != 1 {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( errors . New ( "n must be 1" ) , "invalid_field_value" , http . StatusBadRequest )
2023-12-01 01:29:13 +08:00
}
2023-07-15 12:30:06 +08:00
}
// N should between 1 and 10
if imageRequest . N != 0 && ( imageRequest . N < 1 || imageRequest . N > 10 ) {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( errors . New ( "n must be between 1 and 10" ) , "invalid_field_value" , http . StatusBadRequest )
2023-07-15 12:30:06 +08:00
}
// map model name
modelMapping := c . GetString ( "model_mapping" )
isModelMapped := false
if modelMapping != "" {
modelMap := make ( map [ string ] string )
err := json . Unmarshal ( [ ] byte ( modelMapping ) , & modelMap )
if err != nil {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( err , "unmarshal_model_mapping_failed" , http . StatusInternalServerError )
2023-07-15 12:30:06 +08:00
}
2023-11-09 17:08:32 +08:00
if modelMap [ imageRequest . Model ] != "" {
imageRequest . Model = modelMap [ imageRequest . Model ]
2023-07-15 12:30:06 +08:00
isModelMapped = true
}
}
baseURL := common . ChannelBaseURLs [ channelType ]
requestURL := c . Request . URL . String ( )
if c . GetString ( "base_url" ) != "" {
baseURL = c . GetString ( "base_url" )
}
2024-02-29 16:21:25 +08:00
fullRequestURL := relaycommon . GetFullRequestURL ( baseURL , requestURL , channelType )
if channelType == common . ChannelTypeAzure && relayMode == relayconstant . RelayModeImagesGenerations {
2024-01-09 15:46:45 +08:00
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
2024-02-29 16:21:25 +08:00
apiVersion := relaycommon . GetAPIVersion ( c )
2024-01-09 15:46:45 +08:00
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt . Sprintf ( "%s/openai/deployments/%s/images/generations?api-version=%s" , baseURL , imageRequest . Model , apiVersion )
}
2023-07-15 12:30:06 +08:00
var requestBody io . Reader
2024-01-09 15:46:45 +08:00
if isModelMapped || channelType == common . ChannelTypeAzure { // make Azure channel request body
2023-07-15 12:30:06 +08:00
jsonStr , err := json . Marshal ( imageRequest )
if err != nil {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
2023-07-15 12:30:06 +08:00
}
requestBody = bytes . NewBuffer ( jsonStr )
} else {
requestBody = c . Request . Body
}
2024-05-13 16:04:02 +08:00
modelPrice , success := common . GetModelPrice ( imageRequest . Model , true )
if ! success {
modelRatio := common . GetModelRatio ( imageRequest . Model )
// modelRatio 16 = modelPrice $0.04
// per 1 modelRatio = $0.04 / 16
modelPrice = 0.0025 * modelRatio
}
2023-07-15 12:30:06 +08:00
groupRatio := common . GetGroupRatio ( group )
userQuota , err := model . CacheGetUserQuota ( userId )
sizeRatio := 1.0
// Size
if imageRequest . Size == "256x256" {
2024-05-13 16:04:02 +08:00
sizeRatio = 0.4
2023-07-15 12:30:06 +08:00
} else if imageRequest . Size == "512x512" {
2024-05-13 16:04:02 +08:00
sizeRatio = 0.45
2023-07-15 12:30:06 +08:00
} else if imageRequest . Size == "1024x1024" {
2024-05-13 16:04:02 +08:00
sizeRatio = 1
2023-11-09 17:08:32 +08:00
} else if imageRequest . Size == "1024x1792" || imageRequest . Size == "1792x1024" {
2024-05-13 16:04:02 +08:00
sizeRatio = 2
2023-11-09 17:08:32 +08:00
}
qualityRatio := 1.0
if imageRequest . Model == "dall-e-3" && imageRequest . Quality == "hd" {
qualityRatio = 2.0
if imageRequest . Size == "1024× 1792" || imageRequest . Size == "1792× 1024" {
qualityRatio = 1.5
}
2023-07-15 12:30:06 +08:00
}
2023-11-09 17:08:32 +08:00
2024-05-13 16:04:02 +08:00
quota := int ( modelPrice * groupRatio * common . QuotaPerUnit * sizeRatio * qualityRatio ) * imageRequest . N
2023-07-15 12:30:06 +08:00
2024-03-13 22:30:10 +08:00
if userQuota - quota < 0 {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( errors . New ( "user quota is not enough" ) , "insufficient_user_quota" , http . StatusForbidden )
2023-07-15 12:30:06 +08:00
}
req , err := http . NewRequest ( c . Request . Method , fullRequestURL , requestBody )
if err != nil {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( err , "new_request_failed" , http . StatusInternalServerError )
2023-07-15 12:30:06 +08:00
}
2024-01-09 15:46:45 +08:00
token := c . Request . Header . Get ( "Authorization" )
if channelType == common . ChannelTypeAzure { // Azure authentication
token = strings . TrimPrefix ( token , "Bearer " )
req . Header . Set ( "api-key" , token )
} else {
req . Header . Set ( "Authorization" , token )
}
2023-07-15 12:30:06 +08:00
req . Header . Set ( "Content-Type" , c . Request . Header . Get ( "Content-Type" ) )
req . Header . Set ( "Accept" , c . Request . Header . Get ( "Accept" ) )
2024-02-29 16:21:25 +08:00
resp , err := service . GetHttpClient ( ) . Do ( req )
2023-06-19 10:28:55 +08:00
if err != nil {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( err , "do_request_failed" , http . StatusInternalServerError )
2023-06-19 10:28:55 +08:00
}
2023-07-15 12:30:06 +08:00
2023-06-19 10:28:55 +08:00
err = req . Body . Close ( )
if err != nil {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( err , "close_request_body_failed" , http . StatusInternalServerError )
2023-07-15 12:30:06 +08:00
}
err = c . Request . Body . Close ( )
if err != nil {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( err , "close_request_body_failed" , http . StatusInternalServerError )
2023-07-15 12:30:06 +08:00
}
2023-11-27 19:21:16 +08:00
if resp . StatusCode != http . StatusOK {
2024-02-29 16:21:25 +08:00
return relaycommon . RelayErrorHandler ( resp )
2023-11-27 19:21:16 +08:00
}
2024-02-29 16:21:25 +08:00
var textResponse dto . ImageResponse
2023-09-17 15:39:46 +08:00
defer func ( ctx context . Context ) {
2024-01-21 15:01:59 +08:00
useTimeSeconds := time . Now ( ) . Unix ( ) - startTime . Unix ( )
2024-03-13 22:30:10 +08:00
if resp . StatusCode != http . StatusOK {
return
2023-07-15 12:30:06 +08:00
}
2024-03-13 22:30:10 +08:00
err := model . PostConsumeTokenQuota ( tokenId , userQuota , quota , 0 , true )
2023-07-15 12:30:06 +08:00
if err != nil {
2024-03-13 22:30:10 +08:00
common . SysError ( "error consuming token remain quota: " + err . Error ( ) )
2023-07-15 12:30:06 +08:00
}
2024-03-13 22:30:10 +08:00
err = model . CacheUpdateUserQuota ( userId )
2023-07-15 12:30:06 +08:00
if err != nil {
2024-03-13 22:30:10 +08:00
common . SysError ( "error update user quota cache: " + err . Error ( ) )
2023-07-15 12:30:06 +08:00
}
2024-03-13 22:30:10 +08:00
if quota != 0 {
tokenName := c . GetString ( "token_name" )
2024-04-24 15:14:16 +08:00
quality := "normal"
if imageRequest . Quality == "hd" {
quality = "hd"
}
2024-05-13 16:04:02 +08:00
logContent := fmt . Sprintf ( "模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s" , modelPrice , groupRatio , imageRequest . Size , quality )
2024-05-12 15:35:57 +08:00
other := make ( map [ string ] interface { } )
2024-05-13 16:04:02 +08:00
other [ "model_price" ] = modelPrice
2024-05-12 15:35:57 +08:00
other [ "group_ratio" ] = groupRatio
model . RecordConsumeLog ( ctx , userId , channelId , 0 , 0 , imageRequest . Model , tokenName , quota , logContent , tokenId , userQuota , int ( useTimeSeconds ) , false , other )
2024-03-13 22:30:10 +08:00
model . UpdateUserUsedQuotaAndRequestCount ( userId , quota )
channelId := c . GetInt ( "channel_id" )
model . UpdateChannelUsedQuota ( channelId , quota )
2023-07-15 12:30:06 +08:00
}
2024-03-13 22:30:10 +08:00
} ( c . Request . Context ( ) )
responseBody , err := io . ReadAll ( resp . Body )
2023-07-15 12:30:06 +08:00
2024-03-13 22:30:10 +08:00
if err != nil {
return service . OpenAIErrorWrapper ( err , "read_response_body_failed" , http . StatusInternalServerError )
}
err = resp . Body . Close ( )
if err != nil {
return service . OpenAIErrorWrapper ( err , "close_response_body_failed" , http . StatusInternalServerError )
}
err = json . Unmarshal ( responseBody , & textResponse )
if err != nil {
return service . OpenAIErrorWrapper ( err , "unmarshal_response_body_failed" , http . StatusInternalServerError )
2023-06-19 10:28:55 +08:00
}
2023-07-15 12:30:06 +08:00
2024-03-13 22:30:10 +08:00
resp . Body = io . NopCloser ( bytes . NewBuffer ( responseBody ) )
2023-06-19 10:28:55 +08:00
for k , v := range resp . Header {
c . Writer . Header ( ) . Set ( k , v [ 0 ] )
}
c . Writer . WriteHeader ( resp . StatusCode )
2023-07-15 12:30:06 +08:00
2023-06-19 10:28:55 +08:00
_ , err = io . Copy ( c . Writer , resp . Body )
if err != nil {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( err , "copy_response_body_failed" , http . StatusInternalServerError )
2023-06-19 10:28:55 +08:00
}
err = resp . Body . Close ( )
if err != nil {
2024-02-29 16:21:25 +08:00
return service . OpenAIErrorWrapper ( err , "close_response_body_failed" , http . StatusInternalServerError )
2023-06-19 10:28:55 +08:00
}
return nil
}