onnx模型c++ onnxruntime推理实例
最近在学习onnx模型c++推理相关的内容,case已经训练一款基于BERT的实体识别模型,分类总有24中,pytorch模型也已经被转成onnx模型,接下来是用C++编写推理代码;其中输入数据是用python tokenizer生成的,包括token_ids,token_type_ids和attention_mask等;编译完成之后在terminal中 :bash your_api ./your
·
最近在学习onnx模型c++推理相关的内容,case已经训练一款基于BERT的实体识别模型,分类总有24中,pytorch模型也已经被转成onnx模型,接下来是用C++编写推理代码;其中输入数据是用python tokenizer生成的,包括token_ids,token_type_ids和attention_mask等;
代码包括分类模型的定义,参数初始化,数据生成,预测以及输出后处理等;onnxruntime版本是19.0;
编译完成之后在terminal中 :bash your_api ./your_onnx_model.onnx
#include <iostream>
#include <vector>
#include <string>
#include "onnxruntime_cxx_api.h"
#include "cpu_provider_factory.h"
using namespace std;
class BERTCLASSIFIER
{
public:
BERTCLASSIFIER(const char* model_path){
session = Ort::Session(env, model_path, session_option);
}
void createInput(std::vector<int64_t> &token_ids,
std::vector<int64_t> &token_type_ids,
std::vector<int64_t> &attention_mask,
std::vector<int64_t> &input_shape){
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
input_tensors_.push_back(Ort::Value::CreateTensor<int64_t>(memory_info,
token_ids.data(),
token_ids.size(),
input_shape.data(),
input_shape.size()));
input_tensors_.push_back(Ort::Value::CreateTensor<int64_t>(memory_info,
token_type_ids.data(),
token_type_ids.size(),
input_shape.data(),
input_shape.size()));
input_tensors_.push_back(Ort::Value::CreateTensor<int64_t>(memory_info,
attention_mask.data(),
attention_mask.size(),
input_shape.data(),
input_shape.size()));
}
std::vector<Ort::Value> predict(){
num_input_nodes = session.GetInputCount();
num_output_nodes = session.GetOutputCount();
for (size_t i = 0; i < num_input_nodes; i++) {
auto in_name = session.GetInputName(i, allocator);
const char* in_name_cstr = in_name; // 获取字符串指针
std::cout << "Input Name: " << in_name_cstr << std::endl;
input_names.push_back(in_name);
}
for (size_t i = 0; i < num_output_nodes; i++) {
auto out_name = session.GetOutputName(i, allocator);
const char* out_name_cstr = out_name; // 获取字符串指针
std::cout << "Output Name: " << out_name_cstr << std::endl;
output_names.push_back(out_name);
}
// get the info of the input tensor before inference
auto tensorInfo = input_tensors_[0].GetTensorTypeAndShapeInfo();
int64_t* data = input_tensors_[0].GetTensorMutableData<int64_t>();
std::vector<int64_t> shape = tensorInfo.GetShape();
int seqLength = shape[1];
std::cout << "Input_Length: " << seqLength << std::endl;
std::cout<< "Input: [";
for(size_t i=0; i<seqLength; i++){
std::cout << data[i] << " ";
}
std::cout<< "]" << endl;
auto out_of_model = session.Run(
Ort::RunOptions{ nullptr },
input_names.data(),
input_tensors_.data(),
input_tensors_.size(),
output_names.data(),
output_names.size());
return out_of_model;
}
private:
Ort::SessionOptions session_option;
Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"};
Ort::Session session {nullptr};
Ort::AllocatorWithDefaultOptions allocator;
std::vector<Ort::Value> input_tensors_;
std::vector<Ort::Value> output_tensors_;
size_t num_input_nodes;
size_t num_output_nodes;
std::vector<const char*> input_names = {};
std::vector<const char*> output_names = {};
};
// post processing of result, get the max probs of each out class
std::vector<int64_t> processModelOutput(Ort::Value& outputTensor) {
auto tensorInfo = outputTensor.GetTensorTypeAndShapeInfo();
std::vector<int64_t> shape = tensorInfo.GetShape();
float* data = outputTensor.GetTensorMutableData<float>();
int batchSize = shape[0];
int seqLength = shape[1];
int numClasses = shape[2];
std::vector<int64_t> result(seqLength);
for (int i = 0; i < seqLength; ++i) {
int maxIndex = 0;
float maxValue = data[i * numClasses];
for (int j = 1; j < numClasses; ++j) {
float currentValue = data[i * numClasses + j];
if (currentValue > maxValue) {
maxValue = currentValue;
maxIndex = j;
}
}
result[i] = maxIndex;
}
return result;
}
int main(int argc, char* argv[]){
const char* model_path = argv[1]; // "./model_dir/my_onnx_model.onnx";
BERTCLASSIFIER bertnet(model_path);
std::vector<int64_t> input_shape = {1, 12};
std::vector<int64_t> token_ids = {2, 21, 19, 26, 8421, 1725, 8531, 12303, 1725, 8274, 18, 3};
std::vector<int64_t> token_type_ids = {0,0,0,0,0,0,0,0,0,0,0,0};
std::vector<int64_t> attention_mask = {1,1,1,1,1,1,1,1,1,1,1,1};
bertnet.createInput(token_ids,
token_type_ids,
attention_mask,
input_shape);
auto model_out = bertnet.predict();
auto rstl = processModelOutput(model_out[0]);
size_t l = rstl.size();
std::cout << "Output_Length: " << rstl.size() << std::endl;
std::cout<< "Output: [";
for(size_t i=0; i<l; i++) {
std::cout << rstl[i] << " ";
}
std::cout<<"]" << std::endl;
return 0;
}
//std::vector<int64_t> input_shape = {1, 12};
//std::vector<int64_t> token_ids = {2, 21, 19, 26, 8421, 1725, 8531, 12303, 1725, 8274, 18, 3};
//std::vector<int64_t> token_type_ids = {0,0,0,0,0,0,0,0,0,0,0,0};
//std::vector<int64_t> attention_mask = {1,1,1,1,1,1,1,1,1,1,1,1};
// Input Name: input_ids
// Input Name: token_type_ids
// Input Name: attention_mask
// Output Name: res
// Input_Length: 12
// Input: [2 21 19 26 8421 1725 8531 12303 1725 8274 18 3 ]
// Output_Length: 12
// Output: [4 8 3 3 4 4 4 4 4 4 5 4 ]
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐



所有评论(0)