303 lines
11 KiB
C++
303 lines
11 KiB
C++
#include "NvInfer.h"
|
|
#include "cuda_runtime_api.h"
|
|
#include "logging.h"
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <map>
|
|
#include <sstream>
|
|
#include <vector>
|
|
#include <chrono>
|
|
|
|
#define CHECK(status) \
|
|
do\
|
|
{\
|
|
auto ret = (status);\
|
|
if (ret != 0)\
|
|
{\
|
|
std::cerr << "Cuda failure: " << ret << std::endl;\
|
|
abort();\
|
|
}\
|
|
} while (0)
|
|
|
|
// stuff we know about the network and the input/output blobs
|
|
static const int INPUT_H = 227;
|
|
static const int INPUT_W = 227;
|
|
static const int OUTPUT_SIZE = 1000;
|
|
|
|
const char* INPUT_BLOB_NAME = "data";
|
|
const char* OUTPUT_BLOB_NAME = "prob";
|
|
|
|
using namespace nvinfer1;
|
|
|
|
static Logger gLogger;
|
|
|
|
// Load weights from files shared with TensorRT samples.
|
|
// TensorRT weight files have a simple space delimited format:
|
|
// [type] [size] <data x size in hex>
|
|
std::map<std::string, Weights> loadWeights(const std::string file)
|
|
{
|
|
std::cout << "Loading weights: " << file << std::endl;
|
|
std::map<std::string, Weights> weightMap;
|
|
|
|
// Open weights file
|
|
std::ifstream input(file);
|
|
assert(input.is_open() && "Unable to load weight file.");
|
|
|
|
// Read number of weight blobs
|
|
int32_t count;
|
|
input >> count;
|
|
assert(count > 0 && "Invalid weight map file.");
|
|
|
|
while (count--)
|
|
{
|
|
Weights wt{DataType::kFLOAT, nullptr, 0};
|
|
uint32_t size;
|
|
|
|
// Read name and type of blob
|
|
std::string name;
|
|
input >> name >> std::dec >> size;
|
|
wt.type = DataType::kFLOAT;
|
|
|
|
// Load blob
|
|
uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
|
|
for (uint32_t x = 0, y = size; x < y; ++x)
|
|
{
|
|
input >> std::hex >> val[x];
|
|
}
|
|
wt.values = val;
|
|
|
|
wt.count = size;
|
|
weightMap[name] = wt;
|
|
}
|
|
|
|
return weightMap;
|
|
}
|
|
|
|
IConcatenationLayer* fire(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, std::string lname,
|
|
int squeeze_planes, int e1x1_planes, int e3x3_planes) {
|
|
IConvolutionLayer* conv1 = network->addConvolutionNd(input, squeeze_planes, DimsHW{1, 1}, weightMap[lname + "squeeze.weight"], weightMap[lname + "squeeze.bias"]);
|
|
assert(conv1);
|
|
IActivationLayer* relu1 = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU);
|
|
assert(relu1);
|
|
|
|
IConvolutionLayer* conv2 = network->addConvolutionNd(*relu1->getOutput(0), e1x1_planes, DimsHW{1, 1}, weightMap[lname + "expand1x1.weight"], weightMap[lname + "expand1x1.bias"]);
|
|
assert(conv2);
|
|
IActivationLayer* relu2 = network->addActivation(*conv2->getOutput(0), ActivationType::kRELU);
|
|
assert(relu2);
|
|
|
|
IConvolutionLayer* conv3 = network->addConvolutionNd(*relu1->getOutput(0), e3x3_planes, DimsHW{3, 3}, weightMap[lname + "expand3x3.weight"], weightMap[lname + "expand3x3.bias"]);
|
|
assert(conv3);
|
|
conv3->setPaddingNd(DimsHW{1, 1});
|
|
IActivationLayer* relu3 = network->addActivation(*conv3->getOutput(0), ActivationType::kRELU);
|
|
assert(relu3);
|
|
|
|
ITensor* inputTensors[] = {relu2->getOutput(0), relu3->getOutput(0)};
|
|
IConcatenationLayer* cat1 = network->addConcatenation(inputTensors, 2);
|
|
assert(cat1);
|
|
return cat1;
|
|
}
|
|
|
|
// Creat the engine using only the API and not any parser.
|
|
ICudaEngine* createEngine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt)
|
|
{
|
|
INetworkDefinition* network = builder->createNetworkV2(0U);
|
|
|
|
// Create input tensor of shape { 3, INPUT_H, INPUT_W } with name INPUT_BLOB_NAME
|
|
ITensor* data = network->addInput(INPUT_BLOB_NAME, dt, Dims3{3, INPUT_H, INPUT_W});
|
|
assert(data);
|
|
|
|
std::map<std::string, Weights> weightMap = loadWeights("../squeezenet.wts");
|
|
Weights emptywts{DataType::kFLOAT, nullptr, 0};
|
|
|
|
IConvolutionLayer* conv1 = network->addConvolutionNd(*data, 64, DimsHW{3, 3}, weightMap["features.0.weight"], weightMap["features.0.bias"]);
|
|
assert(conv1);
|
|
conv1->setStrideNd(DimsHW{2, 2});
|
|
IActivationLayer* relu1 = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU);
|
|
assert(relu1);
|
|
IPoolingLayer* pool1 = network->addPoolingNd(*relu1->getOutput(0), PoolingType::kMAX, DimsHW{3, 3});
|
|
assert(pool1);
|
|
pool1->setStrideNd(DimsHW{2, 2});
|
|
|
|
IConcatenationLayer* cat1 = fire(network, weightMap, *pool1->getOutput(0), "features.3.", 16, 64, 64);
|
|
cat1 = fire(network, weightMap, *cat1->getOutput(0), "features.4.", 16, 64, 64);
|
|
|
|
IPoolingLayer* pool2 = network->addPoolingNd(*cat1->getOutput(0), PoolingType::kMAX, DimsHW{3, 3});
|
|
assert(pool2);
|
|
pool2->setStrideNd(DimsHW{2, 2});
|
|
pool2->setPostPadding(DimsHW{1, 1});
|
|
|
|
cat1 = fire(network, weightMap, *pool2->getOutput(0), "features.6.", 32, 128, 128);
|
|
cat1 = fire(network, weightMap, *cat1->getOutput(0), "features.7.", 32, 128, 128);
|
|
|
|
pool2 = network->addPoolingNd(*cat1->getOutput(0), PoolingType::kMAX, DimsHW{3, 3});
|
|
assert(pool2);
|
|
pool2->setStrideNd(DimsHW{2, 2});
|
|
pool2->setPostPadding(DimsHW{1, 1});
|
|
|
|
cat1 = fire(network, weightMap, *pool2->getOutput(0), "features.9.", 48, 192, 192);
|
|
cat1 = fire(network, weightMap, *cat1->getOutput(0), "features.10.", 48, 192, 192);
|
|
cat1 = fire(network, weightMap, *cat1->getOutput(0), "features.11.", 64, 256, 256);
|
|
cat1 = fire(network, weightMap, *cat1->getOutput(0), "features.12.", 64, 256, 256);
|
|
|
|
IConvolutionLayer* conv2 = network->addConvolutionNd(*cat1->getOutput(0), 1000, DimsHW{1, 1}, weightMap["classifier.1.weight"], weightMap["classifier.1.bias"]);
|
|
assert(conv2);
|
|
IActivationLayer* relu2 = network->addActivation(*conv2->getOutput(0), ActivationType::kRELU);
|
|
assert(relu2);
|
|
IPoolingLayer* pool3 = network->addPoolingNd(*relu2->getOutput(0), PoolingType::kAVERAGE, DimsHW{14, 14});
|
|
assert(pool3);
|
|
|
|
pool3->getOutput(0)->setName(OUTPUT_BLOB_NAME);
|
|
std::cout << "set name out" << std::endl;
|
|
network->markOutput(*pool3->getOutput(0));
|
|
|
|
// Build engine
|
|
builder->setMaxBatchSize(maxBatchSize);
|
|
config->setMaxWorkspaceSize(1 << 20);
|
|
ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
|
|
std::cout << "build out" << std::endl;
|
|
|
|
// Don't need the network any more
|
|
network->destroy();
|
|
|
|
// Release host memory
|
|
for (auto& mem : weightMap)
|
|
{
|
|
free((void*) (mem.second.values));
|
|
}
|
|
|
|
return engine;
|
|
}
|
|
|
|
void APIToModel(unsigned int maxBatchSize, IHostMemory** modelStream)
|
|
{
|
|
// Create builder
|
|
IBuilder* builder = createInferBuilder(gLogger);
|
|
IBuilderConfig* config = builder->createBuilderConfig();
|
|
|
|
// Create model to populate the network, then set the outputs and create an engine
|
|
ICudaEngine* engine = createEngine(maxBatchSize, builder, config, DataType::kFLOAT);
|
|
assert(engine != nullptr);
|
|
|
|
// Serialize the engine
|
|
(*modelStream) = engine->serialize();
|
|
|
|
// Close everything down
|
|
engine->destroy();
|
|
builder->destroy();
|
|
config->destroy();
|
|
}
|
|
|
|
void doInference(IExecutionContext& context, float* input, float* output, int batchSize)
|
|
{
|
|
const ICudaEngine& engine = context.getEngine();
|
|
|
|
// Pointers to input and output device buffers to pass to engine.
|
|
// Engine requires exactly IEngine::getNbBindings() number of buffers.
|
|
assert(engine.getNbBindings() == 2);
|
|
void* buffers[2];
|
|
|
|
// In order to bind the buffers, we need to know the names of the input and output tensors.
|
|
// Note that indices are guaranteed to be less than IEngine::getNbBindings()
|
|
const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
|
|
const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);
|
|
|
|
// Create GPU buffers on device
|
|
CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 3 * INPUT_H * INPUT_W * sizeof(float)));
|
|
CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float)));
|
|
|
|
// Create stream
|
|
cudaStream_t stream;
|
|
CHECK(cudaStreamCreate(&stream));
|
|
|
|
// DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
|
|
CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream));
|
|
context.enqueue(batchSize, buffers, stream, nullptr);
|
|
CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
|
|
cudaStreamSynchronize(stream);
|
|
|
|
// Release stream and buffers
|
|
cudaStreamDestroy(stream);
|
|
CHECK(cudaFree(buffers[inputIndex]));
|
|
CHECK(cudaFree(buffers[outputIndex]));
|
|
}
|
|
|
|
int main(int argc, char** argv)
|
|
{
|
|
if (argc != 2) {
|
|
std::cerr << "arguments not right!" << std::endl;
|
|
std::cerr << "./squeezenet -s // serialize model to plan file" << std::endl;
|
|
std::cerr << "./squeezenet -d // deserialize plan file and run inference" << std::endl;
|
|
return -1;
|
|
}
|
|
|
|
// create a model using the API directly and serialize it to a stream
|
|
char *trtModelStream{nullptr};
|
|
size_t size{0};
|
|
|
|
if (std::string(argv[1]) == "-s") {
|
|
IHostMemory* modelStream{nullptr};
|
|
APIToModel(1, &modelStream);
|
|
assert(modelStream != nullptr);
|
|
|
|
std::ofstream p("squeezenet.engine", std::ios::binary);
|
|
if (!p) {
|
|
std::cerr << "could not open plan output file" << std::endl;
|
|
return -1;
|
|
}
|
|
p.write(reinterpret_cast<const char*>(modelStream->data()), modelStream->size());
|
|
modelStream->destroy();
|
|
return 1;
|
|
} else if (std::string(argv[1]) == "-d") {
|
|
std::ifstream file("squeezenet.engine", std::ios::binary);
|
|
if (file.good()) {
|
|
file.seekg(0, file.end);
|
|
size = file.tellg();
|
|
file.seekg(0, file.beg);
|
|
trtModelStream = new char[size];
|
|
assert(trtModelStream);
|
|
file.read(trtModelStream, size);
|
|
file.close();
|
|
}
|
|
} else {
|
|
return -1;
|
|
}
|
|
|
|
|
|
static float data[3 * INPUT_H * INPUT_W];
|
|
for (int i = 0; i < 3 * INPUT_H * INPUT_W; i++)
|
|
data[i] = 1.0;
|
|
|
|
IRuntime* runtime = createInferRuntime(gLogger);
|
|
assert(runtime != nullptr);
|
|
ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size);
|
|
assert(engine != nullptr);
|
|
IExecutionContext* context = engine->createExecutionContext();
|
|
assert(context != nullptr);
|
|
delete[] trtModelStream;
|
|
|
|
// Run inference
|
|
static float prob[OUTPUT_SIZE];
|
|
for (int i = 0; i < 10; i++) {
|
|
auto start = std::chrono::system_clock::now();
|
|
doInference(*context, data, prob, 1);
|
|
auto end = std::chrono::system_clock::now();
|
|
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() << "us" << std::endl;
|
|
}
|
|
|
|
// Destroy the engine
|
|
context->destroy();
|
|
engine->destroy();
|
|
runtime->destroy();
|
|
|
|
// Print histogram of the output distribution
|
|
std::cout << "\nOutput:\n\n";
|
|
for (unsigned int i = 0; i < OUTPUT_SIZE; i++)
|
|
{
|
|
std::cout << prob[i] << ", ";
|
|
if (i % 10 == 0) std::cout << i / 10 << std::endl;
|
|
}
|
|
std::cout << std::endl;
|
|
|
|
return 0;
|
|
}
|