3588AdminBackend/internal/service/task.go

514 lines
13 KiB
Go

package service
import (
"bytes"
"encoding/json"
"fmt"
"io"
"sync"
"3588AdminBackend/internal/config"
"3588AdminBackend/internal/models"
"github.com/google/uuid"
)
type TaskRepository interface {
Save(task *models.Task) error
List() ([]models.Task, error)
}
type DeviceConfigStateRepository interface {
UpsertState(deviceID string, templateName string, profileName string, overlaysJSON string, configID string, configVersion string, lastAppliedTaskID string) error
}
type AuditLogRepository interface {
AppendLog(actor string, action string, targetType string, targetID string, detailsJSON string) error
}
type TaskService struct {
cfg *config.Config
agent *AgentClient
registry *RegistryService
repo TaskRepository
stateRepo DeviceConfigStateRepository
auditRepo AuditLogRepository
tasks map[string]*models.Task
mu sync.RWMutex
listeners map[string][]chan *models.DeviceTaskStatus
lmu sync.RWMutex
}
func (s *TaskService) SetDeviceConfigStateRepo(repo DeviceConfigStateRepository) {
if s == nil {
return
}
s.stateRepo = repo
}
func (s *TaskService) SetAuditLogRepo(repo AuditLogRepository) {
if s == nil {
return
}
s.auditRepo = repo
}
func NewTaskService(cfg *config.Config, agent *AgentClient, registry *RegistryService, repo ...TaskRepository) *TaskService {
var taskRepo TaskRepository
if len(repo) > 0 {
taskRepo = repo[0]
}
return &TaskService{
cfg: cfg,
agent: agent,
registry: registry,
repo: taskRepo,
tasks: make(map[string]*models.Task),
listeners: make(map[string][]chan *models.DeviceTaskStatus),
}
}
func (s *TaskService) ListTasks() []models.Task {
s.mu.RLock()
defer s.mu.RUnlock()
items := make([]models.Task, 0, len(s.tasks))
for _, t := range s.tasks {
t.Mu.RLock()
snap := models.Task{
ID: t.ID,
Type: t.Type,
DeviceIDs: append([]string(nil), t.DeviceIDs...),
Payload: t.Payload,
Status: t.Status,
Devices: make(map[string]*models.DeviceTaskStatus, len(t.Devices)),
}
for did, ds := range t.Devices {
snap.Devices[did] = &models.DeviceTaskStatus{
DeviceID: ds.DeviceID,
Status: ds.Status,
Progress: ds.Progress,
Error: ds.Error,
}
}
t.Mu.RUnlock()
items = append(items, snap)
}
return items
}
func (s *TaskService) CreateTask(tType string, deviceIDs []string, payload interface{}) (*models.Task, error) {
id := uuid.New().String()
task := models.NewTask(id, tType, deviceIDs, payload)
s.mu.Lock()
s.tasks[id] = task
s.mu.Unlock()
s.persistTask(task)
go s.runTask(task)
return task, nil
}
func (s *TaskService) LoadPersistedTasks() error {
if s == nil || s.repo == nil {
return nil
}
items, err := s.repo.List()
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
for i := range items {
item := items[i]
s.tasks[item.ID] = models.NewTask(item.ID, item.Type, append([]string(nil), item.DeviceIDs...), item.Payload)
s.tasks[item.ID].Status = item.Status
for did, ds := range item.Devices {
if ds == nil {
continue
}
s.tasks[item.ID].Devices[did] = &models.DeviceTaskStatus{
DeviceID: ds.DeviceID,
Status: ds.Status,
Progress: ds.Progress,
Error: ds.Error,
}
}
}
return nil
}
func (s *TaskService) runTask(task *models.Task) {
task.Mu.Lock()
task.Status = models.TaskRunning
task.Mu.Unlock()
s.persistTask(task)
// Concurrency control
concurrency := s.cfg.Concurrency
if concurrency <= 0 {
concurrency = 5
}
sem := make(chan struct{}, concurrency)
var wg sync.WaitGroup
for _, did := range task.DeviceIDs {
wg.Add(1)
go func(did string) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
s.executeOnDevice(task, did)
}(did)
}
wg.Wait()
// Overall status: success only if all devices succeed.
task.Mu.Lock()
overallOK := true
for _, ds := range task.Devices {
if ds == nil || ds.Status != models.TaskSuccess {
overallOK = false
break
}
}
if overallOK {
task.Status = models.TaskSuccess
} else {
task.Status = models.TaskFailed
}
task.Mu.Unlock()
s.persistTask(task)
}
func extractConfigPayload(payload any) (any, error) {
if payload == nil {
return nil, fmt.Errorf("payload is required")
}
// Backward-compatible: if payload is {"config": <rootConfig>}, use payload.config.
if m, ok := payload.(map[string]any); ok {
if v, exists := m["config"]; exists {
return v, nil
}
}
return payload, nil
}
func optionalConfigRequestBody(payload any) (io.Reader, int64, error) {
if payload == nil {
return nil, 0, nil
}
// Accept payload as either {"config":"cam1"} or any map that contains a string config.
m, ok := payload.(map[string]any)
if !ok {
return nil, 0, nil
}
v, exists := m["config"]
if !exists {
return nil, 0, nil
}
configStr, ok := v.(string)
if !ok || configStr == "" {
// Ignore invalid shapes (e.g. UI default {"config":{}}) to avoid 400.
return nil, 0, nil
}
b, err := json.Marshal(map[string]any{"config": configStr})
if err != nil {
return nil, 0, err
}
return bytes.NewReader(b), int64(len(b)), nil
}
func (s *TaskService) executeOnDevice(task *models.Task, did string) {
s.updateDeviceStatus(task.ID, did, models.TaskRunning, 0, "")
if s.agent == nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, "agent client not initialized")
return
}
// Find device
devs := s.registry.GetDevices()
var dev *models.Device
for _, d := range devs {
if d.DeviceID == did {
dev = d
break
}
}
if dev == nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, "device not found")
return
}
if !dev.Online {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, "device offline")
return
}
switch task.Type {
case "config_apply":
cfgPayload, err := extractConfigPayload(task.Payload)
if err != nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, err.Error())
return
}
body, err := json.Marshal(cfgPayload)
if err != nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, "invalid payload: "+err.Error())
return
}
_, code, err := s.agent.Do("PUT", dev.IP, dev.AgentPort, "/v1/config", body)
if err != nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, err.Error())
return
}
if code >= 400 {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, fmt.Sprintf("agent error: %d", code))
return
}
s.updateDeviceStatus(task.ID, did, models.TaskSuccess, 1.0, "")
s.persistConfigState(task, did)
s.appendAuditLog(task, did, models.TaskSuccess, "")
case "reload":
_, code, err := s.agent.DoStream("POST", dev.IP, dev.AgentPort, "/v1/media-server/reload", nil, "", 0)
if err != nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, err.Error())
return
}
if code >= 400 {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, fmt.Sprintf("agent error: %d", code))
return
}
s.updateDeviceStatus(task.ID, did, models.TaskSuccess, 1.0, "")
s.appendAuditLog(task, did, models.TaskSuccess, "")
case "rollback":
_, code, err := s.agent.DoStream("POST", dev.IP, dev.AgentPort, "/v1/media-server/rollback", nil, "", 0)
if err != nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, err.Error())
return
}
if code >= 400 {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, fmt.Sprintf("agent error: %d", code))
return
}
s.updateDeviceStatus(task.ID, did, models.TaskSuccess, 1.0, "")
s.appendAuditLog(task, did, models.TaskSuccess, "")
case "media_start":
bodyR, bodyLen, err := optionalConfigRequestBody(task.Payload)
if err != nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, err.Error())
return
}
_, code, err := s.agent.DoStream("POST", dev.IP, dev.AgentPort, "/v1/media-server/start", bodyR, "", bodyLen)
if err != nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, err.Error())
return
}
if code >= 400 {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, fmt.Sprintf("agent error: %d", code))
return
}
s.updateDeviceStatus(task.ID, did, models.TaskSuccess, 1.0, "")
s.appendAuditLog(task, did, models.TaskSuccess, "")
case "media_restart":
bodyR, bodyLen, err := optionalConfigRequestBody(task.Payload)
if err != nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, err.Error())
return
}
_, code, err := s.agent.DoStream("POST", dev.IP, dev.AgentPort, "/v1/media-server/restart", bodyR, "", bodyLen)
if err != nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, err.Error())
return
}
if code >= 400 {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, fmt.Sprintf("agent error: %d", code))
return
}
s.updateDeviceStatus(task.ID, did, models.TaskSuccess, 1.0, "")
s.appendAuditLog(task, did, models.TaskSuccess, "")
case "media_stop":
_, code, err := s.agent.DoStream("POST", dev.IP, dev.AgentPort, "/v1/media-server/stop", nil, "", 0)
if err != nil {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, err.Error())
return
}
if code >= 400 {
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, fmt.Sprintf("agent error: %d", code))
return
}
s.updateDeviceStatus(task.ID, did, models.TaskSuccess, 1.0, "")
s.appendAuditLog(task, did, models.TaskSuccess, "")
default:
s.updateDeviceStatus(task.ID, did, models.TaskFailed, 0, "unsupported task type")
}
}
func (s *TaskService) updateDeviceStatus(taskID, did string, status models.TaskStatus, progress float64, errStr string) {
s.mu.RLock()
task, ok := s.tasks[taskID]
s.mu.RUnlock()
if !ok {
return
}
task.Mu.Lock()
ds, ok := task.Devices[did]
if ok {
ds.Status = status
ds.Progress = progress
ds.Error = errStr
}
task.Mu.Unlock()
s.persistTask(task)
// Notify listeners
s.lmu.RLock()
channels := s.listeners[taskID]
s.lmu.RUnlock()
update := &models.DeviceTaskStatus{
DeviceID: did,
Status: status,
Progress: progress,
Error: errStr,
}
for _, ch := range channels {
select {
case ch <- update:
default:
}
}
}
func (s *TaskService) persistConfigState(task *models.Task, did string) {
if s == nil || s.stateRepo == nil || task == nil || task.Type != "config_apply" {
return
}
meta := taskPayloadMetadata(task.Payload)
overlaysJSON := "[]"
if len(meta.Overlays) > 0 {
if body, err := json.Marshal(meta.Overlays); err == nil {
overlaysJSON = string(body)
}
}
_ = s.stateRepo.UpsertState(did, meta.Template, meta.Profile, overlaysJSON, meta.ConfigID, meta.ConfigVersion, task.ID)
}
func (s *TaskService) appendAuditLog(task *models.Task, did string, status models.TaskStatus, errText string) {
if s == nil || s.auditRepo == nil || task == nil {
return
}
meta := taskPayloadMetadata(task.Payload)
details := map[string]any{
"task_id": task.ID,
"type": task.Type,
"status": status,
}
if meta.Template != "" {
details["template"] = meta.Template
}
if meta.Profile != "" {
details["profile"] = meta.Profile
}
if meta.ConfigID != "" {
details["config_id"] = meta.ConfigID
}
if meta.ConfigVersion != "" {
details["config_version"] = meta.ConfigVersion
}
if len(meta.Overlays) > 0 {
details["overlays"] = meta.Overlays
}
if errText != "" {
details["error"] = errText
}
body, _ := json.Marshal(details)
_ = s.auditRepo.AppendLog("system", task.Type, "device", did, string(body))
}
type taskMetadata struct {
Template string
Profile string
Overlays []string
ConfigID string
ConfigVersion string
}
func taskPayloadMetadata(payload any) taskMetadata {
var out taskMetadata
root, ok := payload.(map[string]any)
if !ok {
return out
}
configRoot, ok := root["config"].(map[string]any)
if !ok {
return out
}
metadata, ok := configRoot["metadata"].(map[string]any)
if !ok {
return out
}
out.Template = stringAny(metadata["template"])
out.Profile = stringAny(metadata["profile"])
out.ConfigID = stringAny(metadata["config_id"])
out.ConfigVersion = stringAny(metadata["config_version"])
if rawOverlays, ok := metadata["overlays"].([]any); ok {
for _, item := range rawOverlays {
if v := stringAny(item); v != "" {
out.Overlays = append(out.Overlays, v)
}
}
}
return out
}
func stringAny(v any) string {
if s, ok := v.(string); ok {
return s
}
return ""
}
func (s *TaskService) persistTask(task *models.Task) {
if s == nil || s.repo == nil || task == nil {
return
}
_ = s.repo.Save(task)
}
func (s *TaskService) Subscribe(taskID string) (chan *models.DeviceTaskStatus, func()) {
ch := make(chan *models.DeviceTaskStatus, 10)
s.lmu.Lock()
s.listeners[taskID] = append(s.listeners[taskID], ch)
s.lmu.Unlock()
cleanup := func() {
s.lmu.Lock()
list := s.listeners[taskID]
for i, c := range list {
if c == ch {
s.listeners[taskID] = append(list[:i], list[i+1:]...)
break
}
}
s.lmu.Unlock()
}
return ch, cleanup
}