Files
new-api/service/token_counter.go

491 lines
14 KiB
Go
Raw Normal View History

2024-02-29 01:08:18 +08:00
package service
import (
2025-03-12 21:31:46 +08:00
"encoding/json"
2023-11-17 20:32:11 +08:00
"errors"
"fmt"
"github.com/tiktoken-go/tokenizer"
"github.com/tiktoken-go/tokenizer/codec"
2023-11-17 20:32:11 +08:00
"image"
"log"
"math"
"one-api/common"
2024-07-18 17:04:19 +08:00
"one-api/constant"
2024-02-29 01:08:18 +08:00
"one-api/dto"
relaycommon "one-api/relay/common"
"strings"
"sync"
2023-11-15 21:05:14 +08:00
"unicode/utf8"
)
// tokenEncoderMap won't grow after initialization
var defaultTokenEncoder tokenizer.Codec
// tokenEncoderMap is used to store token encoders for different models
var tokenEncoderMap = make(map[string]tokenizer.Codec)
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
var tokenEncoderMutex sync.RWMutex
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
defaultTokenEncoder = codec.NewCl100kBase()
common.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) tokenizer.Codec {
// First, try to get the encoder from cache with read lock
tokenEncoderMutex.RLock()
if encoder, exists := tokenEncoderMap[model]; exists {
tokenEncoderMutex.RUnlock()
return encoder
}
tokenEncoderMutex.RUnlock()
// If not in cache, create new encoder with write lock
tokenEncoderMutex.Lock()
defer tokenEncoderMutex.Unlock()
// Double-check if another goroutine already created the encoder
if encoder, exists := tokenEncoderMap[model]; exists {
return encoder
}
// Create new encoder
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
if err != nil {
// Cache the default encoder for this model to avoid repeated failures
tokenEncoderMap[model] = defaultTokenEncoder
return defaultTokenEncoder
}
// Cache the new encoder
tokenEncoderMap[model] = modelCodec
return modelCodec
}
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
if text == "" {
return 0
}
ids, _, _ := tokenEncoder.Encode(text)
return len(ids)
}
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
2025-03-15 19:10:12 +08:00
if imageUrl == nil {
return 0, fmt.Errorf("image_url_is_nil")
}
2024-07-19 12:59:37 +08:00
baseTokens := 85
2024-05-18 15:14:49 +08:00
if model == "glm-4v" {
return 1047, nil
}
2024-05-30 23:15:06 +08:00
if imageUrl.Detail == "low" {
2024-07-19 12:59:37 +08:00
return baseTokens, nil
2024-05-30 23:15:06 +08:00
}
2024-07-18 17:04:19 +08:00
if !constant.GetMediaTokenNotStream && !stream {
2025-03-15 19:10:12 +08:00
return 3 * baseTokens, nil
2024-07-18 17:04:19 +08:00
}
2025-03-15 19:10:12 +08:00
2024-05-30 21:39:58 +08:00
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high"
}
2024-07-19 12:59:37 +08:00
tileTokens := 170
if strings.HasPrefix(model, "gpt-4o-mini") {
tileTokens = 5667
baseTokens = 2833
}
2025-01-02 16:33:00 +08:00
// 是否统计图片token
if !constant.GetMediaToken {
return 3 * baseTokens, nil
}
if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
return 3 * baseTokens, nil
}
2023-11-19 18:59:35 +08:00
var config image.Config
var err error
2023-12-27 16:32:54 +08:00
var format string
var b64str string
2023-11-19 18:59:35 +08:00
if strings.HasPrefix(imageUrl.Url, "http") {
config, format, err = DecodeUrlImageData(imageUrl.Url)
2023-11-19 18:59:35 +08:00
} else {
common.SysLog(fmt.Sprintf("decoding image"))
config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url)
2023-11-19 18:59:35 +08:00
}
2023-11-17 20:32:11 +08:00
if err != nil {
return 0, err
}
2025-03-15 19:10:12 +08:00
imageUrl.MimeType = format
2023-11-17 20:32:11 +08:00
if config.Width == 0 || config.Height == 0 {
// not an image
if format != "" && b64str != "" {
// file type
return 3 * baseTokens, nil
}
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url))
2023-11-17 20:32:11 +08:00
}
shortSide := config.Width
otherSide := config.Height
2023-12-27 16:32:54 +08:00
log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
2023-11-17 20:32:11 +08:00
// 缩放倍数
scale := 1.0
if config.Height < shortSide {
shortSide = config.Height
otherSide = config.Width
}
// 将最小变的尺寸缩小到768以下如果大于768则缩放到768
if shortSide > 768 {
scale = float64(shortSide) / 768
shortSide = 768
}
// 将另一边按照相同的比例缩小,向上取整
otherSide = int(math.Ceil(float64(otherSide) / scale))
log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
// 计算图片的token数量(边的长度除以512向上取整)
tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
log.Printf("tiles: %d", tiles)
2024-07-19 12:59:37 +08:00
return tiles*tileTokens + baseTokens, nil
2023-11-17 20:32:11 +08:00
}
func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
2024-04-23 23:01:06 +08:00
tkm := 0
msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
2024-04-23 23:01:06 +08:00
if err != nil {
2024-05-23 23:59:55 +08:00
return 0, err
2024-04-23 23:01:06 +08:00
}
tkm += msgTokens
if request.Tools != nil {
openaiTools := request.Tools
2024-04-23 23:01:06 +08:00
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
if tool.Function.Description != "" {
countStr += tool.Function.Description
}
if tool.Function.Parameters != nil {
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens, err := CountTokenInput(countStr, request.Model)
2024-04-23 23:01:06 +08:00
if err != nil {
2024-05-23 23:59:55 +08:00
return 0, err
2024-04-23 23:01:06 +08:00
}
tkm += 8
tkm += toolTokens
}
2024-05-23 23:59:55 +08:00
return tkm, nil
2024-04-23 23:01:06 +08:00
}
2025-03-12 21:31:46 +08:00
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
tkm := 0
// Count tokens in messages
msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
if err != nil {
return 0, err
}
tkm += msgTokens
// Count tokens in system message
if request.System != "" {
systemTokens, err := CountTokenInput(request.System, model)
if err != nil {
return 0, err
}
tkm += systemTokens
}
if request.Tools != nil {
// check is array
if tools, ok := request.Tools.([]any); ok {
if len(tools) > 0 {
parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
if err1 != nil {
return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
}
toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
if err2 != nil {
return 0, fmt.Errorf("tools: %v", err)
}
tkm += toolTokens
}
} else {
return 0, errors.New("tools: Input should be a valid list")
}
}
return tkm, nil
}
func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
tokenEncoder := getTokenEncoder(model)
tokenNum := 0
for _, message := range messages {
// Count tokens for role
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.IsStringContent() {
tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
} else {
content, err := message.ParseContent()
if err != nil {
return 0, err
}
for _, mediaMessage := range content {
switch mediaMessage.Type {
case "text":
tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
case "image":
//imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
//if err != nil {
// return 0, err
//}
tokenNum += 1000
case "tool_use":
if mediaMessage.Input != nil {
tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
inputJSON, _ := json.Marshal(mediaMessage.Input)
tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
}
2025-03-12 21:31:46 +08:00
case "tool_result":
if mediaMessage.Content != nil {
contentJSON, _ := json.Marshal(mediaMessage.Content)
tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
}
2025-03-12 21:31:46 +08:00
}
}
}
}
// Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
return tokenNum, nil
}
func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
tokenEncoder := getTokenEncoder(model)
tokenNum := 0
for _, tool := range tools {
tokenNum += getTokenNum(tokenEncoder, tool.Name)
tokenNum += getTokenNum(tokenEncoder, tool.Description)
schemaJSON, err := json.Marshal(tool.InputSchema)
if err != nil {
return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
}
tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
}
// Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
return tokenNum, nil
}
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
audioToken := 0
textToken := 0
switch request.Type {
case dto.RealtimeEventTypeSessionUpdate:
if request.Session != nil {
msgTokens, err := CountTextToken(request.Session.Instructions, model)
if err != nil {
return 0, 0, err
}
textToken += msgTokens
}
case dto.RealtimeEventResponseAudioDelta:
// count audio token
atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
// count text token
tkm, err := CountTextToken(request.Delta, model)
if err != nil {
return 0, 0, fmt.Errorf("error counting text token: %v", err)
}
textToken += tkm
case dto.RealtimeEventInputAudioBufferAppend:
// count audio token
atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventConversationItemCreated:
if request.Item != nil {
switch request.Item.Type {
case "message":
for _, content := range request.Item.Content {
if content.Type == "input_text" {
tokens, err := CountTextToken(content.Text, model)
if err != nil {
return 0, 0, err
}
textToken += tokens
}
}
}
}
case dto.RealtimeEventTypeResponseDone:
// count tools token
if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools {
toolTokens, err := CountTokenInput(tool, model)
if err != nil {
return 0, 0, err
}
textToken += 8
textToken += toolTokens
}
}
}
}
return textToken, audioToken, nil
}
func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
2023-11-17 18:24:37 +08:00
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Content != nil {
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
2025-03-15 19:10:12 +08:00
imageUrl := m.GetImageMedia()
imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
if err != nil {
return 0, err
2024-01-31 01:41:38 +08:00
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else if m.Type == dto.ContentTypeFile {
tokenNum += 5000
} else if m.Type == dto.ContentTypeVideoUrl {
tokenNum += 5000
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
2023-11-17 18:24:37 +08:00
}
}
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
2024-05-23 23:59:55 +08:00
return tokenNum, nil
}
2024-05-23 23:59:55 +08:00
func CountTokenInput(input any, model string) (int, error) {
switch v := input.(type) {
case string:
return CountTextToken(v, model)
case []string:
text := ""
for _, s := range v {
text += s
}
return CountTextToken(text, model)
case []interface{}:
text := ""
for _, item := range v {
text += fmt.Sprintf("%v", item)
}
return CountTextToken(text, model)
}
2024-05-23 23:59:55 +08:00
return CountTokenInput(fmt.Sprintf("%v", input), model)
}
2024-04-22 16:35:56 +08:00
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
tokens := 0
for _, message := range messages {
2024-05-23 23:59:55 +08:00
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
2024-04-22 16:35:56 +08:00
tokens += tkm
if message.Delta.ToolCalls != nil {
for _, tool := range message.Delta.ToolCalls {
2024-05-23 23:59:55 +08:00
tkm, _ := CountTokenInput(tool.Function.Name, model)
2024-04-22 16:35:56 +08:00
tokens += tkm
2024-05-23 23:59:55 +08:00
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
2024-04-22 16:35:56 +08:00
tokens += tkm
}
}
}
return tokens
}
func CountTTSToken(text string, model string) (int, error) {
2023-11-15 21:05:14 +08:00
if strings.HasPrefix(model, "tts") {
2024-05-23 23:59:55 +08:00
return utf8.RuneCountInString(text), nil
2023-11-15 21:05:14 +08:00
} else {
return CountTextToken(text, model)
}
}
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
2023-11-15 21:05:14 +08:00
}
return int(duration / 60 * 100 / 0.06), nil
2023-11-15 21:05:14 +08:00
}
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 200 / 0.24), nil
}
//func CountAudioToken(sec float64, audioType string) {
// if audioType == "input" {
//
// }
//}
// CountTextToken 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTextToken(text string, model string) (int, error) {
2024-03-20 17:07:42 +08:00
var err error
tokenEncoder := getTokenEncoder(model)
2024-05-23 23:59:55 +08:00
return getTokenNum(tokenEncoder, text), err
}