Files
new-api/relay/channel/openai/adaptor.go

313 lines
10 KiB
Go
Raw Normal View History

2024-02-29 01:08:18 +08:00
package openai
import (
2024-07-16 22:07:10 +08:00
"bytes"
"encoding/json"
2024-02-29 01:08:18 +08:00
"errors"
"fmt"
"io"
2024-07-16 22:07:10 +08:00
"mime/multipart"
2024-02-29 01:08:18 +08:00
"net/http"
"one-api/common"
constant2 "one-api/constant"
2024-02-29 01:08:18 +08:00
"one-api/dto"
2024-02-29 16:21:25 +08:00
"one-api/relay/channel"
2024-02-29 18:31:03 +08:00
"one-api/relay/channel/ai360"
2024-03-20 23:44:03 +08:00
"one-api/relay/channel/lingyiwanwu"
2024-05-16 19:03:42 +08:00
"one-api/relay/channel/minimax"
2024-03-25 22:33:33 +08:00
"one-api/relay/channel/moonshot"
"one-api/relay/channel/openrouter"
"one-api/relay/channel/xinference"
2024-02-29 01:08:18 +08:00
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
2024-07-16 22:07:10 +08:00
"one-api/relay/constant"
"one-api/service"
2024-02-29 01:08:18 +08:00
"strings"
"github.com/gin-gonic/gin"
2024-02-29 01:08:18 +08:00
)
type Adaptor struct {
2024-07-16 22:07:10 +08:00
ChannelType int
ResponseFormat string
2024-02-29 01:08:18 +08:00
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if !strings.Contains(request.Model, "claude") {
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
}
aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
if err != nil {
return nil, err
}
if info.SupportStreamOptions {
aiRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
return a.ConvertOpenAIRequest(c, info, aiRequest)
2025-03-12 21:31:46 +08:00
}
2024-07-16 22:07:10 +08:00
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
2024-02-29 18:31:03 +08:00
a.ChannelType = info.ChannelType
// initialize ThinkingContentInfo when thinking_to_content is enabled
if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
HasSentThinkingContent: false,
}
}
2024-02-29 01:08:18 +08:00
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
if info.RelayMode == constant.RelayModeRealtime {
2024-12-04 16:20:24 +08:00
if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
baseUrl = "wss://" + baseUrl
info.BaseUrl = baseUrl
} else if strings.HasPrefix(info.BaseUrl, "http://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "http://")
baseUrl = "ws://" + baseUrl
info.BaseUrl = baseUrl
}
}
2024-05-16 19:03:42 +08:00
switch info.ChannelType {
case common.ChannelTypeAzure:
apiVersion := info.ApiVersion
if apiVersion == "" {
apiVersion = constant2.AzureDefaultAPIVersion
}
2024-02-29 01:08:18 +08:00
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
2024-02-29 01:08:18 +08:00
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
}
2024-02-29 01:08:18 +08:00
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
2024-05-16 19:03:42 +08:00
case common.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)
case common.ChannelTypeCustom:
url := info.BaseUrl
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
return url, nil
2024-05-16 19:03:42 +08:00
default:
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
2024-02-29 01:08:18 +08:00
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, header)
2024-02-29 01:08:18 +08:00
if info.ChannelType == common.ChannelTypeAzure {
header.Set("api-key", info.ApiKey)
2024-02-29 01:08:18 +08:00
return nil
}
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
header.Set("OpenAI-Organization", info.Organization)
}
if info.RelayMode == constant.RelayModeRealtime {
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
if swp != "" {
items := []string{
"realtime",
"openai-insecure-api-key." + info.ApiKey,
"openai-beta.realtime-v1",
}
header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
} else {
header.Set("openai-beta", "realtime=v1")
header.Set("Authorization", "Bearer "+info.ApiKey)
}
} else {
header.Set("Authorization", "Bearer "+info.ApiKey)
}
if info.ChannelType == common.ChannelTypeOpenRouter {
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
header.Set("X-Title", "New API")
}
2024-02-29 01:08:18 +08:00
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
2024-02-29 01:08:18 +08:00
if request == nil {
return nil, errors.New("request is nil")
}
2024-12-26 22:51:06 +08:00
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
2024-07-15 16:05:30 +08:00
request.StreamOptions = nil
}
if strings.HasPrefix(request.Model, "o1") || strings.HasPrefix(request.Model, "o3") {
2024-09-14 16:11:38 +08:00
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
2025-02-01 14:11:34 +08:00
request.Temperature = nil
}
if strings.HasSuffix(request.Model, "-high") {
request.ReasoningEffort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high")
} else if strings.HasSuffix(request.Model, "-low") {
request.ReasoningEffort = "low"
request.Model = strings.TrimSuffix(request.Model, "-low")
} else if strings.HasSuffix(request.Model, "-medium") {
request.ReasoningEffort = "medium"
request.Model = strings.TrimSuffix(request.Model, "-medium")
}
info.ReasoningEffort = request.ReasoningEffort
info.UpstreamModelName = request.Model
2024-09-14 16:11:38 +08:00
}
if request.Model == "o1" || request.Model == "o1-2024-12-17" || strings.HasPrefix(request.Model, "o3") {
2024-12-20 22:07:53 +08:00
//修改第一个Message的内容将system改为developer
2024-12-20 23:14:10 +08:00
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
request.Messages[0].Role = "developer"
2024-12-20 22:07:53 +08:00
}
}
2024-12-20 23:14:10 +08:00
2024-02-29 01:08:18 +08:00
return request, nil
}
2024-07-16 22:07:10 +08:00
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
2024-07-16 22:07:10 +08:00
}
2025-01-23 05:54:39 +08:00
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
2025-01-23 05:54:39 +08:00
}
2024-07-16 22:07:10 +08:00
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
if info.RelayMode == constant.RelayModeAudioSpeech {
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshalling object: %w", err)
}
return bytes.NewReader(jsonData), nil
} else {
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
2024-07-16 22:07:10 +08:00
// 添加文件字段
file, header, err := c.Request.FormFile("file")
if err != nil {
return nil, errors.New("file is required")
}
defer file.Close()
part, err := writer.CreateFormFile("file", header.Filename)
if err != nil {
return nil, errors.New("create form file failed")
}
if _, err := io.Copy(part, file); err != nil {
return nil, errors.New("copy file failed")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return &requestBody, nil
}
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
2024-07-17 23:50:37 +08:00
return request, nil
2024-07-16 22:07:10 +08:00
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
2024-07-16 22:07:10 +08:00
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
return channel.DoFormRequest(a, c, info, requestBody)
} else if info.RelayMode == constant.RelayModeRealtime {
return channel.DoWssRequest(a, c, info, requestBody)
2024-07-16 22:07:10 +08:00
} else {
return channel.DoApiRequest(a, c, info, requestBody)
}
2024-02-29 01:08:18 +08:00
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
2024-07-16 22:07:10 +08:00
switch info.RelayMode {
case constant.RelayModeRealtime:
err, usage = OpenaiRealtimeHandler(c, info)
2024-07-16 22:07:10 +08:00
case constant.RelayModeAudioSpeech:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
2024-07-17 23:50:37 +08:00
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeRerank:
err, usage = common_handler.RerankHandler(c, info, resp)
2024-07-16 22:07:10 +08:00
default:
if info.IsStream {
2024-07-19 01:07:37 +08:00
err, usage = OaiStreamHandler(c, resp, info)
2024-07-16 22:07:10 +08:00
} else {
err, usage = OpenaiHandler(c, resp, info)
2024-07-16 22:07:10 +08:00
}
2024-02-29 01:08:18 +08:00
}
return
}
func (a *Adaptor) GetModelList() []string {
2024-02-29 18:31:03 +08:00
switch a.ChannelType {
case common.ChannelType360:
return ai360.ModelList
case common.ChannelTypeMoonshot:
return moonshot.ModelList
2024-03-20 23:44:03 +08:00
case common.ChannelTypeLingYiWanWu:
return lingyiwanwu.ModelList
2024-05-16 19:03:42 +08:00
case common.ChannelTypeMiniMax:
return minimax.ModelList
case common.ChannelTypeXinference:
return xinference.ModelList
case common.ChannelTypeOpenRouter:
return openrouter.ModelList
2024-02-29 18:31:03 +08:00
default:
return ModelList
}
2024-02-29 01:08:18 +08:00
}
func (a *Adaptor) GetChannelName() string {
2024-05-16 19:03:42 +08:00
switch a.ChannelType {
case common.ChannelType360:
return ai360.ChannelName
case common.ChannelTypeMoonshot:
return moonshot.ChannelName
case common.ChannelTypeLingYiWanWu:
return lingyiwanwu.ChannelName
case common.ChannelTypeMiniMax:
return minimax.ChannelName
case common.ChannelTypeXinference:
return xinference.ChannelName
case common.ChannelTypeOpenRouter:
return openrouter.ChannelName
2024-05-16 19:03:42 +08:00
default:
return ChannelName
}
2024-02-29 01:08:18 +08:00
}