2025-06-08 21:40:57 +08:00
|
|
|
|
package controller
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
2025-07-21 15:06:26 +08:00
|
|
|
|
"encoding/json"
|
2025-06-08 21:40:57 +08:00
|
|
|
|
"fmt"
|
|
|
|
|
|
"io"
|
2025-06-20 15:50:00 +08:00
|
|
|
|
"time"
|
2025-10-11 15:30:09 +08:00
|
|
|
|
|
|
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/constant"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/dto"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/logger"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/model"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/relay"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/relay/channel"
|
|
|
|
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
2025-06-08 21:40:57 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2025-06-20 15:50:00 +08:00
|
|
|
|
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
2025-06-08 21:40:57 +08:00
|
|
|
|
for channelId, taskIds := range taskChannelM {
|
2025-06-20 15:50:00 +08:00
|
|
|
|
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
2025-08-14 20:05:06 +08:00
|
|
|
|
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
2025-06-08 21:40:57 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-06-20 15:50:00 +08:00
|
|
|
|
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
2025-08-14 20:05:06 +08:00
|
|
|
|
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
2025-06-08 21:40:57 +08:00
|
|
|
|
if len(taskIds) == 0 {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
|
|
|
|
|
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
|
|
|
|
|
"status": "FAILURE",
|
|
|
|
|
|
"progress": "100%",
|
|
|
|
|
|
})
|
|
|
|
|
|
if errUpdate != nil {
|
2025-08-14 21:10:04 +08:00
|
|
|
|
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
2025-06-08 21:40:57 +08:00
|
|
|
|
}
|
|
|
|
|
|
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
|
|
|
|
|
}
|
2025-06-20 15:50:00 +08:00
|
|
|
|
adaptor := relay.GetTaskAdaptor(platform)
|
2025-06-08 21:40:57 +08:00
|
|
|
|
if adaptor == nil {
|
|
|
|
|
|
return fmt.Errorf("video adaptor not found")
|
|
|
|
|
|
}
|
2025-10-09 10:59:05 +08:00
|
|
|
|
info := &relaycommon.RelayInfo{}
|
|
|
|
|
|
info.ChannelMeta = &relaycommon.ChannelMeta{
|
|
|
|
|
|
ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
|
|
|
|
|
|
}
|
2025-11-13 23:26:21 +08:00
|
|
|
|
info.ApiKey = cacheGetChannel.Key
|
2025-10-09 10:59:05 +08:00
|
|
|
|
adaptor.Init(info)
|
2025-06-08 21:40:57 +08:00
|
|
|
|
for _, taskId := range taskIds {
|
|
|
|
|
|
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
2025-08-14 20:05:06 +08:00
|
|
|
|
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
2025-06-08 21:40:57 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
2025-07-03 13:10:25 +08:00
|
|
|
|
baseURL := constant.ChannelBaseURLs[channel.Type]
|
2025-06-08 21:40:57 +08:00
|
|
|
|
if channel.GetBaseURL() != "" {
|
|
|
|
|
|
baseURL = channel.GetBaseURL()
|
|
|
|
|
|
}
|
2025-06-23 21:22:01 +08:00
|
|
|
|
|
|
|
|
|
|
task := taskM[taskId]
|
|
|
|
|
|
if task == nil {
|
2025-08-14 20:05:06 +08:00
|
|
|
|
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
2025-06-23 21:22:01 +08:00
|
|
|
|
return fmt.Errorf("task %s not found", taskId)
|
|
|
|
|
|
}
|
2025-06-08 21:40:57 +08:00
|
|
|
|
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
|
|
|
|
|
"task_id": taskId,
|
2025-06-23 21:22:01 +08:00
|
|
|
|
"action": task.Action,
|
2025-06-08 21:40:57 +08:00
|
|
|
|
})
|
|
|
|
|
|
if err != nil {
|
2025-06-20 15:50:00 +08:00
|
|
|
|
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
2025-06-08 21:40:57 +08:00
|
|
|
|
}
|
2025-06-20 15:50:00 +08:00
|
|
|
|
//if resp.StatusCode != http.StatusOK {
|
|
|
|
|
|
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
|
|
|
|
|
//}
|
2025-06-08 21:40:57 +08:00
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
|
|
|
|
if err != nil {
|
2025-06-20 15:50:00 +08:00
|
|
|
|
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
2025-06-08 21:40:57 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-10-14 23:03:17 +08:00
|
|
|
|
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
|
|
|
|
|
|
|
2025-07-21 15:06:26 +08:00
|
|
|
|
taskResult := &relaycommon.TaskInfo{}
|
|
|
|
|
|
// try parse as New API response format
|
|
|
|
|
|
var responseItems dto.TaskResponse[model.Task]
|
2025-10-14 23:03:17 +08:00
|
|
|
|
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
|
|
|
|
|
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
|
2025-07-21 15:06:26 +08:00
|
|
|
|
t := responseItems.Data
|
|
|
|
|
|
taskResult.TaskID = t.TaskID
|
|
|
|
|
|
taskResult.Status = string(t.Status)
|
|
|
|
|
|
taskResult.Url = t.FailReason
|
|
|
|
|
|
taskResult.Progress = t.Progress
|
|
|
|
|
|
taskResult.Reason = t.FailReason
|
2025-10-10 18:45:12 +08:00
|
|
|
|
task.Data = t.Data
|
2025-07-21 15:06:26 +08:00
|
|
|
|
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
2025-06-20 15:50:00 +08:00
|
|
|
|
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
2025-07-21 15:06:26 +08:00
|
|
|
|
} else {
|
2025-08-26 08:29:26 +08:00
|
|
|
|
task.Data = redactVideoResponseBody(responseBody)
|
2025-06-08 21:40:57 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-10-14 23:03:17 +08:00
|
|
|
|
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
|
|
|
|
|
|
|
2025-06-20 15:50:00 +08:00
|
|
|
|
now := time.Now().Unix()
|
|
|
|
|
|
if taskResult.Status == "" {
|
2025-10-14 23:03:17 +08:00
|
|
|
|
//return fmt.Errorf("task %s status is empty", taskId)
|
|
|
|
|
|
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
2025-06-20 15:50:00 +08:00
|
|
|
|
}
|
2025-10-16 12:38:21 +08:00
|
|
|
|
|
|
|
|
|
|
// 记录原本的状态,防止重复退款
|
|
|
|
|
|
shouldRefund := false
|
|
|
|
|
|
quota := task.Quota
|
|
|
|
|
|
preStatus := task.Status
|
|
|
|
|
|
|
2025-06-20 15:50:00 +08:00
|
|
|
|
task.Status = model.TaskStatus(taskResult.Status)
|
|
|
|
|
|
switch taskResult.Status {
|
|
|
|
|
|
case model.TaskStatusSubmitted:
|
|
|
|
|
|
task.Progress = "10%"
|
|
|
|
|
|
case model.TaskStatusQueued:
|
|
|
|
|
|
task.Progress = "20%"
|
|
|
|
|
|
case model.TaskStatusInProgress:
|
|
|
|
|
|
task.Progress = "30%"
|
|
|
|
|
|
if task.StartTime == 0 {
|
|
|
|
|
|
task.StartTime = now
|
2025-06-08 21:40:57 +08:00
|
|
|
|
}
|
2025-06-20 15:50:00 +08:00
|
|
|
|
case model.TaskStatusSuccess:
|
2025-08-26 08:29:26 +08:00
|
|
|
|
task.Progress = "100%"
|
2025-06-20 15:50:00 +08:00
|
|
|
|
if task.FinishTime == 0 {
|
|
|
|
|
|
task.FinishTime = now
|
|
|
|
|
|
}
|
2025-08-26 08:29:26 +08:00
|
|
|
|
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
|
|
|
|
|
task.FailReason = taskResult.Url
|
|
|
|
|
|
}
|
2025-10-02 02:46:47 +08:00
|
|
|
|
|
|
|
|
|
|
// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
|
|
|
|
|
|
if taskResult.TotalTokens > 0 {
|
|
|
|
|
|
// 获取模型名称
|
|
|
|
|
|
var taskData map[string]interface{}
|
|
|
|
|
|
if err := json.Unmarshal(task.Data, &taskData); err == nil {
|
|
|
|
|
|
if modelName, ok := taskData["model"].(string); ok && modelName != "" {
|
|
|
|
|
|
// 获取模型价格和倍率
|
|
|
|
|
|
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
|
|
|
|
|
// 只有配置了倍率(非固定价格)时才按 token 重新计费
|
|
|
|
|
|
if hasRatioSetting && modelRatio > 0 {
|
|
|
|
|
|
// 获取用户和组的倍率信息
|
2025-10-13 15:17:06 +08:00
|
|
|
|
group := task.Group
|
|
|
|
|
|
if group == "" {
|
|
|
|
|
|
user, err := model.GetUserById(task.UserId, false)
|
|
|
|
|
|
if err == nil {
|
|
|
|
|
|
group = user.Group
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
if group != "" {
|
|
|
|
|
|
groupRatio := ratio_setting.GetGroupRatio(group)
|
|
|
|
|
|
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
|
2025-10-02 02:46:47 +08:00
|
|
|
|
|
|
|
|
|
|
var finalGroupRatio float64
|
|
|
|
|
|
if hasUserGroupRatio {
|
|
|
|
|
|
finalGroupRatio = userGroupRatio
|
|
|
|
|
|
} else {
|
|
|
|
|
|
finalGroupRatio = groupRatio
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
|
|
|
|
|
actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
|
|
|
|
|
|
|
|
|
|
|
|
// 计算差额
|
|
|
|
|
|
preConsumedQuota := task.Quota
|
|
|
|
|
|
quotaDelta := actualQuota - preConsumedQuota
|
|
|
|
|
|
|
|
|
|
|
|
if quotaDelta > 0 {
|
|
|
|
|
|
// 需要补扣费
|
|
|
|
|
|
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
|
|
|
|
|
task.TaskID,
|
|
|
|
|
|
logger.LogQuota(quotaDelta),
|
|
|
|
|
|
logger.LogQuota(actualQuota),
|
|
|
|
|
|
logger.LogQuota(preConsumedQuota),
|
|
|
|
|
|
taskResult.TotalTokens,
|
|
|
|
|
|
))
|
|
|
|
|
|
if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
|
|
|
|
|
|
logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
|
|
|
|
|
|
} else {
|
|
|
|
|
|
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
|
|
|
|
|
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
|
|
|
|
|
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
|
|
|
|
|
|
|
|
|
|
|
// 记录消费日志
|
2025-10-02 03:46:00 +08:00
|
|
|
|
logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
|
|
|
|
|
|
modelRatio, finalGroupRatio, taskResult.TotalTokens,
|
|
|
|
|
|
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
|
2025-10-02 02:46:47 +08:00
|
|
|
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
|
|
|
|
|
}
|
|
|
|
|
|
} else if quotaDelta < 0 {
|
|
|
|
|
|
// 需要退还多扣的费用
|
|
|
|
|
|
refundQuota := -quotaDelta
|
|
|
|
|
|
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
|
|
|
|
|
task.TaskID,
|
|
|
|
|
|
logger.LogQuota(refundQuota),
|
|
|
|
|
|
logger.LogQuota(actualQuota),
|
|
|
|
|
|
logger.LogQuota(preConsumedQuota),
|
|
|
|
|
|
taskResult.TotalTokens,
|
|
|
|
|
|
))
|
|
|
|
|
|
if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
|
|
|
|
|
|
logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
|
|
|
|
|
|
} else {
|
|
|
|
|
|
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
|
|
|
|
|
|
|
|
|
|
|
// 记录退款日志
|
2025-10-02 03:46:00 +08:00
|
|
|
|
logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
|
|
|
|
|
|
modelRatio, finalGroupRatio, taskResult.TotalTokens,
|
|
|
|
|
|
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
|
2025-10-02 02:46:47 +08:00
|
|
|
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
// quotaDelta == 0, 预扣费刚好准确
|
|
|
|
|
|
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
|
|
|
|
|
|
task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2025-06-20 15:50:00 +08:00
|
|
|
|
case model.TaskStatusFailure:
|
2025-10-16 12:38:21 +08:00
|
|
|
|
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
2025-06-20 15:50:00 +08:00
|
|
|
|
task.Status = model.TaskStatusFailure
|
|
|
|
|
|
task.Progress = "100%"
|
|
|
|
|
|
if task.FinishTime == 0 {
|
|
|
|
|
|
task.FinishTime = now
|
|
|
|
|
|
}
|
|
|
|
|
|
task.FailReason = taskResult.Reason
|
2025-08-14 20:05:06 +08:00
|
|
|
|
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
2025-10-12 12:05:36 +08:00
|
|
|
|
taskResult.Progress = "100%"
|
2025-06-08 21:40:57 +08:00
|
|
|
|
if quota != 0 {
|
2025-10-12 12:05:36 +08:00
|
|
|
|
if preStatus != model.TaskStatusFailure {
|
2025-10-16 12:38:21 +08:00
|
|
|
|
shouldRefund = true
|
2025-10-12 12:05:36 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
|
2025-06-08 21:40:57 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
2025-06-20 15:50:00 +08:00
|
|
|
|
default:
|
|
|
|
|
|
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
|
|
|
|
|
}
|
|
|
|
|
|
if taskResult.Progress != "" {
|
|
|
|
|
|
task.Progress = taskResult.Progress
|
2025-06-08 21:40:57 +08:00
|
|
|
|
}
|
|
|
|
|
|
if err := task.Update(); err != nil {
|
2025-08-14 21:10:04 +08:00
|
|
|
|
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
2025-10-16 12:46:07 +08:00
|
|
|
|
shouldRefund = false
|
2025-06-08 21:40:57 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-10-16 12:38:21 +08:00
|
|
|
|
if shouldRefund {
|
|
|
|
|
|
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
|
|
|
|
|
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
|
|
|
|
|
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
|
|
|
|
|
}
|
|
|
|
|
|
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
|
|
|
|
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-06-08 21:40:57 +08:00
|
|
|
|
return nil
|
|
|
|
|
|
}
|
2025-08-26 08:29:26 +08:00
|
|
|
|
|
|
|
|
|
|
func redactVideoResponseBody(body []byte) []byte {
|
|
|
|
|
|
var m map[string]any
|
|
|
|
|
|
if err := json.Unmarshal(body, &m); err != nil {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
resp, _ := m["response"].(map[string]any)
|
|
|
|
|
|
if resp != nil {
|
|
|
|
|
|
delete(resp, "bytesBase64Encoded")
|
|
|
|
|
|
if v, ok := resp["video"].(string); ok {
|
|
|
|
|
|
resp["video"] = truncateBase64(v)
|
|
|
|
|
|
}
|
|
|
|
|
|
if vs, ok := resp["videos"].([]any); ok {
|
|
|
|
|
|
for i := range vs {
|
|
|
|
|
|
if vm, ok := vs[i].(map[string]any); ok {
|
|
|
|
|
|
delete(vm, "bytesBase64Encoded")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
b, err := json.Marshal(m)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return body
|
|
|
|
|
|
}
|
|
|
|
|
|
return b
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func truncateBase64(s string) string {
|
|
|
|
|
|
const maxKeep = 256
|
|
|
|
|
|
if len(s) <= maxKeep {
|
|
|
|
|
|
return s
|
|
|
|
|
|
}
|
|
|
|
|
|
return s[:maxKeep] + "..."
|
|
|
|
|
|
}
|