Files
new-api/middleware/auth.go

220 lines
5.1 KiB
Go
Raw Normal View History

2023-04-22 20:39:27 +08:00
package middleware
import (
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
2023-04-22 21:14:09 +08:00
"one-api/common"
"one-api/model"
"strconv"
2023-04-23 18:24:11 +08:00
"strings"
2023-04-22 20:39:27 +08:00
)
2024-09-24 17:48:09 +08:00
func validUserInfo(username string, role int) bool {
// check username is empty
if strings.TrimSpace(username) == "" {
return false
}
if !common.IsValidateRole(role) {
return false
}
return true
}
2023-04-22 20:39:27 +08:00
func authHelper(c *gin.Context, minRole int) {
session := sessions.Default(c)
username := session.Get("username")
role := session.Get("role")
id := session.Get("id")
status := session.Get("status")
useAccessToken := false
2023-04-22 20:39:27 +08:00
if username == nil {
// Check access token
accessToken := c.Request.Header.Get("Authorization")
if accessToken == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,未登录且未提供 access token",
})
c.Abort()
return
}
user := model.ValidateAccessToken(accessToken)
if user != nil && user.Username != "" {
2024-09-24 17:48:09 +08:00
if !validUserInfo(user.Username, user.Role) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
// Token is valid
username = user.Username
role = user.Role
id = user.Id
status = user.Status
useAccessToken = true
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作access token 无效",
})
c.Abort()
return
}
2023-04-22 20:39:27 +08:00
}
if !useAccessToken {
// get header New-Api-User
apiUserIdStr := c.Request.Header.Get("New-Api-User")
if apiUserIdStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,请刷新页面或清空缓存后重试",
})
c.Abort()
return
}
apiUserId, err := strconv.Atoi(apiUserIdStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,登录信息无效,请重新登录",
})
c.Abort()
return
}
if id != apiUserId {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,与登录用户不匹配,请重新登录",
})
c.Abort()
return
}
}
2023-04-22 20:39:27 +08:00
if status.(int) == common.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已被封禁",
})
c.Abort()
return
}
if role.(int) < minRole {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,权限不足",
})
c.Abort()
return
}
2024-09-24 17:48:09 +08:00
if !validUserInfo(username.(string), role.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
2023-04-22 20:39:27 +08:00
c.Set("username", username)
c.Set("role", role)
c.Set("id", id)
2024-09-26 00:59:09 +08:00
c.Set("group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
2023-04-22 20:39:27 +08:00
c.Next()
}
2024-05-13 23:02:35 +08:00
func TryUserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
session := sessions.Default(c)
id := session.Get("id")
if id != nil {
c.Set("id", id)
}
c.Next()
}
}
2023-04-22 20:39:27 +08:00
func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleCommonUser)
}
}
func AdminAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleAdminUser)
}
}
func RootAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleRootUser)
}
}
2023-04-23 18:24:11 +08:00
func TokenAuth() func(c *gin.Context) {
2023-04-22 20:39:27 +08:00
return func(c *gin.Context) {
2023-04-23 18:24:11 +08:00
key := c.Request.Header.Get("Authorization")
2023-08-31 00:44:16 +08:00
parts := make([]string, 0)
2023-11-28 23:15:17 +08:00
key = strings.TrimPrefix(key, "Bearer ")
if key == "" || key == "midjourney-proxy" {
2023-08-31 00:44:16 +08:00
key = c.Request.Header.Get("mj-api-secret")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(key, "sk-")
2023-12-14 16:43:20 +08:00
parts = strings.Split(key, "-")
2023-08-31 00:44:16 +08:00
key = parts[0]
} else {
key = strings.TrimPrefix(key, "sk-")
2023-12-14 16:43:20 +08:00
parts = strings.Split(key, "-")
2023-08-31 00:44:16 +08:00
key = parts[0]
}
2023-04-23 18:24:11 +08:00
token, err := model.ValidateUserToken(key)
2024-08-04 14:35:16 +08:00
if token != nil {
id := c.GetInt("id")
if id == 0 {
2024-08-07 02:50:22 +08:00
c.Set("id", token.UserId)
2024-08-04 14:35:16 +08:00
}
}
2023-04-23 18:24:11 +08:00
if err != nil {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
2023-04-22 20:39:27 +08:00
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
2023-09-03 21:31:58 +08:00
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
2023-09-03 21:31:58 +08:00
return
}
if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
2023-04-23 18:24:11 +08:00
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
2024-01-11 14:12:38 +08:00
c.Set("token_unlimited_quota", token.UnlimitedQuota)
if !token.UnlimitedQuota {
c.Set("token_quota", token.RemainQuota)
}
if token.ModelLimitsEnabled {
c.Set("token_model_limit_enabled", true)
c.Set("token_model_limit", token.GetModelLimitsMap())
} else {
c.Set("token_model_limit_enabled", false)
}
2024-09-17 20:49:51 +08:00
c.Set("allow_ips", token.GetIpLimitsMap())
2024-09-18 05:19:10 +08:00
c.Set("token_group", token.Group)
2023-04-23 18:24:11 +08:00
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
2024-04-04 16:35:44 +08:00
c.Set("specific_channel_id", parts[1])
} else {
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
}
2023-04-22 20:39:27 +08:00
}
c.Next()
}
}