rtsp_tensorrt/tests/test_cuda_helper.cpp
sladro e13cb3659c feat: 初始化项目结构
- 创建基本项目结构和目录
- 添加CMake构建系统
- 实现基础的配置解析功能
- 添加YOLO推理框架支持
- 集成RTSP和视频流处理功能
- 添加性能监控和日志系统
2024-12-24 16:25:03 +08:00

168 lines
5.5 KiB
C++

#include <gtest/gtest.h>
#include "pipeline/inference/cuda_helper.hpp"
#include <thread>
#include <chrono>
#include <cstring>
#include <iostream>
using namespace pipeline;
using namespace pipeline::detail;
class CudaHelperTest : public ::testing::Test {
protected:
void SetUp() override {
// 检查CUDA设备
int device_count;
cudaError_t err = cudaGetDeviceCount(&device_count);
if (err != cudaSuccess) {
std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl;
}
ASSERT_EQ(err, cudaSuccess);
ASSERT_GT(device_count, 0);
// 设置设备
err = cudaSetDevice(0);
if (err != cudaSuccess) {
std::cerr << "CUDA Error: " << cudaGetErrorString(err) << std::endl;
}
ASSERT_EQ(err, cudaSuccess);
}
void TearDown() override {
cudaDeviceSynchronize();
cudaDeviceReset();
}
};
// 测试CUDA流
TEST_F(CudaHelperTest, CudaStream) {
// 创建流
CudaStream stream;
EXPECT_NE(stream.get(), nullptr);
// 测试同步
EXPECT_TRUE(stream.sync());
}
// 测试CUDA内存缓冲区
TEST_F(CudaHelperTest, CudaBuffer) {
const size_t size = 1024; // 1KB
const size_t num_elements = size / sizeof(float);
try {
// 创建缓冲区
CudaBuffer buffer(size);
ASSERT_NE(buffer.devicePtr(), nullptr) << "Device memory allocation failed";
ASSERT_NE(buffer.hostPtr(), nullptr) << "Host memory allocation failed";
ASSERT_EQ(buffer.size(), size);
// 测试内存拷贝
float* host_data = static_cast<float*>(buffer.hostPtr());
for (size_t i = 0; i < num_elements; ++i) {
host_data[i] = static_cast<float>(i);
}
// 主机到设备
ASSERT_TRUE(buffer.copyH2D()) << "Host to Device copy failed";
// 清空主机内存
memset(host_data, 0, size);
// 验证主机内存已清空
for (size_t i = 0; i < num_elements; ++i) {
ASSERT_FLOAT_EQ(host_data[i], 0.0f) << "Host memory clear failed at index " << i;
}
// 设备到主机
ASSERT_TRUE(buffer.copyD2H()) << "Device to Host copy failed";
// 验证数据
for (size_t i = 0; i < num_elements; ++i) {
if (host_data[i] != static_cast<float>(i)) {
std::cerr << "Data mismatch at index " << i
<< ": expected " << static_cast<float>(i)
<< ", got " << host_data[i] << std::endl;
}
ASSERT_FLOAT_EQ(host_data[i], static_cast<float>(i))
<< "Data verification failed at index " << i;
}
// 同步设备
ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess) << "Device synchronization failed";
} catch (const std::exception& e) {
FAIL() << "Exception caught: " << e.what();
}
}
// 测试异步操作
TEST_F(CudaHelperTest, AsyncOperations) {
const size_t size = 1024; // 减小测试数据大小
const size_t num_elements = size / sizeof(float);
try {
CudaStream stream;
CudaBuffer buffer(size);
// 填充测试数据
float* host_data = static_cast<float*>(buffer.hostPtr());
for (size_t i = 0; i < num_elements; ++i) {
host_data[i] = static_cast<float>(i);
}
// 异步拷贝
ASSERT_TRUE(buffer.copyH2D(stream.get())) << "Async H2D copy failed";
ASSERT_TRUE(buffer.copyD2H(stream.get())) << "Async D2H copy failed";
// 同步并验证
ASSERT_TRUE(stream.sync()) << "Stream synchronization failed";
for (size_t i = 0; i < num_elements; ++i) {
ASSERT_FLOAT_EQ(host_data[i], static_cast<float>(i))
<< "Async data verification failed at index " << i;
}
} catch (const std::exception& e) {
FAIL() << "Exception caught: " << e.what();
}
}
// 测试TensorRT日志器
TEST_F(CudaHelperTest, Logger) {
pipeline::detail::Logger logger;
// 测试不同级别的日志
logger.log(nvinfer1::ILogger::Severity::kINFO, "Info message");
logger.log(nvinfer1::ILogger::Severity::kWARNING, "Warning message");
logger.log(nvinfer1::ILogger::Severity::kERROR, "Error message");
}
// 测试多线程场景
TEST_F(CudaHelperTest, MultiThread) {
const size_t num_threads = 2; // 减少线程数
const size_t size = 1024; // 减小每个线程的数据大小
std::vector<std::thread> threads;
try {
for (size_t i = 0; i < num_threads; ++i) {
threads.emplace_back([size]() {
CudaStream stream;
CudaBuffer buffer(size);
float* host_data = static_cast<float*>(buffer.hostPtr());
for (size_t j = 0; j < size / sizeof(float); ++j) {
host_data[j] = static_cast<float>(j);
}
EXPECT_TRUE(buffer.copyH2D(stream.get()));
EXPECT_TRUE(buffer.copyD2H(stream.get()));
EXPECT_TRUE(stream.sync());
});
}
// 等待所有线程完成
for (auto& thread : threads) {
thread.join();
}
} catch (const std::exception& e) {
FAIL() << "Exception caught in thread: " << e.what();
}
}