> 文章列表 > 一个简单的tensorRT mnist推理案例,模型采用代码构建

一个简单的tensorRT mnist推理案例,模型采用代码构建

一个简单的tensorRT mnist推理案例,模型采用代码构建

TensorRT是NVIDIA的一个深度神经网络推理引擎,可以对深度学习模型进行优化和部署。本程序中,使用了TensorRT来加载一个已经训练好的模型并进行推理。

TRTLogger是一个日志记录类,用于记录TensorRT的运行日志。

Matrix是一个矩阵结构体,用于存储模型权重和输入输出数据。Model是一个模型结构体,用于存储加载的模型。
print_image函数用于将图像的像素值打印出来,方便调试和查看。load_file函数用于从文件中加载数据,包括模型权重和输入图像数据。load_model函数用于加载模型权重,其中模型权重的文件名按照"[index].weight"的格式命名,index从0开始递增。模型权重的形状是预先定义好的,存储在weight_shapes数组中,其中weight_shapes[i][0]表示第i层权重的行数,weight_shapes[i][1]表示第i层权重的列数。

这些函数都是为了方便程序的编写和调试,可以根据具体的应用场景进行修改和扩展。
它包括将BMP格式的图像数据转换为适合输入神经网络的矩阵的函数,以及将神经网络的权重转换为适合与TensorRT一起使用的格式的函数。

do_trt_build_engine函数使用TensorRT API构建神经网络,然后将结果引擎序列化到文件中。

do_trt_inference函数从文件中加载序列化的引擎,然后使用引擎在一组输入图像上执行推理。对于每个输入图像,它将BMP数据转换为矩阵,将矩阵复制到GPU,使用引擎进行推理,然后将输出概率值复制回CPU以供显示。
它首先调用load_model函数加载训练好的模型,并打印出每个权重矩阵的大小。

接下来,它调用do_trt_build_engine函数将模型转换为TensorRT引擎,并将引擎保存到文件mnist.trtmodel中。

最后,它调用do_trt_inference函数对一组输入图像执行推理,并显示每个图像的预测结果和置信度。

在推理完成后,它打印出一条消息表示程序运行完成,并返回0表示程序正常退出。

