yolo_standard_libray/tensorrtx-master/scaled-yolov4/mish.cu
2025-03-07 11:35:40 +08:00

197 lines
5.6 KiB
Plaintext

#include <cmath>
#include <stdio.h>
#include <cassert>
#include <iostream>
#include "mish.h"
namespace nvinfer1
{
MishPlugin::MishPlugin()
{
}
MishPlugin::~MishPlugin()
{
}
// create the plugin at runtime from a byte stream
MishPlugin::MishPlugin(const void* data, size_t length)
{
assert(length == sizeof(input_size_));
input_size_ = *reinterpret_cast<const int*>(data);
}
void MishPlugin::serialize(void* buffer) const
{
*reinterpret_cast<int*>(buffer) = input_size_;
}
size_t MishPlugin::getSerializationSize() const
{
return sizeof(input_size_);
}
int MishPlugin::initialize()
{
return 0;
}
Dims MishPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
{
assert(nbInputDims == 1);
assert(index == 0);
input_size_ = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2];
// Output dimensions
return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
}
// Set plugin namespace
void MishPlugin::setPluginNamespace(const char* pluginNamespace)
{
mPluginNamespace = pluginNamespace;
}
const char* MishPlugin::getPluginNamespace() const
{
return mPluginNamespace;
}
// Return the DataType of the plugin output at the requested index
DataType MishPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
{
return DataType::kFLOAT;
}
// Return true if output tensor is broadcast across a batch.
bool MishPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
{
return false;
}
// Return true if plugin can use input that is broadcast across batch without replication.
bool MishPlugin::canBroadcastInputAcrossBatch(int inputIndex) const
{
return false;
}
void MishPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput)
{
}
// Attach the plugin object to an execution context and grant the plugin the access to some context resource.
void MishPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
{
}
// Detach the plugin object from its execution context.
void MishPlugin::detachFromContext() {}
const char* MishPlugin::getPluginType() const
{
return "Mish_TRT";
}
const char* MishPlugin::getPluginVersion() const
{
return "1";
}
void MishPlugin::destroy()
{
delete this;
}
// Clone the plugin
IPluginV2IOExt* MishPlugin::clone() const
{
MishPlugin *p = new MishPlugin();
p->input_size_ = input_size_;
p->setPluginNamespace(mPluginNamespace);
return p;
}
__device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);}
__device__ float softplus_kernel(float x, float threshold = 20) {
if (x > threshold) return x; // too large
else if (x < -threshold) return expf(x); // too small
return logf(expf(x) + 1);
}
__global__ void mish_kernel(const float *input, float *output, int num_elem) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx >= num_elem) return;
//float t = exp(input[idx]);
//if (input[idx] > 20.0) {
// t *= t;
// output[idx] = (t - 1.0) / (t + 1.0);
//} else {
// float tt = t * t;
// output[idx] = (tt + 2.0 * t) / (tt + 2.0 * t + 2.0);
//}
//output[idx] *= input[idx];
output[idx] = input[idx] * tanh_activate_kernel(softplus_kernel(input[idx]));
}
void MishPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
int block_size = thread_count_;
int grid_size = (input_size_ * batchSize + block_size - 1) / block_size;
mish_kernel<<<grid_size, block_size>>>(inputs[0], output, input_size_ * batchSize);
}
int MishPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
{
//assert(batchSize == 1);
//GPU
//CUDA_CHECK(cudaStreamSynchronize(stream));
forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize);
return 0;
}
PluginFieldCollection MishPluginCreator::mFC{};
std::vector<PluginField> MishPluginCreator::mPluginAttributes;
MishPluginCreator::MishPluginCreator()
{
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char* MishPluginCreator::getPluginName() const
{
return "Mish_TRT";
}
const char* MishPluginCreator::getPluginVersion() const
{
return "1";
}
const PluginFieldCollection* MishPluginCreator::getFieldNames()
{
return &mFC;
}
IPluginV2IOExt* MishPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
{
MishPlugin* obj = new MishPlugin();
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
IPluginV2IOExt* MishPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
{
// This object will be deleted when the network is destroyed, which will
// call MishPlugin::destroy()
MishPlugin* obj = new MishPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
}