3588AdminBackend/internal/service/task_test.go

316 lines
9.0 KiB
Go

package service
import (
"3588AdminBackend/internal/config"
"3588AdminBackend/internal/models"
"3588AdminBackend/internal/storage"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
)
func waitForTaskDone(t *testing.T, task *models.Task, timeout time.Duration) models.TaskStatus {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
task.Mu.RLock()
st := task.Status
task.Mu.RUnlock()
if st == models.TaskSuccess || st == models.TaskFailed {
return st
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("timed out waiting for task to finish")
return ""
}
func TestTaskService_CreateTask(t *testing.T) {
cfg := &config.Config{
Concurrency: 5,
}
// Mock registry
agent := NewAgentClient(cfg)
reg := NewRegistryService(cfg, agent)
reg.UpdateDevice(&models.Device{
DeviceID: "dev1",
IP: "127.0.0.1",
AgentPort: 9100,
Online: true,
})
svc := NewTaskService(cfg, agent, reg)
task, err := svc.CreateTask("config_apply", []string{"dev1"}, map[string]string{"foo": "bar"})
if err != nil {
t.Fatalf("failed to create task: %v", err)
}
if task.ID == "" {
t.Error("expected task ID to be set")
}
// Wait for task to finish or fail (since agent is nil, it will fail)
time.Sleep(100 * time.Millisecond)
task.Mu.RLock()
defer task.Mu.RUnlock()
if task.Devices["dev1"].Status == models.TaskPending {
t.Error("expected task status to change from pending")
}
}
func TestTaskService_Subscribe(t *testing.T) {
cfg := &config.Config{
Concurrency: 5,
}
svc := NewTaskService(cfg, NewAgentClient(cfg), NewRegistryService(cfg, NewAgentClient(cfg)))
taskID := "test-task"
svc.tasks[taskID] = models.NewTask(taskID, "test", []string{"dev1"}, nil)
ch, cleanup := svc.Subscribe(taskID)
defer cleanup()
go func() {
svc.updateDeviceStatus(taskID, "dev1", models.TaskRunning, 0.5, "")
}()
select {
case update := <-ch:
if update.DeviceID != "dev1" || update.Status != models.TaskRunning {
t.Errorf("unexpected update: %+v", update)
}
case <-time.After(1 * time.Second):
t.Error("timed out waiting for event")
}
}
func TestTaskService_ConfigApply_UsesPayloadConfigField(t *testing.T) {
var gotBody any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
t.Fatalf("expected PUT, got %s", r.Method)
}
if r.URL.Path != "/v1/config" {
t.Fatalf("expected path /v1/config, got %s", r.URL.Path)
}
_ = json.NewDecoder(r.Body).Decode(&gotBody)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
u, _ := url.Parse(server.URL)
host, portStr, err := net.SplitHostPort(u.Host)
if err != nil {
t.Fatalf("SplitHostPort(%q): %v", u.Host, err)
}
port, _ := strconv.Atoi(portStr)
cfg := &config.Config{Concurrency: 1}
agent := NewAgentClient(cfg)
reg := NewRegistryService(cfg, agent)
reg.UpdateDevice(&models.Device{DeviceID: "dev1", IP: host, AgentPort: port, Online: true})
svc := NewTaskService(cfg, agent, reg)
payload := map[string]any{"config": map[string]any{"a": 1}}
task, err := svc.CreateTask("config_apply", []string{"dev1"}, payload)
if err != nil {
t.Fatalf("failed to create task: %v", err)
}
st := waitForTaskDone(t, task, 2*time.Second)
if st != models.TaskSuccess {
t.Fatalf("expected task success, got %s", st)
}
m, ok := gotBody.(map[string]any)
if !ok || m["a"].(float64) != 1 {
t.Fatalf("expected body {a:1}, got %#v", gotBody)
}
}
func TestTaskService_MediaStart_IgnoresInvalidConfigShape(t *testing.T) {
var bodyBytes []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/v1/media-server/start" {
t.Fatalf("expected path /v1/media-server/start, got %s", r.URL.Path)
}
bodyBytes, _ = io.ReadAll(r.Body)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
u, _ := url.Parse(server.URL)
host, portStr, err := net.SplitHostPort(u.Host)
if err != nil {
t.Fatalf("SplitHostPort(%q): %v", u.Host, err)
}
port, _ := strconv.Atoi(portStr)
cfg := &config.Config{Concurrency: 1}
agent := NewAgentClient(cfg)
reg := NewRegistryService(cfg, agent)
reg.UpdateDevice(&models.Device{DeviceID: "dev1", IP: host, AgentPort: port, Online: true})
svc := NewTaskService(cfg, agent, reg)
// UI default payload_json is {"config":{}}; this should be ignored for media_start.
payload := map[string]any{"config": map[string]any{}}
task, err := svc.CreateTask("media_start", []string{"dev1"}, payload)
if err != nil {
t.Fatalf("failed to create task: %v", err)
}
st := waitForTaskDone(t, task, 2*time.Second)
if st != models.TaskSuccess {
t.Fatalf("expected task success, got %s", st)
}
if len(bodyBytes) != 0 {
t.Fatalf("expected empty body, got %q", string(bodyBytes))
}
}
func TestTaskService_ConfigApplyPersistsDeviceConfigStateAndAudit(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
u, _ := url.Parse(server.URL)
host, portStr, err := net.SplitHostPort(u.Host)
if err != nil {
t.Fatalf("SplitHostPort(%q): %v", u.Host, err)
}
port, _ := strconv.Atoi(portStr)
cfg := &config.Config{Concurrency: 1}
agent := NewAgentClient(cfg)
reg := NewRegistryService(cfg, agent)
reg.UpdateDevice(&models.Device{DeviceID: "dev1", IP: host, AgentPort: port, Online: true})
store, err := storage.OpenSQLite(filepath.Join(t.TempDir(), "app.db"))
if err != nil {
t.Fatalf("OpenSQLite: %v", err)
}
defer store.Close()
svc := NewTaskService(cfg, agent, reg)
svc.SetDeviceConfigStateRepo(storage.NewDeviceConfigStateRepo(store.DB()))
svc.SetAuditLogRepo(storage.NewAuditLogsRepo(store.DB()))
payload := map[string]any{
"config": map[string]any{
"metadata": map[string]any{
"template": "helmet",
"profile": "gate_a",
"overlays": []any{"night_relaxed"},
"config_id": "cfg-001",
"config_version": "20260427.1",
},
},
}
task, err := svc.CreateTask("config_apply", []string{"dev1"}, payload)
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
if st := waitForTaskDone(t, task, 2*time.Second); st != models.TaskSuccess {
t.Fatalf("expected task success, got %s", st)
}
state, err := storage.NewDeviceConfigStateRepo(store.DB()).Get("dev1")
if err != nil {
t.Fatalf("Get state: %v", err)
}
if state == nil || state.ProfileName != "gate_a" || state.ConfigID != "cfg-001" || state.LastAppliedTaskID != task.ID {
t.Fatalf("unexpected state: %#v", state)
}
logs, err := storage.NewAuditLogsRepo(store.DB()).List()
if err != nil {
t.Fatalf("List audit logs: %v", err)
}
if len(logs) == 0 || logs[0].Action != "config_apply" || logs[0].TargetID != "dev1" {
t.Fatalf("unexpected audit logs: %#v", logs)
}
}
func TestTaskService_ModelSyncAllUploadsStandardModels(t *testing.T) {
var uploads []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
t.Fatalf("expected PUT, got %s", r.Method)
}
uploads = append(uploads, r.URL.Path)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
u, _ := url.Parse(server.URL)
host, portStr, err := net.SplitHostPort(u.Host)
if err != nil {
t.Fatalf("SplitHostPort(%q): %v", u.Host, err)
}
port, _ := strconv.Atoi(portStr)
cfg := &config.Config{Concurrency: 1}
agent := NewAgentClient(cfg)
reg := NewRegistryService(cfg, agent)
reg.UpdateDevice(&models.Device{DeviceID: "dev1", IP: host, AgentPort: port, Online: true})
store, err := storage.OpenSQLite(filepath.Join(t.TempDir(), "app.db"))
if err != nil {
t.Fatalf("OpenSQLite: %v", err)
}
defer store.Close()
repo := storage.NewModelsRepo(store.DB())
modelDir := t.TempDir()
body := []byte("model-a")
sum := sha256SumHex(body)
fileName := "face_det_scrfd_500m_640_rk3588.rknn"
if err := os.WriteFile(filepath.Join(modelDir, fileName), body, 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
if err := repo.Save(storage.StandardModelRecord{
Name: "face_det_scrfd_500m_640_rk3588",
FileName: fileName,
Version: "v1.0.0",
SHA256: sum,
}); err != nil {
t.Fatalf("Save model: %v", err)
}
svc := NewTaskService(cfg, agent, reg)
svc.SetStandardModels(repo, modelDir)
task, err := svc.CreateTask("model_sync_all", []string{"dev1"}, map[string]any{})
if err != nil {
t.Fatalf("CreateTask: %v", err)
}
if st := waitForTaskDone(t, task, 2*time.Second); st != models.TaskSuccess {
t.Fatalf("expected task success, got %s", st)
}
if len(uploads) != 1 || !strings.Contains(uploads[0], "/v1/models/face_det_scrfd_500m_640_rk3588") {
t.Fatalf("unexpected uploads: %#v", uploads)
}
}
func sha256SumHex(body []byte) string {
sum := sha256.Sum256(body)
return hex.EncodeToString(sum[:])
}