diff --git a/agent/internal/httpapi/config_candidate_apply_test.go b/agent/internal/httpapi/config_candidate_apply_test.go index d53f133..2d59bbf 100644 --- a/agent/internal/httpapi/config_candidate_apply_test.go +++ b/agent/internal/httpapi/config_candidate_apply_test.go @@ -142,3 +142,108 @@ func TestApplyCandidateConfigBytes(t *testing.T) { t.Fatalf("last_good body = %s", gotLastGood) } } + +func TestHandleMediaRollbackRestoresPreviousConfig(t *testing.T) { + reloadCalls := 0 + msServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/api/config/reload" { + reloadCalls++ + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + return + } + t.Fatalf("unexpected media-server request %s %s", r.Method, r.URL.Path) + })) + defer msServer.Close() + + ms, err := mediaserver.New(msServer.URL, 3000, 1, nil) + if err != nil { + t.Fatalf("new mediaserver client: %v", err) + } + + dir := t.TempDir() + cfgPath := filepath.Join(dir, "media-server.json") + currentBody := []byte(`{"templates":{"tpl":{"nodes":[],"edges":[]}},"instances":[],"metadata":{"config_id":"current","config_version":"v2"}}`) + previousBody := []byte(`{"templates":{"tpl":{"nodes":[],"edges":[]}},"instances":[],"metadata":{"config_id":"previous","config_version":"v1"}}`) + if err := os.WriteFile(cfgPath, currentBody, 0o644); err != nil { + t.Fatalf("write current: %v", err) + } + if err := os.WriteFile(cfgPath+".last_good.json", previousBody, 0o644); err != nil { + t.Fatalf("write previous: %v", err) + } + + s := &Server{ + agentCfg: config.AgentConfig{ConfigPath: cfgPath, Token: "test-token"}, + ms: ms, + } + req := httptest.NewRequest(http.MethodPost, "/v1/media-server/rollback", nil) + req.Header.Set("X-RK-Token", "test-token") + rr := httptest.NewRecorder() + + s.handleMediaRollback(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status code: got %d body=%s", rr.Code, rr.Body.String()) + } + if reloadCalls != 1 { + t.Fatalf("reload calls = %d", reloadCalls) + } + gotCurrent, err := os.ReadFile(cfgPath) + if err != nil { + t.Fatalf("read current: %v", err) + } + if strings.TrimSpace(string(gotCurrent)) != string(previousBody) { + t.Fatalf("current body = %s", gotCurrent) + } +} + +func TestApplyRootConfigBytesRestoresPreviousWhenReloadFails(t *testing.T) { + reloadCalls := 0 + msServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && r.URL.Path == "/api/config/reload" { + reloadCalls++ + if reloadCalls == 1 { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"reload failed"}`)) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + return + } + t.Fatalf("unexpected media-server request %s %s", r.Method, r.URL.Path) + })) + defer msServer.Close() + + ms, err := mediaserver.New(msServer.URL, 3000, 1, nil) + if err != nil { + t.Fatalf("new mediaserver client: %v", err) + } + + dir := t.TempDir() + cfgPath := filepath.Join(dir, "media-server.json") + currentBody := []byte(`{"templates":{"tpl":{"nodes":[],"edges":[]}},"instances":[],"metadata":{"config_id":"current","config_version":"v1"}}`) + newBody := []byte(`{"templates":{"tpl":{"nodes":[],"edges":[]}},"instances":[],"metadata":{"config_id":"new","config_version":"v2"}}`) + if err := os.WriteFile(cfgPath, currentBody, 0o644); err != nil { + t.Fatalf("write current: %v", err) + } + + s := &Server{ + agentCfg: config.AgentConfig{ConfigPath: cfgPath}, + ms: ms, + } + err = s.applyRootConfigBytes(context.Background(), newBody) + if err == nil || !strings.Contains(err.Error(), "restored previous config") { + t.Fatalf("applyRootConfigBytes err = %v", err) + } + if reloadCalls != 2 { + t.Fatalf("reload calls = %d", reloadCalls) + } + gotCurrent, err := os.ReadFile(cfgPath) + if err != nil { + t.Fatalf("read current: %v", err) + } + if strings.TrimSpace(string(gotCurrent)) != string(currentBody) { + t.Fatalf("current body = %s", gotCurrent) + } +} diff --git a/agent/internal/httpapi/server.go b/agent/internal/httpapi/server.go index 89ab6fc..203761f 100644 --- a/agent/internal/httpapi/server.go +++ b/agent/internal/httpapi/server.go @@ -357,6 +357,18 @@ func validateRootConfigJSON(body []byte) (rootConfigDocument, error) { } func (s *Server) applyRootConfigBytes(ctx context.Context, body []byte) error { + previous, err := os.ReadFile(s.agentCfg.ConfigPath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("read current config failed: %w", err) + } + var restoreBody []byte + if err == nil && len(previous) > 0 { + restoreBody = previous + } + return s.writeConfigAndReload(ctx, body, restoreBody) +} + +func (s *Server) writeConfigAndReload(ctx context.Context, body []byte, restoreBody []byte) error { if err := files.WriteFileAtomic(s.agentCfg.ConfigPath, append(body, '\n'), 0o644); err != nil { return fmt.Errorf("write config failed: %w", err) } @@ -365,11 +377,18 @@ func (s *Server) applyRootConfigBytes(ctx context.Context, body []byte) error { defer cancel() if err := s.ms.Reload(ctx); err != nil { rerr := err - rbErr := s.ms.Rollback(ctx) - if rbErr != nil { - return fmt.Errorf("reload failed: %v; rollback failed: %v", rerr, rbErr) + if len(restoreBody) == 0 { + return fmt.Errorf("reload failed: %v", rerr) } - return fmt.Errorf("reload failed: %v; rollback ok", rerr) + if werr := files.WriteFileAtomic(s.agentCfg.ConfigPath, append(restoreBody, '\n'), 0o644); werr != nil { + return fmt.Errorf("reload failed: %v; restore write failed: %v", rerr, werr) + } + restoreCtx, restoreCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer restoreCancel() + if restoreErr := s.ms.Reload(restoreCtx); restoreErr != nil { + return fmt.Errorf("reload failed: %v; restore reload failed: %v", rerr, restoreErr) + } + return fmt.Errorf("reload failed: %v; restored previous config", rerr) } return nil } @@ -383,7 +402,7 @@ func (s *Server) applyCandidateConfigBytes(ctx context.Context, body []byte) err } else if err != nil && !os.IsNotExist(err) { return fmt.Errorf("read current config failed: %w", err) } - return s.applyRootConfigBytes(ctx, body) + return s.writeConfigAndReload(ctx, body, current) } var modelNameRE = regexp.MustCompile(`^[A-Za-z0-9._-]+$`) @@ -573,7 +592,24 @@ func (s *Server) handleMediaRollback(w http.ResponseWriter, r *http.Request) { errorJSON(w, http.StatusUnauthorized, "unauthorized") return } - if err := s.ms.Rollback(r.Context()); err != nil { + previousPath := s.agentCfg.ConfigPath + ".last_good.json" + body, err := os.ReadFile(previousPath) + if err != nil { + if os.IsNotExist(err) { + s.recordAudit(r, "media.rollback", false, "previous config not found") + errorJSON(w, http.StatusNotFound, "previous config not found") + return + } + s.recordAudit(r, "media.rollback", false, err.Error()) + errorJSON(w, http.StatusInternalServerError, "internal error: read previous config failed: "+err.Error()) + return + } + if _, err := validateRootConfigJSON(body); err != nil { + s.recordAudit(r, "media.rollback", false, err.Error()) + errorJSON(w, http.StatusBadRequest, err.Error()) + return + } + if err := s.writeConfigAndReload(r.Context(), body, nil); err != nil { s.recordAudit(r, "media.rollback", false, err.Error()) errorJSON(w, http.StatusInternalServerError, "internal error: "+err.Error()) return diff --git a/agent/internal/mediaserver/client.go b/agent/internal/mediaserver/client.go index b6754dc..1ec49ea 100644 --- a/agent/internal/mediaserver/client.go +++ b/agent/internal/mediaserver/client.go @@ -70,11 +70,6 @@ func (c *Client) Reload(ctx context.Context) error { return err } -func (c *Client) Rollback(ctx context.Context) error { - _, _, err := c.doControl(ctx, http.MethodPost, "/api/config/rollback", nil) - return err -} - func (c *Client) UpdateNodeConfig(ctx context.Context, nodeID string, graph string, patch any) error { if strings.TrimSpace(nodeID) == "" { return errors.New("node id is empty") diff --git a/agent/rk3588-agent_linux_arm64 b/agent/rk3588-agent_linux_arm64 index 897ca73..8239059 100755 Binary files a/agent/rk3588-agent_linux_arm64 and b/agent/rk3588-agent_linux_arm64 differ diff --git a/include/graph_manager.h b/include/graph_manager.h index 38014cb..d581c64 100644 --- a/include/graph_manager.h +++ b/include/graph_manager.h @@ -198,9 +198,6 @@ public: Status Reload(const std::string& path); const std::string& ConfigPath() const { return config_path_; } - const std::string& LastGoodPath() const { return last_good_path_; } - bool RollbackFromLastGood(std::string& err); - Status Rollback(); bool UpdateNodeConfig(const std::string& node_id, const std::optional& graph, const SimpleJson& new_node_cfg, std::string& err); @@ -226,7 +223,6 @@ private: // Expanded root config used for running graphs (instances expanded into graphs). SimpleJson last_good_expanded_root_; std::string config_path_; - std::string last_good_path_; size_t default_queue_size_ = 8; QueueDropStrategy default_strategy_ = QueueDropStrategy::DropOldest; std::mutex graphs_mu_; diff --git a/src/graph_manager.cpp b/src/graph_manager.cpp index bf944ff..f99d4ed 100644 --- a/src/graph_manager.cpp +++ b/src/graph_manager.cpp @@ -1137,18 +1137,11 @@ bool GraphManager::Build(const SimpleJson& root_cfg, std::string& err) { default_queue_size_ = default_queue_size; default_strategy_ = default_strategy; - if (!last_good_path_.empty()) { - std::string werr; - if (!WriteTextFileAtomic(last_good_path_, StringifySimpleJson(last_good_source_root_), werr)) { - LogWarn("[GraphManager] persist last_good failed: " + werr); - } - } return true; } bool GraphManager::BuildFromFile(const std::string& path, std::string& err) { config_path_ = path; - last_good_path_ = path + ".last_good.json"; SimpleJson root_cfg; if (!LoadConfigFile(path, root_cfg, err)) { return false; @@ -1202,7 +1195,6 @@ void GraphManager::BlockUntilStop() { bool GraphManager::ReloadFromFile(const std::string& path, std::string& err) { if (config_path_.empty()) { config_path_ = path; - last_good_path_ = path + ".last_good.json"; } SimpleJson root_cfg; @@ -1380,12 +1372,6 @@ bool GraphManager::ReloadFromFile(const std::string& path, std::string& err) { Logger::Instance().SetLevel(*new_log_level); } - if (!last_good_path_.empty()) { - std::string werr; - if (!WriteTextFileAtomic(last_good_path_, StringifySimpleJson(last_good_source_root_), werr)) { - LogWarn("[GraphManager] persist last_good failed: " + werr); - } - } return true; } @@ -1493,40 +1479,9 @@ bool GraphManager::ReloadFromFile(const std::string& path, std::string& err) { Logger::Instance().SetLevel(*new_log_level); } - if (!last_good_path_.empty()) { - std::string werr; - if (!WriteTextFileAtomic(last_good_path_, StringifySimpleJson(last_good_source_root_), werr)) { - LogWarn("[GraphManager] persist last_good failed: " + werr); - } - } return true; } -bool GraphManager::RollbackFromLastGood(std::string& err) { - err.clear(); - if (config_path_.empty()) { - err = "config_path not set"; - return false; - } - if (last_good_path_.empty()) { - err = "last_good_path not set"; - return false; - } - - SimpleJson last_good; - if (!LoadConfigFile(last_good_path_, last_good, err)) { - return false; - } - - std::string werr; - if (!WriteTextFileAtomic(config_path_, StringifySimpleJson(last_good), werr)) { - err = "failed to write rollback config: " + werr; - return false; - } - - return ReloadFromFile(config_path_, err); -} - bool GraphManager::UpdateNodeConfig(const std::string& node_id, const std::optional& graph, const SimpleJson& new_node_cfg, std::string& err) { std::lock_guard lock(graphs_mu_); @@ -1653,14 +1608,6 @@ Status GraphManager::Reload(const std::string& path) { return Status::Ok(); } -Status GraphManager::Rollback() { - std::string err; - if (!RollbackFromLastGood(err)) { - return Status::Fail(err); - } - return Status::Ok(); -} - Status GraphManager::SetNodeConfig(const std::string& node_id, const SimpleJson& new_node_cfg, const std::optional& graph) { std::string err; diff --git a/src/http_server.cpp b/src/http_server.cpp index 4b25286..7e918e0 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -572,14 +572,6 @@ void HttpServer::ServerLoop() { resp.body = OkJson(); } } - } else if (req.path == "/api/config/rollback") { - std::string rerr; - if (!gm_.RollbackFromLastGood(rerr)) { - resp.status = 500; - resp.body = ErrorJson(rerr); - } else { - resp.body = OkJson(); - } } else if (req.path == "/api/log/level") { if (req.body.empty()) { resp.status = 400;