126 lines
3.1 KiB
C++
126 lines
3.1 KiB
C++
#pragma once
|
|
|
|
#include <atomic>
|
|
#include <cstdint>
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
#if defined(RK3588_ENABLE_RKNN)
|
|
#include "rknn_api.h"
|
|
#endif
|
|
|
|
namespace rk3588 {
|
|
|
|
using ModelHandle = uint64_t;
|
|
constexpr ModelHandle kInvalidModelHandle = 0;
|
|
|
|
struct ModelInfo {
|
|
int input_width = 0;
|
|
int input_height = 0;
|
|
int input_channels = 0;
|
|
uint32_t n_input = 0;
|
|
uint32_t n_output = 0;
|
|
std::string name;
|
|
};
|
|
|
|
struct InferInput {
|
|
const void* data = nullptr;
|
|
size_t size = 0;
|
|
int width = 0;
|
|
int height = 0;
|
|
bool is_nhwc = true; // true: NHWC, false: NCHW
|
|
};
|
|
|
|
struct InferOutput {
|
|
std::vector<uint8_t> data;
|
|
size_t size = 0;
|
|
int index = 0;
|
|
#if defined(RK3588_ENABLE_RKNN)
|
|
rknn_tensor_type type = RKNN_TENSOR_UINT8;
|
|
int32_t zp = 0;
|
|
float scale = 1.0f;
|
|
std::vector<uint32_t> dims;
|
|
#endif
|
|
};
|
|
|
|
struct InferResult {
|
|
bool success = false;
|
|
std::string error;
|
|
std::vector<InferOutput> outputs;
|
|
};
|
|
|
|
// Callback for async inference (future use)
|
|
using InferCallback = std::function<void(const InferResult& result)>;
|
|
|
|
class AiScheduler {
|
|
public:
|
|
static AiScheduler& Instance();
|
|
|
|
// Prevent copy/move
|
|
AiScheduler(const AiScheduler&) = delete;
|
|
AiScheduler& operator=(const AiScheduler&) = delete;
|
|
|
|
// Load a model from file, returns handle (0 = invalid)
|
|
ModelHandle LoadModel(const std::string& model_path, std::string& err);
|
|
|
|
// Unload a model by handle
|
|
void UnloadModel(ModelHandle handle);
|
|
|
|
// Get model information
|
|
bool GetModelInfo(ModelHandle handle, ModelInfo& info) const;
|
|
|
|
// Synchronous inference
|
|
InferResult Infer(ModelHandle handle, const InferInput& input);
|
|
|
|
// Async inference (submits to internal queue, calls callback when done)
|
|
// For now, this is a simple wrapper around sync Infer
|
|
void InferAsync(ModelHandle handle, const InferInput& input, InferCallback callback);
|
|
|
|
// Get statistics
|
|
uint64_t GetTotalInferences() const { return total_inferences_.load(); }
|
|
uint64_t GetTotalErrors() const { return total_errors_.load(); }
|
|
|
|
// Shutdown scheduler (unload all models)
|
|
void Shutdown();
|
|
|
|
private:
|
|
AiScheduler();
|
|
~AiScheduler();
|
|
|
|
#if defined(RK3588_ENABLE_RKNN)
|
|
struct ModelContext {
|
|
rknn_context ctx = 0;
|
|
std::vector<uint8_t> model_data;
|
|
std::vector<rknn_tensor_attr> input_attrs;
|
|
std::vector<rknn_tensor_attr> output_attrs;
|
|
uint32_t n_input = 0;
|
|
uint32_t n_output = 0;
|
|
int input_w = 0;
|
|
int input_h = 0;
|
|
int input_c = 0;
|
|
std::string path;
|
|
std::mutex infer_mutex; // Per-model lock for inference
|
|
|
|
~ModelContext() {
|
|
if (ctx) {
|
|
rknn_destroy(ctx);
|
|
ctx = 0;
|
|
}
|
|
}
|
|
};
|
|
|
|
std::unordered_map<ModelHandle, std::shared_ptr<ModelContext>> models_;
|
|
#endif
|
|
|
|
mutable std::mutex models_mutex_; // Protects models_ map
|
|
std::atomic<ModelHandle> next_handle_{1};
|
|
std::atomic<uint64_t> total_inferences_{0};
|
|
std::atomic<uint64_t> total_errors_{0};
|
|
};
|
|
|
|
} // namespace rk3588
|