最近在学习onnx模型c++推理相关的内容,case已经训练一款基于BERT的实体识别模型,分类总有24中,pytorch模型也已经被转成onnx模型,接下来是用C++编写推理代码;其中输入数据是用python tokenizer生成的,包括token_ids,token_type_ids和attention_mask等;

代码包括分类模型的定义,参数初始化,数据生成,预测以及输出后处理等;onnxruntime版本是19.0;

编译完成之后在terminal中 :bash your_api ./your_onnx_model.onnx 


#include <iostream>
#include <vector>
#include <string>

#include "onnxruntime_cxx_api.h"
#include "cpu_provider_factory.h"

using namespace std;


class BERTCLASSIFIER
{

public:

    BERTCLASSIFIER(const char* model_path){
        
        session = Ort::Session(env, model_path, session_option);
        
    }

    void createInput(std::vector<int64_t> &token_ids, 
                     std::vector<int64_t> &token_type_ids,
                     std::vector<int64_t> &attention_mask,
                     std::vector<int64_t> &input_shape){

        auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
        
        input_tensors_.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, 
                                                                token_ids.data(), 
                                                                token_ids.size(), 
                                                                input_shape.data(), 
                                                                input_shape.size()));
        input_tensors_.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, 
                                                                token_type_ids.data(), 
                                                                token_type_ids.size(), 
                                                                input_shape.data(), 
                                                                input_shape.size()));
        input_tensors_.push_back(Ort::Value::CreateTensor<int64_t>(memory_info, 
                                                                attention_mask.data(), 
                                                                attention_mask.size(), 
                                                                input_shape.data(), 
                                                                input_shape.size()));

    }

    std::vector<Ort::Value> predict(){

        num_input_nodes = session.GetInputCount();
        num_output_nodes = session.GetOutputCount();

        for (size_t i = 0; i < num_input_nodes; i++) {
            auto in_name = session.GetInputName(i, allocator);
            const char* in_name_cstr = in_name; // 获取字符串指针
            std::cout << "Input Name: " << in_name_cstr << std::endl;
            input_names.push_back(in_name);
        }

        for (size_t i = 0; i < num_output_nodes; i++) {
            auto out_name = session.GetOutputName(i, allocator);
            const char* out_name_cstr = out_name; // 获取字符串指针
            std::cout << "Output Name: " << out_name_cstr << std::endl;
            output_names.push_back(out_name);
        }


        // get the info of the input tensor before inference 
        auto tensorInfo = input_tensors_[0].GetTensorTypeAndShapeInfo();
        int64_t* data = input_tensors_[0].GetTensorMutableData<int64_t>();
        std::vector<int64_t> shape = tensorInfo.GetShape();
        int seqLength = shape[1];
        std::cout << "Input_Length: " << seqLength << std::endl;
        std::cout<< "Input: [";
        for(size_t i=0; i<seqLength; i++){
            std::cout << data[i] << " "; 
        }
        std::cout<< "]" << endl;

        auto out_of_model = session.Run(
                            Ort::RunOptions{ nullptr }, 
                            input_names.data(),
                            input_tensors_.data(),
                            input_tensors_.size(),
                            output_names.data(),
                            output_names.size());
        
        return out_of_model;
    }


private:

    Ort::SessionOptions session_option;
    Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"};

    Ort::Session session {nullptr};
    Ort::AllocatorWithDefaultOptions allocator;

    std::vector<Ort::Value> input_tensors_;
    std::vector<Ort::Value> output_tensors_;

    size_t num_input_nodes;
    size_t num_output_nodes;

    std::vector<const char*> input_names =  {}; 
    std::vector<const char*> output_names = {}; 
    
};

// post processing of result, get the max probs of each out class

std::vector<int64_t> processModelOutput(Ort::Value& outputTensor) {
    auto tensorInfo = outputTensor.GetTensorTypeAndShapeInfo();
    std::vector<int64_t> shape = tensorInfo.GetShape(); 
    
    float* data = outputTensor.GetTensorMutableData<float>();
    
    int batchSize = shape[0];    
    int seqLength = shape[1];   
    int numClasses = shape[2];
    
    std::vector<int64_t> result(seqLength);
    
    for (int i = 0; i < seqLength; ++i) {
        int maxIndex = 0;
        float maxValue = data[i * numClasses];
        
        for (int j = 1; j < numClasses; ++j) {
            float currentValue = data[i * numClasses + j];
            if (currentValue > maxValue) {
                maxValue = currentValue;
                maxIndex = j;
            }
        }
        result[i] = maxIndex;
    }
    
    return result;
}




int main(int argc, char* argv[]){


        const char* model_path = argv[1];     // "./model_dir/my_onnx_model.onnx";
        BERTCLASSIFIER bertnet(model_path);

        std::vector<int64_t> input_shape = {1, 12};
        std::vector<int64_t> token_ids =  {2, 21, 19, 26, 8421, 1725, 8531, 12303, 1725, 8274, 18, 3}; 
        std::vector<int64_t> token_type_ids = {0,0,0,0,0,0,0,0,0,0,0,0};
        std::vector<int64_t> attention_mask = {1,1,1,1,1,1,1,1,1,1,1,1};

    
        bertnet.createInput(token_ids, 
                            token_type_ids, 
                            attention_mask, 
                            input_shape);

        auto model_out = bertnet.predict();

        auto rstl = processModelOutput(model_out[0]);

        size_t l =  rstl.size();

        std::cout << "Output_Length: " << rstl.size() << std::endl;
        std::cout<< "Output: [";
        for(size_t i=0; i<l; i++) {
            std::cout << rstl[i] << " ";
        }
        std::cout<<"]" << std::endl;
        
    return 0;
}



//std::vector<int64_t> input_shape = {1, 12};
//std::vector<int64_t> token_ids =  {2, 21, 19, 26, 8421, 1725, 8531, 12303, 1725, 8274, 18, 3}; 
//std::vector<int64_t> token_type_ids = {0,0,0,0,0,0,0,0,0,0,0,0};
//std::vector<int64_t> attention_mask = {1,1,1,1,1,1,1,1,1,1,1,1};

// Input Name: input_ids
// Input Name: token_type_ids
// Input Name: attention_mask
// Output Name: res
// Input_Length: 12
// Input: [2 21 19 26 8421 1725 8531 12303 1725 8274 18 3 ]
// Output_Length: 12
// Output: [4 8 3 3 4 4 4 4 4 4 5 4 ]

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