Files
new-api/controller/channel.go

337 lines
6.9 KiB
Go
Raw Normal View History

2023-04-22 20:39:27 +08:00
package controller
import (
"encoding/json"
"fmt"
2023-04-22 20:39:27 +08:00
"github.com/gin-gonic/gin"
"net/http"
2023-04-22 21:14:09 +08:00
"one-api/common"
"one-api/model"
2023-04-22 20:39:27 +08:00
"strconv"
"strings"
2023-04-22 20:39:27 +08:00
)
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
2023-04-22 21:41:16 +08:00
func GetAllChannels(c *gin.Context) {
2023-04-22 20:39:27 +08:00
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
2023-04-22 20:39:27 +08:00
if p < 0 {
p = 0
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
2023-12-05 18:15:40 +08:00
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
2023-04-22 20:39:27 +08:00
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
2023-04-22 22:02:59 +08:00
"data": channels,
2023-04-22 20:39:27 +08:00
})
return
}
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if channel.Type != common.ChannelTypeOpenAI {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "仅支持 OpenAI 类型渠道",
})
return
}
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
result := OpenAIModelsResponse{}
err = json.Unmarshal(body, &result)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
if !result.Success {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "上游返回错误",
})
}
var ids []string
for _, model := range result.Data {
ids = append(ids, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ids,
})
}
2024-01-10 13:23:43 +08:00
func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
}
2023-04-22 21:41:16 +08:00
func SearchChannels(c *gin.Context) {
2023-04-22 20:39:27 +08:00
keyword := c.Query("keyword")
2023-12-05 18:15:40 +08:00
group := c.Query("group")
modelKeyword := c.Query("model")
2023-12-05 18:15:40 +08:00
//idSort, _ := strconv.ParseBool(c.Query("id_sort"))
channels, err := model.SearchChannels(keyword, group, modelKeyword)
2023-04-22 20:39:27 +08:00
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
2023-04-22 22:02:59 +08:00
"data": channels,
2023-04-22 20:39:27 +08:00
})
return
}
2023-04-22 22:02:59 +08:00
func GetChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
2023-04-22 20:39:27 +08:00
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
2023-04-23 18:24:11 +08:00
channel, err := model.GetChannelById(id, false)
2023-04-22 22:02:59 +08:00
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
2023-04-22 20:39:27 +08:00
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
2023-04-22 22:02:59 +08:00
"data": channel,
2023-04-22 20:39:27 +08:00
})
return
}
2023-04-22 22:02:59 +08:00
func AddChannel(c *gin.Context) {
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
2023-04-22 20:39:27 +08:00
"success": false,
2023-04-22 22:02:59 +08:00
"message": err.Error(),
2023-04-22 20:39:27 +08:00
})
return
}
2023-04-23 15:42:23 +08:00
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
2024-08-27 20:19:51 +08:00
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "必须包含default字段",
})
return
}
}
}
2024-08-27 20:19:51 +08:00
keys = []string{channel.Key}
}
2023-08-30 21:15:56 +08:00
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
continue
}
localChannel := channel
localChannel.Key = key
channels = append(channels, localChannel)
}
err = model.BatchInsertChannels(channels)
2023-04-22 22:02:59 +08:00
if err != nil {
2023-04-22 20:39:27 +08:00
c.JSON(http.StatusOK, gin.H{
"success": false,
2023-04-22 22:02:59 +08:00
"message": err.Error(),
2023-04-22 20:39:27 +08:00
})
return
}
2023-04-22 22:02:59 +08:00
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteChannel(c *gin.Context) {
id, _ := strconv.Atoi(c.Param("id"))
channel := model.Channel{Id: id}
err := channel.Delete()
2023-04-22 20:39:27 +08:00
if err != nil {
c.JSON(http.StatusOK, gin.H{
2023-04-22 22:02:59 +08:00
"success": false,
2023-04-22 20:39:27 +08:00
"message": err.Error(),
})
return
}
2023-04-22 22:02:59 +08:00
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
2023-04-22 20:39:27 +08:00
}
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
2023-12-14 16:35:03 +08:00
type ChannelBatch struct {
Ids []int `json:"ids"`
}
func DeleteChannelBatch(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchDeleteChannels(channelBatch.Ids)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
2023-04-22 22:02:59 +08:00
func UpdateChannel(c *gin.Context) {
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
2023-04-22 20:39:27 +08:00
return
}
2023-04-22 22:02:59 +08:00
err = channel.Update()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
2023-04-23 15:42:23 +08:00
"data": channel,
2023-04-22 22:02:59 +08:00
})
return
2023-04-22 20:39:27 +08:00
}