233 lines
7.2 KiB
C++
233 lines
7.2 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "hw/i_infer_backend.h"
|
|
#include "node.h"
|
|
#include "pose/pose_result.h"
|
|
#include "utils/simple_json.h"
|
|
#include "../plugins/ai_pose/ai_pose_node.h"
|
|
|
|
namespace rk3588 {
|
|
namespace {
|
|
|
|
SimpleJson ParsePoseConfig(const std::string& text) {
|
|
SimpleJson config;
|
|
std::string err;
|
|
const bool ok = ParseSimpleJson(text, config, err);
|
|
EXPECT_TRUE(ok);
|
|
return config;
|
|
}
|
|
|
|
class FakeInferBackend final : public IInferBackend {
|
|
public:
|
|
ModelHandle LoadModel(const std::string& /*model_path*/, std::string& /*err*/) override {
|
|
return 1;
|
|
}
|
|
|
|
void UnloadModel(ModelHandle /*handle*/) override {}
|
|
|
|
bool GetModelInfo(ModelHandle /*handle*/, ModelInfo& info) const override {
|
|
info.input_width = 640;
|
|
info.input_height = 640;
|
|
info.input_channels = 3;
|
|
info.n_input = 1;
|
|
info.n_output = 4;
|
|
info.name = "fake_yolov8_pose";
|
|
return true;
|
|
}
|
|
|
|
InferResult Infer(ModelHandle /*handle*/, const InferInput& /*input*/) override {
|
|
InferResult result;
|
|
result.success = true;
|
|
result.outputs = outputs_;
|
|
return result;
|
|
}
|
|
|
|
AiScheduler::BorrowedInferResult InferBorrowed(ModelHandle /*handle*/, const InferInput& /*input*/) override {
|
|
AiScheduler::BorrowedInferResult result;
|
|
result.success = false;
|
|
result.error = "not used in unit test";
|
|
return result;
|
|
}
|
|
|
|
std::vector<InferOutput> outputs_;
|
|
};
|
|
|
|
InferOutput MakeFloatOutput(const std::vector<float>& values) {
|
|
InferOutput out;
|
|
out.size = values.size() * sizeof(float);
|
|
out.data.resize(out.size);
|
|
std::memcpy(out.data.data(), values.data(), out.size);
|
|
return out;
|
|
}
|
|
|
|
TEST(AiPoseNodeTest, RejectsEnabledConfigWithoutModelPath) {
|
|
AiPoseNode node;
|
|
SimpleJson config = ParsePoseConfig(R"({
|
|
"id": "pose0",
|
|
"enabled": true
|
|
})");
|
|
|
|
NodeContext ctx;
|
|
EXPECT_FALSE(node.Init(config, ctx));
|
|
}
|
|
|
|
TEST(AiPoseNodeTest, DisabledNodeStartsAndPassesFramesThrough) {
|
|
AiPoseNode node;
|
|
SimpleJson config = ParsePoseConfig(R"({
|
|
"id": "pose0",
|
|
"enabled": false
|
|
})");
|
|
|
|
NodeContext ctx;
|
|
auto out = std::make_shared<SpscQueue<FramePtr>>(4, QueueDropStrategy::DropOldest);
|
|
ctx.output_queues.push_back(out);
|
|
|
|
ASSERT_TRUE(node.Init(config, ctx));
|
|
ASSERT_TRUE(node.Start());
|
|
|
|
auto frame = std::make_shared<Frame>();
|
|
frame->width = 1280;
|
|
frame->height = 720;
|
|
|
|
EXPECT_EQ(static_cast<int>(node.Process(frame)), static_cast<int>(NodeStatus::OK));
|
|
|
|
FramePtr forwarded;
|
|
ASSERT_TRUE(out->TryPop(forwarded));
|
|
EXPECT_EQ(forwarded.get(), frame.get());
|
|
EXPECT_EQ(frame->pose, nullptr);
|
|
}
|
|
|
|
TEST(AiPoseNodeTest, EnabledNodeRequiresInferBackendAtStart) {
|
|
AiPoseNode node;
|
|
SimpleJson config = ParsePoseConfig(R"({
|
|
"id": "pose0",
|
|
"enabled": true,
|
|
"model_path": "models/yolov8n-pose.rknn"
|
|
})");
|
|
|
|
NodeContext ctx;
|
|
ASSERT_TRUE(node.Init(config, ctx));
|
|
EXPECT_FALSE(node.Start());
|
|
}
|
|
|
|
TEST(AiPoseNodeTest, DecodePoseOutputsToOriginalCoordinates) {
|
|
const int num_keypoints = 17;
|
|
const int total_points = 8400;
|
|
|
|
std::vector<float> head0(static_cast<size_t>(65 * 80 * 80), -10.0f);
|
|
std::vector<float> head1(static_cast<size_t>(65 * 40 * 40), -10.0f);
|
|
std::vector<float> head2(static_cast<size_t>(65 * 20 * 20), -10.0f);
|
|
std::vector<float> keypoints(static_cast<size_t>(num_keypoints * 3 * total_points), 0.0f);
|
|
|
|
const int cell_x = 10;
|
|
const int cell_y = 15;
|
|
const int grid_w = 80;
|
|
const int point_index = cell_y * grid_w + cell_x;
|
|
|
|
auto set_dfl_bin = [&](int channel_group, int bin) {
|
|
for (int i = 0; i < 16; ++i) {
|
|
head0[static_cast<size_t>((channel_group * 16 + i) * grid_w * 80 + point_index)] = (i == bin) ? 10.0f : -10.0f;
|
|
}
|
|
};
|
|
|
|
set_dfl_bin(0, 4);
|
|
set_dfl_bin(1, 6);
|
|
set_dfl_bin(2, 8);
|
|
set_dfl_bin(3, 10);
|
|
head0[static_cast<size_t>(64 * grid_w * 80 + point_index)] = 10.0f;
|
|
|
|
for (int k = 0; k < num_keypoints; ++k) {
|
|
keypoints[static_cast<size_t>(k * 3 * total_points + point_index)] = 120.0f + static_cast<float>(k);
|
|
keypoints[static_cast<size_t>(k * 3 * total_points + total_points + point_index)] = 140.0f + static_cast<float>(k * 2);
|
|
keypoints[static_cast<size_t>(k * 3 * total_points + 2 * total_points + point_index)] = 0.9f;
|
|
}
|
|
|
|
auto backend = std::make_shared<FakeInferBackend>();
|
|
backend->outputs_ = {
|
|
MakeFloatOutput(head0),
|
|
MakeFloatOutput(head1),
|
|
MakeFloatOutput(head2),
|
|
MakeFloatOutput(keypoints),
|
|
};
|
|
|
|
AiPoseNode node;
|
|
SimpleJson config = ParsePoseConfig(R"({
|
|
"id": "pose0",
|
|
"enabled": true,
|
|
"model_path": "models/yolov8n-pose.rknn",
|
|
"model_input_w": 640,
|
|
"model_input_h": 640,
|
|
"conf_thresh": 0.25,
|
|
"nms_thresh": 0.45
|
|
})");
|
|
|
|
NodeContext ctx;
|
|
ctx.infer_backend = backend;
|
|
auto out = std::make_shared<SpscQueue<FramePtr>>(4, QueueDropStrategy::DropOldest);
|
|
ctx.output_queues.push_back(out);
|
|
|
|
ASSERT_TRUE(node.Init(config, ctx));
|
|
ASSERT_TRUE(node.Start());
|
|
|
|
auto rgb = std::make_shared<std::vector<uint8_t>>(static_cast<size_t>(320 * 320 * 3), 127);
|
|
auto frame = std::make_shared<Frame>();
|
|
frame->width = 320;
|
|
frame->height = 320;
|
|
frame->format = PixelFormat::RGB;
|
|
frame->data = rgb->data();
|
|
frame->data_size = rgb->size();
|
|
frame->stride = 320 * 3;
|
|
frame->plane_count = 1;
|
|
frame->planes[0] = {frame->data, frame->stride, static_cast<int>(frame->data_size), 0};
|
|
frame->data_owner = rgb;
|
|
|
|
EXPECT_EQ(static_cast<int>(node.Process(frame)), static_cast<int>(NodeStatus::OK));
|
|
|
|
ASSERT_NE(frame->pose, nullptr);
|
|
ASSERT_EQ(frame->pose->items.size(), 1u);
|
|
const PoseItem& item = frame->pose->items[0];
|
|
EXPECT_NEAR(item.bbox.x, 26.0f, 1.0f);
|
|
EXPECT_NEAR(item.bbox.y, 38.0f, 1.0f);
|
|
EXPECT_NEAR(item.bbox.w, 48.0f, 1.0f);
|
|
EXPECT_NEAR(item.bbox.h, 64.0f, 1.0f);
|
|
ASSERT_EQ(item.keypoints.size(), static_cast<size_t>(num_keypoints));
|
|
EXPECT_NEAR(item.keypoints[0].point.x, 60.0f, 1.0f);
|
|
EXPECT_NEAR(item.keypoints[0].point.y, 70.0f, 1.0f);
|
|
EXPECT_NEAR(item.keypoints[16].point.x, 68.0f, 1.0f);
|
|
EXPECT_NEAR(item.keypoints[16].point.y, 86.0f, 1.0f);
|
|
|
|
FramePtr forwarded;
|
|
ASSERT_TRUE(out->TryPop(forwarded));
|
|
EXPECT_EQ(forwarded.get(), frame.get());
|
|
}
|
|
|
|
TEST(PoseResultTest, SupportsPerTrackKeypoints) {
|
|
PoseResult result;
|
|
result.img_w = 1920;
|
|
result.img_h = 1080;
|
|
result.model_name = "pose_model";
|
|
|
|
PoseItem item;
|
|
item.track_id = 42;
|
|
item.score = 0.93f;
|
|
item.bbox = Rect{100.0f, 200.0f, 300.0f, 400.0f};
|
|
item.keypoints.push_back(PoseKeypoint{PosePoint2f{120.0f, 220.0f}, 0.99f});
|
|
item.keypoints.push_back(PoseKeypoint{PosePoint2f{140.0f, 260.0f}, 0.95f});
|
|
result.items.push_back(item);
|
|
|
|
ASSERT_EQ(result.items.size(), 1u);
|
|
EXPECT_EQ(result.items[0].track_id, 42);
|
|
ASSERT_EQ(result.items[0].keypoints.size(), 2u);
|
|
EXPECT_FLOAT_EQ(result.items[0].keypoints[0].point.x, 120.0f);
|
|
EXPECT_FLOAT_EQ(result.items[0].keypoints[1].score, 0.95f);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace rk3588
|