Files
new-api/controller/task_video.go

148 lines
4.7 KiB
Go
Raw Normal View History

2025-06-08 21:40:57 +08:00
package controller
import (
"context"
2025-07-21 15:06:26 +08:00
"encoding/json"
2025-06-08 21:40:57 +08:00
"fmt"
"io"
"one-api/common"
"one-api/constant"
2025-07-21 15:06:26 +08:00
"one-api/dto"
2025-06-08 21:40:57 +08:00
"one-api/model"
"one-api/relay"
"one-api/relay/channel"
2025-07-21 15:06:26 +08:00
relaycommon "one-api/relay/common"
2025-06-20 15:50:00 +08:00
"time"
2025-06-08 21:40:57 +08:00
)
2025-06-20 15:50:00 +08:00
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
2025-06-08 21:40:57 +08:00
for channelId, taskIds := range taskChannelM {
2025-06-20 15:50:00 +08:00
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
2025-06-08 21:40:57 +08:00
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
2025-06-20 15:50:00 +08:00
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
2025-06-08 21:40:57 +08:00
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
cacheGetChannel, err := model.CacheGetChannel(channelId)
if err != nil {
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if errUpdate != nil {
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
2025-06-20 15:50:00 +08:00
adaptor := relay.GetTaskAdaptor(platform)
2025-06-08 21:40:57 +08:00
if adaptor == nil {
return fmt.Errorf("video adaptor not found")
}
for _, taskId := range taskIds {
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
}
}
return nil
}
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
baseURL := constant.ChannelBaseURLs[channel.Type]
2025-06-08 21:40:57 +08:00
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
2025-06-23 21:22:01 +08:00
task := taskM[taskId]
if task == nil {
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
2025-06-08 21:40:57 +08:00
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId,
2025-06-23 21:22:01 +08:00
"action": task.Action,
2025-06-08 21:40:57 +08:00
})
if err != nil {
2025-06-20 15:50:00 +08:00
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
2025-06-08 21:40:57 +08:00
}
2025-06-20 15:50:00 +08:00
//if resp.StatusCode != http.StatusOK {
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
//}
2025-06-08 21:40:57 +08:00
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
2025-06-20 15:50:00 +08:00
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
2025-06-08 21:40:57 +08:00
}
2025-07-21 15:06:26 +08:00
taskResult := &relaycommon.TaskInfo{}
// try parse as New API response format
var responseItems dto.TaskResponse[model.Task]
2025-07-23 10:22:52 +08:00
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
2025-07-21 15:06:26 +08:00
t := responseItems.Data
taskResult.TaskID = t.TaskID
taskResult.Status = string(t.Status)
taskResult.Url = t.FailReason
taskResult.Progress = t.Progress
taskResult.Reason = t.FailReason
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
2025-06-20 15:50:00 +08:00
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
2025-07-21 15:06:26 +08:00
} else {
task.Data = responseBody
2025-06-08 21:40:57 +08:00
}
2025-06-20 15:50:00 +08:00
now := time.Now().Unix()
if taskResult.Status == "" {
return fmt.Errorf("task %s status is empty", taskId)
}
task.Status = model.TaskStatus(taskResult.Status)
switch taskResult.Status {
case model.TaskStatusSubmitted:
task.Progress = "10%"
case model.TaskStatusQueued:
task.Progress = "20%"
case model.TaskStatusInProgress:
task.Progress = "30%"
if task.StartTime == 0 {
task.StartTime = now
2025-06-08 21:40:57 +08:00
}
2025-06-20 15:50:00 +08:00
case model.TaskStatusSuccess:
task.Progress = "100%"
if task.FinishTime == 0 {
task.FinishTime = now
}
task.FailReason = taskResult.Url
case model.TaskStatusFailure:
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if task.FinishTime == 0 {
task.FinishTime = now
}
task.FailReason = taskResult.Reason
2025-06-08 21:40:57 +08:00
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota
if quota != 0 {
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
2025-06-20 15:50:00 +08:00
default:
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
}
if taskResult.Progress != "" {
task.Progress = taskResult.Progress
2025-06-08 21:40:57 +08:00
}
if err := task.Update(); err != nil {
common.SysError("UpdateVideoTask task error: " + err.Error())
}
return nil
}