39 lines
1.2 KiB
C++
39 lines
1.2 KiB
C++
#ifndef __TRT_UTILS_H_
|
|
#define __TRT_UTILS_H_
|
|
|
|
#include <iostream>
|
|
#include <vector>
|
|
#include <algorithm>
|
|
#include <cudnn.h>
|
|
|
|
#ifndef CUDA_CHECK
|
|
|
|
#define CUDA_CHECK(callstr) \
|
|
{ \
|
|
cudaError_t error_code = callstr; \
|
|
if (error_code != cudaSuccess) { \
|
|
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
|
|
assert(0); \
|
|
} \
|
|
}
|
|
|
|
#endif
|
|
|
|
namespace Tn
|
|
{
|
|
template<typename T>
|
|
void write(char*& buffer, const T& val)
|
|
{
|
|
*reinterpret_cast<T*>(buffer) = val;
|
|
buffer += sizeof(T);
|
|
}
|
|
|
|
template<typename T>
|
|
void read(const char*& buffer, T& val)
|
|
{
|
|
val = *reinterpret_cast<const T*>(buffer);
|
|
buffer += sizeof(T);
|
|
}
|
|
}
|
|
|
|
#endif |