diff --git a/include/ai_scheduler.h b/include/ai_scheduler.h index 67ba4de..09bf629 100644 --- a/include/ai_scheduler.h +++ b/include/ai_scheduler.h @@ -104,9 +104,16 @@ private: int input_c = 0; std::string path; std::mutex infer_mutex; // Per-model lock for inference + + ~ModelContext() { + if (ctx) { + rknn_destroy(ctx); + ctx = 0; + } + } }; - std::unordered_map> models_; + std::unordered_map> models_; #endif mutable std::mutex models_mutex_; // Protects models_ map diff --git a/include/plugin_loader.h b/include/plugin_loader.h index 64ade14..7973258 100644 --- a/include/plugin_loader.h +++ b/include/plugin_loader.h @@ -16,6 +16,9 @@ public: PluginLoader(const PluginLoader&) = delete; PluginLoader& operator=(const PluginLoader&) = delete; + PluginLoader(PluginLoader&& other) noexcept; + PluginLoader& operator=(PluginLoader&& other) noexcept; + std::unique_ptr Create(const std::string& type, std::string& err); // Switch plugin directory. This will unload any cached plugins. diff --git a/src/ai_scheduler.cpp b/src/ai_scheduler.cpp index b78db2c..e20adba 100644 --- a/src/ai_scheduler.cpp +++ b/src/ai_scheduler.cpp @@ -21,14 +21,10 @@ AiScheduler::~AiScheduler() { void AiScheduler::Shutdown() { #if defined(RK3588_ENABLE_RKNN) - std::lock_guard lock(models_mutex_); - for (auto& [handle, ctx] : models_) { - if (ctx && ctx->ctx) { - rknn_destroy(ctx->ctx); - ctx->ctx = 0; - } + { + std::lock_guard lock(models_mutex_); + models_.clear(); } - models_.clear(); std::cout << "[AiScheduler] shutdown, total inferences: " << total_inferences_.load() << ", errors: " << total_errors_.load() << "\n"; #endif @@ -46,7 +42,7 @@ ModelHandle AiScheduler::LoadModel(const std::string& model_path, std::string& e size_t model_size = file.tellg(); file.seekg(0, std::ios::beg); - auto ctx = std::make_unique(); + auto ctx = std::make_shared(); ctx->model_data.resize(model_size); ctx->path = model_path; @@ -68,6 +64,7 @@ ModelHandle AiScheduler::LoadModel(const std::string& model_path, std::string& e if (ret < 0) { err = "rknn_query IO num failed"; rknn_destroy(ctx->ctx); + ctx->ctx = 0; return kInvalidModelHandle; } @@ -104,13 +101,13 @@ ModelHandle AiScheduler::LoadModel(const std::string& model_path, std::string& e { std::lock_guard lock(models_mutex_); - models_[handle] = std::move(ctx); + models_[handle] = ctx; } std::cout << "[AiScheduler] loaded model: " << model_path - << " (handle=" << handle << ", input=" << models_[handle]->input_w - << "x" << models_[handle]->input_h << "x" << models_[handle]->input_c - << ", outputs=" << models_[handle]->n_output << ")\n"; + << " (handle=" << handle << ", input=" << ctx->input_w + << "x" << ctx->input_h << "x" << ctx->input_c + << ", outputs=" << ctx->n_output << ")\n"; return handle; #else @@ -122,14 +119,17 @@ ModelHandle AiScheduler::LoadModel(const std::string& model_path, std::string& e void AiScheduler::UnloadModel(ModelHandle handle) { #if defined(RK3588_ENABLE_RKNN) - std::lock_guard lock(models_mutex_); - auto it = models_.find(handle); - if (it != models_.end()) { - if (it->second && it->second->ctx) { - rknn_destroy(it->second->ctx); + bool erased = false; + { + std::lock_guard lock(models_mutex_); + auto it = models_.find(handle); + if (it != models_.end()) { + models_.erase(it); + erased = true; } + } + if (erased) { std::cout << "[AiScheduler] unloaded model handle=" << handle << "\n"; - models_.erase(it); } #else (void)handle; @@ -138,13 +138,16 @@ void AiScheduler::UnloadModel(ModelHandle handle) { bool AiScheduler::GetModelInfo(ModelHandle handle, ModelInfo& info) const { #if defined(RK3588_ENABLE_RKNN) - std::lock_guard lock(models_mutex_); - auto it = models_.find(handle); - if (it == models_.end() || !it->second) { - return false; + std::shared_ptr ctx; + { + std::lock_guard lock(models_mutex_); + auto it = models_.find(handle); + if (it == models_.end() || !it->second) { + return false; + } + ctx = it->second; } - const auto& ctx = it->second; info.input_width = ctx->input_w; info.input_height = ctx->input_h; info.input_channels = ctx->input_c; @@ -163,8 +166,7 @@ InferResult AiScheduler::Infer(ModelHandle handle, const InferInput& input) { InferResult result; #if defined(RK3588_ENABLE_RKNN) - ModelContext* ctx_ptr = nullptr; - + std::shared_ptr ctx; { std::lock_guard lock(models_mutex_); auto it = models_.find(handle); @@ -173,11 +175,11 @@ InferResult AiScheduler::Infer(ModelHandle handle, const InferInput& input) { total_errors_.fetch_add(1); return result; } - ctx_ptr = it->second.get(); + ctx = it->second; } - // Lock this specific model for inference - std::lock_guard infer_lock(ctx_ptr->infer_mutex); + // Lock this specific model for inference. + std::lock_guard infer_lock(ctx->infer_mutex); if (!input.data || input.size == 0) { result.error = "Invalid input data"; @@ -195,7 +197,7 @@ InferResult AiScheduler::Infer(ModelHandle handle, const InferInput& input) { inputs[0].buf = const_cast(input.data); inputs[0].pass_through = 0; - int ret = rknn_inputs_set(ctx_ptr->ctx, ctx_ptr->n_input, inputs); + int ret = rknn_inputs_set(ctx->ctx, ctx->n_input, inputs); if (ret < 0) { result.error = "rknn_inputs_set failed: " + std::to_string(ret); total_errors_.fetch_add(1); @@ -203,7 +205,7 @@ InferResult AiScheduler::Infer(ModelHandle handle, const InferInput& input) { } // Run inference - ret = rknn_run(ctx_ptr->ctx, nullptr); + ret = rknn_run(ctx->ctx, nullptr); if (ret < 0) { result.error = "rknn_run failed: " + std::to_string(ret); total_errors_.fetch_add(1); @@ -211,13 +213,13 @@ InferResult AiScheduler::Infer(ModelHandle handle, const InferInput& input) { } // Get outputs - std::vector outputs(ctx_ptr->n_output); - memset(outputs.data(), 0, sizeof(rknn_output) * ctx_ptr->n_output); - for (uint32_t i = 0; i < ctx_ptr->n_output; ++i) { + std::vector outputs(ctx->n_output); + memset(outputs.data(), 0, sizeof(rknn_output) * ctx->n_output); + for (uint32_t i = 0; i < ctx->n_output; ++i) { outputs[i].want_float = 0; // Keep quantized output } - ret = rknn_outputs_get(ctx_ptr->ctx, ctx_ptr->n_output, outputs.data(), nullptr); + ret = rknn_outputs_get(ctx->ctx, ctx->n_output, outputs.data(), nullptr); if (ret < 0) { result.error = "rknn_outputs_get failed: " + std::to_string(ret); total_errors_.fetch_add(1); @@ -225,19 +227,19 @@ InferResult AiScheduler::Infer(ModelHandle handle, const InferInput& input) { } // Copy outputs to result - result.outputs.resize(ctx_ptr->n_output); - for (uint32_t i = 0; i < ctx_ptr->n_output; ++i) { + result.outputs.resize(ctx->n_output); + for (uint32_t i = 0; i < ctx->n_output; ++i) { auto& out = result.outputs[i]; out.index = i; out.size = outputs[i].size; - out.type = ctx_ptr->output_attrs[i].type; - out.zp = ctx_ptr->output_attrs[i].zp; - out.scale = ctx_ptr->output_attrs[i].scale; + out.type = ctx->output_attrs[i].type; + out.zp = ctx->output_attrs[i].zp; + out.scale = ctx->output_attrs[i].scale; // Copy dimensions - out.dims.resize(ctx_ptr->output_attrs[i].n_dims); - for (uint32_t d = 0; d < ctx_ptr->output_attrs[i].n_dims; ++d) { - out.dims[d] = ctx_ptr->output_attrs[i].dims[d]; + out.dims.resize(ctx->output_attrs[i].n_dims); + for (uint32_t d = 0; d < ctx->output_attrs[i].n_dims; ++d) { + out.dims[d] = ctx->output_attrs[i].dims[d]; } // Copy data @@ -245,7 +247,7 @@ InferResult AiScheduler::Infer(ModelHandle handle, const InferInput& input) { memcpy(out.data.data(), outputs[i].buf, outputs[i].size); } - rknn_outputs_release(ctx_ptr->ctx, ctx_ptr->n_output, outputs.data()); + rknn_outputs_release(ctx->ctx, ctx->n_output, outputs.data()); result.success = true; total_inferences_.fetch_add(1); diff --git a/src/graph_manager.cpp b/src/graph_manager.cpp index 4c29cbb..1a39365 100644 --- a/src/graph_manager.cpp +++ b/src/graph_manager.cpp @@ -458,12 +458,11 @@ bool Graph::Start() { continue; } - if (stop_requested_.load()) { - if (entry.context.input_queue->IsStopped() && - entry.context.input_queue->Size() == 0) { - for (auto& q : entry.context.output_queues) q->Stop(); - break; - } + // Avoid busy-spin when upstream queue is stopped. + if (entry.context.input_queue->IsStopped() && + entry.context.input_queue->Size() == 0) { + for (auto& q : entry.context.output_queues) q->Stop(); + break; } } }); @@ -792,25 +791,118 @@ bool GraphManager::ReloadFromFile(const std::string& path, std::string& err) { std::lock_guard lock(graphs_mu_); - // If plugin_dir changes, do full stop+rebuild to avoid unloading libs while nodes exist. - if (!new_plugin_path.empty() && new_plugin_path != loader_.PluginDir()) { - for (auto& g : graphs_) g->Stop(); - graphs_.clear(); - loader_.SetPluginDir(new_plugin_path); + const SimpleJson prev_last_good = last_good_root_; + const size_t prev_default_queue_size = default_queue_size_; + const QueueDropStrategy prev_default_strategy = default_strategy_; + const std::string prev_plugin_dir = loader_.PluginDir(); - for (const auto& graph_val : graphs_it->second.AsArray()) { - std::string name = graph_val.ValueOr("name", "noname"); - auto graph = std::make_unique(name); - if (!graph->Build(graph_val, loader_, new_default_queue_size, new_default_strategy, err)) { - return false; - } - graphs_.push_back(std::move(graph)); + auto build_graphs_locked = [&](const SimpleJson& expanded_root, PluginLoader& loader, + size_t def_q, QueueDropStrategy def_s, + std::vector>& out_graphs, + std::string& build_err) -> bool { + out_graphs.clear(); + const SimpleJson* graphs = expanded_root.Find("graphs"); + if (!graphs || !graphs->IsArray()) { + build_err = "Root config missing 'graphs' array"; + return false; } - for (auto& g : graphs_) { - if (!g->Start()) { - err = "Failed to start graph after full rebuild"; + out_graphs.reserve(graphs->AsArray().size()); + for (const auto& gv : graphs->AsArray()) { + if (!gv.IsObject()) { + build_err = "Graph entry is not object"; return false; } + std::string name = gv.ValueOr("name", "noname"); + auto graph = std::make_unique(name); + if (!graph->Build(gv, loader, def_q, def_s, build_err)) { + return false; + } + out_graphs.push_back(std::move(graph)); + } + return true; + }; + + auto start_graphs_locked = [&](std::vector>& gs, std::string& start_err) -> bool { + for (auto& g : gs) { + if (!g) continue; + if (!g->Start()) { + start_err = "Failed to start graph: " + g->Name(); + return false; + } + } + return true; + }; + + auto stop_all_locked = [&]() { + for (auto& g : graphs_) { + if (g) g->Stop(); + } + }; + + auto recover_locked = [&](std::string& recover_err) -> bool { + stop_all_locked(); + graphs_.clear(); + + if (!prev_plugin_dir.empty() && prev_plugin_dir != loader_.PluginDir()) { + loader_.SetPluginDir(prev_plugin_dir); + } + + std::vector> recovered; + std::string berr; + if (!build_graphs_locked(prev_last_good, loader_, prev_default_queue_size, prev_default_strategy, recovered, + berr)) { + recover_err = "Recovery build failed: " + berr; + return false; + } + std::string serr; + if (!start_graphs_locked(recovered, serr)) { + recover_err = "Recovery start failed: " + serr; + for (auto& gg : recovered) { + if (gg) gg->Stop(); + } + return false; + } + graphs_ = std::move(recovered); + default_queue_size_ = prev_default_queue_size; + default_strategy_ = prev_default_strategy; + // last_good_root_ remains unchanged. + return true; + }; + + const bool plugin_dir_change = (!new_plugin_path.empty() && new_plugin_path != loader_.PluginDir()); + if (plugin_dir_change) { + PluginLoader staged_loader(new_plugin_path); + std::vector> staged_graphs; + std::string berr; + if (!build_graphs_locked(expanded, staged_loader, new_default_queue_size, new_default_strategy, staged_graphs, + berr)) { + err = berr; + return false; + } + + // Switch window: allow short downtime, but must be recoverable. + stop_all_locked(); + graphs_.clear(); + + PluginLoader old_loader = std::move(loader_); + loader_ = std::move(staged_loader); + graphs_ = std::move(staged_graphs); + + std::string serr; + if (!start_graphs_locked(graphs_, serr)) { + err = "Failed to start after plugin_path switch: " + serr; + + // Stop partially started graphs before recovery. + stop_all_locked(); + graphs_.clear(); + + loader_ = std::move(old_loader); + std::string rerr; + if (!recover_locked(rerr)) { + err += "; recovery failed: " + rerr; + return false; + } + return false; } last_good_root_ = expanded; @@ -819,16 +911,22 @@ bool GraphManager::ReloadFromFile(const std::string& path, std::string& err) { return true; } - // Index existing graphs by name. - std::map old_index; - for (size_t i = 0; i < graphs_.size(); ++i) { - old_index[graphs_[i]->Name()] = i; - } + auto find_graph_index_locked = [&](const std::string& name, size_t& out_idx) -> bool { + for (size_t i = 0; i < graphs_.size(); ++i) { + if (graphs_[i] && graphs_[i]->Name() == name) { + out_idx = i; + return true; + } + } + return false; + }; // Track graphs referenced by new config. std::set seen; - // Update or rebuild existing graphs. + // Stage graphs that require rebuild or are newly added. + std::map> staged; + for (const auto& graph_val : graphs_it->second.AsArray()) { if (!graph_val.IsObject()) { err = "Graph entry is not object"; @@ -837,22 +935,18 @@ bool GraphManager::ReloadFromFile(const std::string& path, std::string& err) { std::string name = graph_val.ValueOr("name", "noname"); seen.insert(name); - auto it = old_index.find(name); - if (it == old_index.end()) { - // New graph: build+start. + size_t idx = 0; + if (!find_graph_index_locked(name, idx)) { + // New graph: stage build (do not start until we have stopped removed graphs). auto graph = std::make_unique(name); if (!graph->Build(graph_val, loader_, new_default_queue_size, new_default_strategy, err)) { return false; } - if (!graph->Start()) { - err = "Failed to start new graph: " + name; - return false; - } - graphs_.push_back(std::move(graph)); + staged[name] = std::move(graph); continue; } - auto& g = graphs_[it->second]; + auto& g = graphs_[idx]; std::string upd_err; if (g->TryUpdateInPlace(graph_val, new_default_queue_size, new_default_strategy, upd_err)) { continue; @@ -862,25 +956,16 @@ bool GraphManager::ReloadFromFile(const std::string& path, std::string& err) { return false; } - // Need rebuild. - g->Stop(); - auto new_g = std::make_unique(name); - if (!new_g->Build(graph_val, loader_, new_default_queue_size, new_default_strategy, err)) { - // Rollback: keep old stopped graph is not acceptable; attempt restart old. - (void)g->Start(); + auto graph = std::make_unique(name); + if (!graph->Build(graph_val, loader_, new_default_queue_size, new_default_strategy, err)) { return false; } - if (!new_g->Start()) { - err = "Failed to start rebuilt graph: " + name; - (void)g->Start(); - return false; - } - g = std::move(new_g); + staged[name] = std::move(graph); } - // Stop and remove graphs not present anymore. + // Stop and remove graphs not present anymore (may free resources needed by staged graphs). for (auto itg = graphs_.begin(); itg != graphs_.end();) { - if (!seen.count((*itg)->Name())) { + if (*itg && !seen.count((*itg)->Name())) { (*itg)->Stop(); itg = graphs_.erase(itg); } else { @@ -888,6 +973,39 @@ bool GraphManager::ReloadFromFile(const std::string& path, std::string& err) { } } + // Apply staged graphs. + for (auto& kv : staged) { + const std::string& name = kv.first; + auto& new_g = kv.second; + if (!new_g) continue; + + size_t idx = 0; + if (find_graph_index_locked(name, idx)) { + graphs_[idx]->Stop(); + if (!new_g->Start()) { + err = "Failed to start rebuilt graph: " + name; + std::string rerr; + if (!recover_locked(rerr)) { + err += "; recovery failed: " + rerr; + return false; + } + return false; + } + graphs_[idx] = std::move(new_g); + } else { + if (!new_g->Start()) { + err = "Failed to start new graph: " + name; + std::string rerr; + if (!recover_locked(rerr)) { + err += "; recovery failed: " + rerr; + return false; + } + return false; + } + graphs_.push_back(std::move(new_g)); + } + } + last_good_root_ = expanded; default_queue_size_ = new_default_queue_size; default_strategy_ = new_default_strategy; diff --git a/src/plugin_loader.cpp b/src/plugin_loader.cpp index cb92f42..650fbb9 100644 --- a/src/plugin_loader.cpp +++ b/src/plugin_loader.cpp @@ -62,6 +62,25 @@ std::string SharedLibExtension() { PluginLoader::PluginLoader(std::string plugin_dir) : plugin_dir_(std::move(plugin_dir)) {} +PluginLoader::PluginLoader(PluginLoader&& other) noexcept + : plugin_dir_(std::move(other.plugin_dir_)), cache_(std::move(other.cache_)) { + other.cache_.clear(); +} + +PluginLoader& PluginLoader::operator=(PluginLoader&& other) noexcept { + if (this == &other) return *this; + + for (auto& kv : cache_) { + CloseLibraryHandle(kv.second.handle); + } + cache_.clear(); + + plugin_dir_ = std::move(other.plugin_dir_); + cache_ = std::move(other.cache_); + other.cache_.clear(); + return *this; +} + PluginLoader::~PluginLoader() { for (auto& kv : cache_) { CloseLibraryHandle(kv.second.handle);