// tensorRT include
#include <NvInfer.h>
#include <NvInferRuntime.h>// cuda include
#include <cuda_runtime.h>// system include
#include <stdio.h>
#include <string.h>
#include <math.h>#include <vector>
#include <string>
#include <fstream>
#include <algorithm>using namespace std;#define SIMLOG(type, ...)                        \\do{                                          \\printf("[%s:%d]%s: ", __FILE__, __LINE__, type); \\printf(__VA_ARGS__);                     \\printf("\\n");                            \\}while(0)#define INFO(...)   SIMLOG("info", __VA_ARGS__)inline const char* severity_string(nvinfer1::ILogger::Severity t){switch(t){case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: return "internal_error";case nvinfer1::ILogger::Severity::kERROR:   return "error";case nvinfer1::ILogger::Severity::kWARNING: return "warning";case nvinfer1::ILogger::Severity::kINFO:    return "info";case nvinfer1::ILogger::Severity::kVERBOSE: return "verbose";default: return "unknow";}
}class TRTLogger : public nvinfer1::ILogger{
public:virtual void log(Severity severity, nvinfer1::AsciiChar const* msg) noexcept override{if(severity <= Severity::kINFO){SIMLOG(severity_string(severity), "%s", msg);}}
};struct Matrix{vector<float> data;int rows = 0, cols = 0;void resize(int rows, int cols){this->rows = rows;this->cols = cols;this->data.resize(rows * cols * sizeof(float));}bool empty() const{return data.empty();}int size() const{ return rows * cols; }float* ptr() const{return (float*)this->data.data();}
};struct Model{vector<Matrix> weights;
};void print_image(const vector<unsigned char>& a, int rows, int cols, const char* format = "%3d"){INFO("Matrix[%p], %d x %d", &a, rows, cols);char fmt[20];sprintf(fmt, "%s,", format);for(int i = 0; i < rows; ++i){printf("row[%02d]: ", i);for(int j = 0; j < cols; ++j){int index = (rows - i - 1) * cols + j;printf(fmt, a.data()[index * 3 + 0]);}printf("\\n");}
}vector<unsigned char> load_file(const string& file){ifstream in(file, ios::in | ios::binary);if (!in.is_open())return {};in.seekg(0, ios::end);size_t length = in.tellg();std::vector<uint8_t> data;if (length > 0){in.seekg(0, ios::beg);data.resize(length);in.read((char*)&data[0], length);}in.close();return data;
}bool load_model(Model& model){model.weights.resize(4);const int weight_shapes[][2] = {{1024, 784},{1024, 1},{10, 1024},{10, 1}};for(int i = 0; i < model.weights.size(); ++i){char weight_name[100];sprintf(weight_name, "%d.weight", i);auto data = load_file(weight_name);if(data.empty()){INFO("Load %s failed.", weight_name);return false;}auto& w = model.weights[i];int rows = weight_shapes[i][0];int cols = weight_shapes[i][1];if(data.size() != rows * cols * sizeof(float)){INFO("Invalid weight file: %s", weight_name);return false;}w.resize(rows, cols);memcpy(w.ptr(), data.data(), data.size());}return true;
}Matrix bmp_data_to_normalize_matrix(const vector<unsigned char>& data){Matrix output;const int std_w = 28;const int std_h = 28;if(data.size() != std_w * std_h * 3){INFO("Invalid bmp file, must be %d x %d @ rgb 3 channels image", std_w, std_h);return output;}output.resize(1, std_w * std_h);const unsigned char* begin_ptr = data.data();float* output_ptr = output.ptr();for(int i = 0; i < std_h; ++i){const unsigned char* image_row_ptr = begin_ptr + (std_h - i - 1) * std_w * 3;float* output_row_ptr = output_ptr + i * std_w;for(int j = 0; j < std_w; ++j){// normalizeoutput_row_ptr[j] = (image_row_ptr[j * 3 + 0] / 255.0f - 0.1307f) / 0.3081f;;}}return output;
}nvinfer1::Weights model_weights_to_trt_weights(const Matrix& model_weights){nvinfer1::Weights output;output.type = nvinfer1::DataType::kFLOAT;output.values = model_weights.ptr();output.count = model_weights.size();return output;
}TRTLogger logger;
void do_trt_build_engine(const Model& model, const string& save_file){/*Network is:image|linear (fully connected)  input = 784, output = 1024, bias = True|relu|linear (fully connected)  input = 1024, output = 10, bias = True|sigmoid|prob*/nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();nvinfer1::INetworkDefinition* network = builder->createNetworkV2(1);nvinfer1::ITensor* input = network->addInput("image", nvinfer1::DataType::kFLOAT, nvinfer1::Dims4(1, 784, 1, 1));nvinfer1::Weights layer1_weight = model_weights_to_trt_weights(model.weights[0]);nvinfer1::Weights layer1_bias = model_weights_to_trt_weights(model.weights[1]);auto layer1 = network->addFullyConnected(*input, model.weights[0].rows, layer1_weight, layer1_bias);auto relu1 = network->addActivation(*layer1->getOutput(0), nvinfer1::ActivationType::kRELU);nvinfer1::Weights layer2_weight = model_weights_to_trt_weights(model.weights[2]);nvinfer1::Weights layer2_bias = model_weights_to_trt_weights(model.weights[3]);auto layer2 = network->addFullyConnected(*relu1->getOutput(0), model.weights[2].rows, layer2_weight, layer2_bias);auto prob = network->addActivation(*layer2->getOutput(0), nvinfer1::ActivationType::kSIGMOID);network->markOutput(*prob->getOutput(0));config->setMaxWorkspaceSize(1 << 28);builder->setMaxBatchSize(1);nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);if(engine == nullptr){INFO("Build engine failed.");return;}nvinfer1::IHostMemory* model_data = engine->serialize();ofstream outf(save_file, ios::binary | ios::out);if(outf.is_open()){outf.write((const char*)model_data->data(), model_data->size());outf.close();}else{INFO("Open %s failed", save_file.c_str());}model_data->destroy();engine->destroy();network->destroy();config->destroy();builder->destroy();
}void do_trt_inference(const string& model_file){auto engine_data = load_file(model_file);if(engine_data.empty()){INFO("engine_data is empty");return;}nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engine_data.data(), engine_data.size());if(engine == nullptr){INFO("Deserialize cuda engine failed.");return;}nvinfer1::IExecutionContext* execution_context = engine->createExecutionContext();cudaStream_t stream = nullptr;cudaStreamCreate(&stream);const char* image_list[] = {"5.bmp", "6.bmp"};int num_image = sizeof(image_list) / sizeof(image_list[0]);const int num_classes = 10;for(int i = 0; i < num_image; ++i){const int bmp_file_head_size = 54;auto file_name  = image_list[i];auto image_data = load_file(file_name);if(image_data.empty() || image_data.size() != bmp_file_head_size + 28*28*3){INFO("Load image failed: %s", file_name);continue;}image_data.erase(image_data.begin(), image_data.begin() + bmp_file_head_size);auto image = bmp_data_to_normalize_matrix(image_data);float* image_device_ptr = nullptr;cudaMalloc(&image_device_ptr, image.size() * sizeof(float));cudaMemcpyAsync(image_device_ptr, image.ptr(), image.size() * sizeof(float), cudaMemcpyHostToDevice, stream);float* output_device_ptr = nullptr;cudaMalloc(&output_device_ptr, num_classes * sizeof(float));float* bindings[] = {image_device_ptr, output_device_ptr};bool success      = execution_context->enqueueV2((void**)bindings, stream, nullptr);float predict_proba[num_classes];cudaMemcpyAsync(predict_proba, output_device_ptr, num_classes * sizeof(float), cudaMemcpyDeviceToHost, stream);cudaStreamSynchronize(stream);// release memorycudaFree(image_device_ptr);cudaFree(output_device_ptr);int predict_label  = std::max_element(predict_proba, predict_proba + num_classes) - predict_proba;float predict_prob = predict_proba[predict_label];print_image(image_data, 28, 28);INFO("image matrix: %d x %d", image.rows, image.cols);INFO("%s predict: %d, confidence: %f", file_name, predict_label, predict_prob);printf("Press 'Enter' to next, Press 'q' to quit: ");int c = getchar();if(c == 'q')break;}INFO("Clean memory");cudaStreamDestroy(stream);execution_context->destroy();engine->destroy();runtime->destroy();
}   int main(){Model model;if(!load_model(model))return 0;for(int i = 0; i < model.weights.size(); ++i){INFO("weight.%d shape = %d x %d", i, model.weights[i].rows, model.weights[i].cols);}auto trtmodel = "mnist.trtmodel";do_trt_build_engine(model, trtmodel);do_trt_inference(trtmodel);INFO("done.");return 0;
}