OrangePi3588Media/tests/test_ai_pose.cpp

233 lines
7.2 KiB
C++

#include <gtest/gtest.h>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <vector>
#include "hw/i_infer_backend.h"
#include "node.h"
#include "pose/pose_result.h"
#include "utils/simple_json.h"
#include "../plugins/ai_pose/ai_pose_node.h"
namespace rk3588 {
namespace {
SimpleJson ParsePoseConfig(const std::string& text) {
SimpleJson config;
std::string err;
const bool ok = ParseSimpleJson(text, config, err);
EXPECT_TRUE(ok);
return config;
}
class FakeInferBackend final : public IInferBackend {
public:
ModelHandle LoadModel(const std::string& /*model_path*/, std::string& /*err*/) override {
return 1;
}
void UnloadModel(ModelHandle /*handle*/) override {}
bool GetModelInfo(ModelHandle /*handle*/, ModelInfo& info) const override {
info.input_width = 640;
info.input_height = 640;
info.input_channels = 3;
info.n_input = 1;
info.n_output = 4;
info.name = "fake_yolov8_pose";
return true;
}
InferResult Infer(ModelHandle /*handle*/, const InferInput& /*input*/) override {
InferResult result;
result.success = true;
result.outputs = outputs_;
return result;
}
AiScheduler::BorrowedInferResult InferBorrowed(ModelHandle /*handle*/, const InferInput& /*input*/) override {
AiScheduler::BorrowedInferResult result;
result.success = false;
result.error = "not used in unit test";
return result;
}
std::vector<InferOutput> outputs_;
};
InferOutput MakeFloatOutput(const std::vector<float>& values) {
InferOutput out;
out.size = values.size() * sizeof(float);
out.data.resize(out.size);
std::memcpy(out.data.data(), values.data(), out.size);
return out;
}
TEST(AiPoseNodeTest, RejectsEnabledConfigWithoutModelPath) {
AiPoseNode node;
SimpleJson config = ParsePoseConfig(R"({
"id": "pose0",
"enabled": true
})");
NodeContext ctx;
EXPECT_FALSE(node.Init(config, ctx));
}
TEST(AiPoseNodeTest, DisabledNodeStartsAndPassesFramesThrough) {
AiPoseNode node;
SimpleJson config = ParsePoseConfig(R"({
"id": "pose0",
"enabled": false
})");
NodeContext ctx;
auto out = std::make_shared<SpscQueue<FramePtr>>(4, QueueDropStrategy::DropOldest);
ctx.output_queues.push_back(out);
ASSERT_TRUE(node.Init(config, ctx));
ASSERT_TRUE(node.Start());
auto frame = std::make_shared<Frame>();
frame->width = 1280;
frame->height = 720;
EXPECT_EQ(static_cast<int>(node.Process(frame)), static_cast<int>(NodeStatus::OK));
FramePtr forwarded;
ASSERT_TRUE(out->TryPop(forwarded));
EXPECT_EQ(forwarded.get(), frame.get());
EXPECT_EQ(frame->pose, nullptr);
}
TEST(AiPoseNodeTest, EnabledNodeRequiresInferBackendAtStart) {
AiPoseNode node;
SimpleJson config = ParsePoseConfig(R"({
"id": "pose0",
"enabled": true,
"model_path": "models/yolov8n-pose.rknn"
})");
NodeContext ctx;
ASSERT_TRUE(node.Init(config, ctx));
EXPECT_FALSE(node.Start());
}
TEST(AiPoseNodeTest, DecodePoseOutputsToOriginalCoordinates) {
const int num_keypoints = 17;
const int total_points = 8400;
std::vector<float> head0(static_cast<size_t>(65 * 80 * 80), -10.0f);
std::vector<float> head1(static_cast<size_t>(65 * 40 * 40), -10.0f);
std::vector<float> head2(static_cast<size_t>(65 * 20 * 20), -10.0f);
std::vector<float> keypoints(static_cast<size_t>(num_keypoints * 3 * total_points), 0.0f);
const int cell_x = 10;
const int cell_y = 15;
const int grid_w = 80;
const int point_index = cell_y * grid_w + cell_x;
auto set_dfl_bin = [&](int channel_group, int bin) {
for (int i = 0; i < 16; ++i) {
head0[static_cast<size_t>((channel_group * 16 + i) * grid_w * 80 + point_index)] = (i == bin) ? 10.0f : -10.0f;
}
};
set_dfl_bin(0, 4);
set_dfl_bin(1, 6);
set_dfl_bin(2, 8);
set_dfl_bin(3, 10);
head0[static_cast<size_t>(64 * grid_w * 80 + point_index)] = 10.0f;
for (int k = 0; k < num_keypoints; ++k) {
keypoints[static_cast<size_t>(k * 3 * total_points + point_index)] = 120.0f + static_cast<float>(k);
keypoints[static_cast<size_t>(k * 3 * total_points + total_points + point_index)] = 140.0f + static_cast<float>(k * 2);
keypoints[static_cast<size_t>(k * 3 * total_points + 2 * total_points + point_index)] = 0.9f;
}
auto backend = std::make_shared<FakeInferBackend>();
backend->outputs_ = {
MakeFloatOutput(head0),
MakeFloatOutput(head1),
MakeFloatOutput(head2),
MakeFloatOutput(keypoints),
};
AiPoseNode node;
SimpleJson config = ParsePoseConfig(R"({
"id": "pose0",
"enabled": true,
"model_path": "models/yolov8n-pose.rknn",
"model_input_w": 640,
"model_input_h": 640,
"conf_thresh": 0.25,
"nms_thresh": 0.45
})");
NodeContext ctx;
ctx.infer_backend = backend;
auto out = std::make_shared<SpscQueue<FramePtr>>(4, QueueDropStrategy::DropOldest);
ctx.output_queues.push_back(out);
ASSERT_TRUE(node.Init(config, ctx));
ASSERT_TRUE(node.Start());
auto rgb = std::make_shared<std::vector<uint8_t>>(static_cast<size_t>(320 * 320 * 3), 127);
auto frame = std::make_shared<Frame>();
frame->width = 320;
frame->height = 320;
frame->format = PixelFormat::RGB;
frame->data = rgb->data();
frame->data_size = rgb->size();
frame->stride = 320 * 3;
frame->plane_count = 1;
frame->planes[0] = {frame->data, frame->stride, static_cast<int>(frame->data_size), 0};
frame->data_owner = rgb;
EXPECT_EQ(static_cast<int>(node.Process(frame)), static_cast<int>(NodeStatus::OK));
ASSERT_NE(frame->pose, nullptr);
ASSERT_EQ(frame->pose->items.size(), 1u);
const PoseItem& item = frame->pose->items[0];
EXPECT_NEAR(item.bbox.x, 26.0f, 1.0f);
EXPECT_NEAR(item.bbox.y, 38.0f, 1.0f);
EXPECT_NEAR(item.bbox.w, 48.0f, 1.0f);
EXPECT_NEAR(item.bbox.h, 64.0f, 1.0f);
ASSERT_EQ(item.keypoints.size(), static_cast<size_t>(num_keypoints));
EXPECT_NEAR(item.keypoints[0].point.x, 60.0f, 1.0f);
EXPECT_NEAR(item.keypoints[0].point.y, 70.0f, 1.0f);
EXPECT_NEAR(item.keypoints[16].point.x, 68.0f, 1.0f);
EXPECT_NEAR(item.keypoints[16].point.y, 86.0f, 1.0f);
FramePtr forwarded;
ASSERT_TRUE(out->TryPop(forwarded));
EXPECT_EQ(forwarded.get(), frame.get());
}
TEST(PoseResultTest, SupportsPerTrackKeypoints) {
PoseResult result;
result.img_w = 1920;
result.img_h = 1080;
result.model_name = "pose_model";
PoseItem item;
item.track_id = 42;
item.score = 0.93f;
item.bbox = Rect{100.0f, 200.0f, 300.0f, 400.0f};
item.keypoints.push_back(PoseKeypoint{PosePoint2f{120.0f, 220.0f}, 0.99f});
item.keypoints.push_back(PoseKeypoint{PosePoint2f{140.0f, 260.0f}, 0.95f});
result.items.push_back(item);
ASSERT_EQ(result.items.size(), 1u);
EXPECT_EQ(result.items[0].track_id, 42);
ASSERT_EQ(result.items[0].keypoints.size(), 2u);
EXPECT_FLOAT_EQ(result.items[0].keypoints[0].point.x, 120.0f);
EXPECT_FLOAT_EQ(result.items[0].keypoints[1].score, 0.95f);
}
} // namespace
} // namespace rk3588