!8402 refactor mindir loading

From: @wangnan39
Reviewed-by: @kingxian,@guoqi1024
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2020-12-01 15:38:55 +08:00 committed by Gitee
commit 6ecf200b49
35 changed files with 1430 additions and 1110 deletions

View File

@ -125,7 +125,8 @@ if (ENABLE_DUMP_PROTO)
"utils/lineage.proto"
"utils/checkpoint.proto"
"utils/print.proto"
"utils/node_strategy.proto"
"utils/node_strategy.proto"
"utils/mind_ir.proto"
)
ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY})
@ -156,7 +157,7 @@ endif()
## make sub objects
set(SUB_COMP
transform/graph_ir
transform/onnx
transform/express_ir
backend/optimizer
backend/kernel_compiler
backend/session
@ -344,13 +345,13 @@ if (ENABLE_MINDDATA)
endif ()
# build inference
set(LOAD_ONNX_SRC
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_converter.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc
set(LOAD_MINDIR_SRC
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc
)
add_library(inference SHARED
${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc
${LOAD_ONNX_SRC}
${LOAD_MINDIR_SRC}
)
set_target_properties(inference PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})

View File

@ -20,11 +20,11 @@
#include <fstream>
#include "include/inference.h"
#include "utils/load_onnx/anf_converter.h"
#include "backend/session/session_basic.h"
#include "backend/session/session_factory.h"
#include "backend/session/executor_manager.h"
#include "base/base_ref_utils.h"
#include "load_mindir/load_model.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "utils/context/context_extends.h"
#include "runtime/device/kernel_runtime_manager.h"
@ -58,46 +58,9 @@ std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &dev
MSInferSession::MSInferSession() = default;
MSInferSession::~MSInferSession() = default;
std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &file) {
if (file.empty()) {
MS_LOG(ERROR) << "file is nullptr";
return nullptr;
}
std::string realPath = file;
std::ifstream ifs(realPath);
if (!ifs.good()) {
MS_LOG(ERROR) << "file: " << realPath << " is not exist";
return nullptr;
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "file: " << realPath << "open failed";
return nullptr;
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size));
if (buf == nullptr) {
MS_LOG(ERROR) << "malloc buf failed, file: " << realPath;
ifs.close();
return nullptr;
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf->data(), size);
ifs.close();
return buf;
}
Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
auto graphBuf = ReadFile(file_name);
if (graphBuf == nullptr) {
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
return FAILED;
}
auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_);
Py_Initialize();
auto graph = RunLoadMindIR(file_name);
if (graph == nullptr) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
return FAILED;
@ -213,6 +176,7 @@ Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &reques
}
inputs.push_back(input);
}
auto ret = CheckModelInputs(model_id, inputs);
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
@ -250,16 +214,6 @@ Status MSInferSession::FinalizeEnv() {
return SUCCESS;
}
std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) {
try {
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
return anf_graph;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference LoadModel failed";
return nullptr;
}
}
void MSInferSession::RegAllOp() {
static std::mutex init_mutex;
static bool Initialized = false;

View File

@ -54,8 +54,6 @@ class MSInferSession : public InferSession {
rtContext_t context_ = nullptr;
#endif
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device);
std::shared_ptr<std::vector<char>> ReadFile(const std::string &file);
static void RegAllOp();
string AjustTargetName(const std::string &device);
Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);

View File

@ -1,7 +1,7 @@
# build mindspore_shared_lib
set(LOAD_ONNX_SRC
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_converter.cc
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc
set(LOAD_MINDIR_SRC
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc
)
file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc")
@ -18,7 +18,7 @@ set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
${API_MS_INFER_SRC}
${API_ACL_SRC}
${API_OPS_SRC}
${LOAD_ONNX_SRC})
${LOAD_MINDIR_SRC})
add_library(mindspore_shared_lib SHARED ${MSLIB_SRC})
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}")

View File

@ -17,9 +17,9 @@
#include "cxx_api/model/acl/model_converter.h"
#include <memory>
#include "pybind11/pybind11.h"
#include "utils/load_onnx/anf_converter.h"
#include "transform/graph_ir/convert.h"
#include "transform/graph_ir/graph_runner.h"
#include "core/load_mindir/load_model.h"
#include "mindspore/core/utils/ms_context.h"
#include "backend/kernel_compiler/oplib/oplib.h"
@ -79,8 +79,7 @@ bool CreateSessionAndGraphRunner() {
std::shared_ptr<FuncGraph> ModelConverter::ConvertMindIrToFuncGraph(const Buffer &model_data) {
try {
auto anf_graph =
lite::AnfConverter::RunAnfConverter(reinterpret_cast<const char *>(model_data.Data()), model_data.DataSize());
auto anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data.Data()), model_data.DataSize());
return anf_graph;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Load MindIR failed.";
@ -364,6 +363,7 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) {
RegAllOp();
Py_Initialize();
auto func_graph = ConvertMindIrToFuncGraph(model_data);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed.";

View File

@ -19,7 +19,7 @@
#include <algorithm>
#include <fstream>
#include "utils/load_onnx/anf_converter.h"
#include "load_mindir/load_model.h"
#include "backend/session/session_basic.h"
#include "backend/session/session_factory.h"
#include "backend/session/executor_manager.h"
@ -117,9 +117,9 @@ Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::m
return FAILED;
}
std::shared_ptr<FuncGraph> anf_graph;
Py_Initialize();
try {
anf_graph =
lite::AnfConverter::RunAnfConverter(static_cast<const char *>(model_data.Data()), model_data.DataSize());
anf_graph = ConvertStreamToFuncGraph(static_cast<const char *>(model_data.Data()), model_data.DataSize());
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference LoadModel failed";
return FAILED;
@ -290,9 +290,10 @@ Status MsModel::FinalizeEnv() {
}
std::shared_ptr<FuncGraph> MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) {
Py_Initialize();
MS_EXCEPTION_IF_NULL(model_buf);
try {
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
auto anf_graph = ConvertStreamToFuncGraph(model_buf, size);
return anf_graph;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what();

View File

@ -74,6 +74,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.")
.def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""),
py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.")
.def("convert_funcgraph_to_mindir", &ExecutorPy::ConvertFuncGraphToMindIR, py::arg("graph"),
"Convert FuncGraph to MindIR proto.")
.def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
.def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
@ -108,6 +110,7 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend.");
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
(py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load MindIR as Graph.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")

View File

@ -45,6 +45,7 @@
#include "debug/draw.h"
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "load_mindir/load_model.h"
#include "pybind_api/pybind_patch.h"
#include "utils/shape_utils.h"
#include "utils/info.h"
@ -103,6 +104,16 @@ void CheckArgIsTensor(const ValuePtr &arg, std::size_t idx) {
}
} // namespace
py::bytes ExecutorPy::ConvertFuncGraphToMindIR(const FuncGraphPtr &fg_ptr) {
std::string proto_str = GetBinaryProtoString(fg_ptr);
if (proto_str.empty()) {
MS_LOG(EXCEPTION) << "Graph proto is empty.";
}
return proto_str;
}
FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::RunLoadMindIR(file_name); }
py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) {
MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size();
abstract::AbstractBasePtrList args_spec;

View File

@ -82,6 +82,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
ResourcePtr GetResource(const std::string &phase);
FuncGraphPtr GetFuncGraph(const std::string &phase);
py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type);
py::bytes ConvertFuncGraphToMindIR(const FuncGraphPtr &fg_ptr);
compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase);
bool HasCompiled(const std::string &phase) const;
@ -138,6 +139,7 @@ void ClearResAtexit();
void ReleaseGeTsd();
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase);
FuncGraphPtr LoadMindIR(const std::string &file_name);
// init and exec dataset sub graph
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,

View File

@ -0,0 +1,3 @@
file(GLOB_RECURSE _EXPORTER_IR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX)
add_library(_mindspore_transform_express_ir_obj OBJECT ${_EXPORTER_IR_SRC_FILES})

View File

@ -25,33 +25,40 @@
#include "ir/param_info.h"
#include "ir/func_graph.h"
#include "base/core_ops.h"
#include "proto/onnx.pb.h"
#include "proto/mind_ir.pb.h"
namespace mindspore {
using FloatPtr = std::shared_ptr<Float>;
using IntPtr = std::shared_ptr<Int>;
// anf type to onnx type map
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_type_map = {
{kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8},
{kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32},
{kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8},
{kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32},
{kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16},
{kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE},
{kObjectTypeString, onnx::TensorProto_DataType_STRING},
// anf type to mindir type map
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_type_map = {
{kNumberTypeBool, mind_ir::TensorProto_DataType_BOOL},
{kNumberTypeInt8, mind_ir::TensorProto_DataType_INT8},
{kNumberTypeInt16, mind_ir::TensorProto_DataType_INT16},
{kNumberTypeInt32, mind_ir::TensorProto_DataType_INT32},
{kNumberTypeInt64, mind_ir::TensorProto_DataType_INT64},
{kNumberTypeUInt8, mind_ir::TensorProto_DataType_UINT8},
{kNumberTypeUInt16, mind_ir::TensorProto_DataType_UINT16},
{kNumberTypeUInt32, mind_ir::TensorProto_DataType_UINT32},
{kNumberTypeUInt64, mind_ir::TensorProto_DataType_UINT64},
{kNumberTypeFloat16, mind_ir::TensorProto_DataType_FLOAT16},
{kNumberTypeFloat32, mind_ir::TensorProto_DataType_FLOAT},
{kNumberTypeFloat64, mind_ir::TensorProto_DataType_DOUBLE},
{kObjectTypeString, mind_ir::TensorProto_DataType_STRING},
};
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_bits_int_map = {
{8, onnx::TensorProto_DataType_INT8},
{16, onnx::TensorProto_DataType_INT16},
{32, onnx::TensorProto_DataType_INT32},
{64, onnx::TensorProto_DataType_INT64},
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_int_map = {
{8, mind_ir::TensorProto_DataType_INT8},
{16, mind_ir::TensorProto_DataType_INT16},
{32, mind_ir::TensorProto_DataType_INT32},
{64, mind_ir::TensorProto_DataType_INT64},
};
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_bits_float_map = {
{16, onnx::TensorProto_DataType_FLOAT16},
{32, onnx::TensorProto_DataType_FLOAT},
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_float_map = {
{16, mind_ir::TensorProto_DataType_FLOAT16},
{32, mind_ir::TensorProto_DataType_FLOAT},
{64, mind_ir::TensorProto_DataType_FLOAT64},
};
// Can build different builder according to format
@ -77,34 +84,34 @@ class IrExportBuilder {
void BuildModel(const FuncGraphPtr &func_graph);
private:
void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto);
void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto);
std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto);
void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto);
void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto);
void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto);
void SetParamToTensorProto(const ParameterPtr &param, onnx::TensorProto *const tensor_proto);
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto);
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto,
void SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::ValueInfoProto *const value_proto);
void SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto);
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
void SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
void SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string);
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name);
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto,
void SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
void SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
void SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
void SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
void SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string);
void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto,
void SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string);
onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits);
onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits);
mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id);
mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits);
mind_ir::TensorProto_DataType GetMindirDataBitsFloatType(int bits);
std::string GetNodeName(const AnfNodePtr &node);
std::string GetUniqueNodeName(const AnfNodePtr &node);
std::string GetOpTypeName(const AnfNodePtr &node);
@ -114,8 +121,8 @@ class IrExportBuilder {
void ResetTupleIndex() { shape_index_ = 0; }
private:
onnx::ModelProto model_;
onnx::NodeProto *last_node_{nullptr};
mind_ir::ModelProto model_;
mind_ir::NodeProto *last_node_{nullptr};
std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, size_t> node_index_map_;
size_t node_index_{0};
@ -144,13 +151,13 @@ std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) {
}
void IrExportBuilder::BuildModelInfo() {
model_.set_ir_version(onnx::IR_VERSION_2019_1_22);
model_.set_ir_version("0.1.0");
model_.set_producer_name("MindSpore");
model_.set_model_version(1);
model_.set_model_version("1.1.0");
}
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
onnx::GraphProto *graph_proto = model_.mutable_graph();
mind_ir::GraphProto *graph_proto = model_.mutable_graph();
graph_proto->set_name(func_graph->ToString());
ResetNodeIndex();
todo_.clear();
@ -162,7 +169,7 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
}
}
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
// Export parameters
// 1. parameters should be mapped to ValueInfoProto
// 2. parameters with default value should be mapped to Initializer
@ -172,33 +179,31 @@ void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::Graph
BuildNodes(func_graph, graph_proto);
}
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
for (auto &item : func_graph->parameters()) {
auto param = item->cast<ParameterPtr>();
if (param == nullptr) {
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
}
onnx::ValueInfoProto *input_proto = graph_proto->add_input();
std::string param_name = GetUniqueNodeName(param);
input_proto->set_name(param_name);
SetValueInfoProto(param, input_proto);
if (!param->has_default()) {
if (param->has_default()) {
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default.";
continue;
}
// Using ONNX initializer to set parameter's default value
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
initializer_proto->set_name(param_name);
SetParamToTensorProto(param, initializer_proto);
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
if (tensor) {
initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
parameter_proto->set_name(param_name);
SetParamToTensorProto(param, parameter_proto);
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
if (tensor) {
parameter_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
}
} else {
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
input_proto->set_name(param_name);
SetValueInfoProto(param, input_proto);
}
}
}
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) {
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) {
auto iter = g_data_type_map.find(type_id);
if (iter == g_data_type_map.end()) {
MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id;
@ -206,7 +211,7 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) {
return iter->second;
}
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) {
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) {
auto iter = g_data_bits_int_map.find(bits);
if (iter == g_data_bits_int_map.end()) {
MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits;
@ -214,7 +219,7 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) {
return iter->second;
}
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) {
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) {
auto iter = g_data_bits_float_map.find(bits);
if (iter == g_data_bits_float_map.end()) {
MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits;
@ -222,73 +227,70 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) {
return iter->second;
}
void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) {
void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) {
if (node == nullptr || value_proto == nullptr) {
MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!";
}
MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
SetValueInfoProto(node->Type(), node->Shape(), value_proto);
}
void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape,
onnx::ValueInfoProto *const value_proto) {
onnx::TypeProto *type_proto = value_proto->mutable_type();
const TypePtr &type = node->Type();
const BaseShapePtr &shape = node->Shape();
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
auto tensor = type->cast<TensorTypePtr>();
auto elem_type = tensor->element();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id()));
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id()));
if (dims.size() == 0) {
MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1.";
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
tensor_proto->add_dims(1);
} else {
for (const auto &dim : dims) {
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
tensor_proto->add_dims(dim);
}
}
} else if (type->isa<Tuple>()) {
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
} else if (type->isa<Number>() || type->isa<String>()) {
type_proto->set_denotation(type->type_name());
value_proto->set_denotation(type->type_name());
} else {
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
}
}
void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_ref_attr_name("tensor:value0");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name("value0");
auto data = value->cast<tensor::TensorPtr>();
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
auto dtype = data->data_type();
auto shape = data->shape_c();
tensor_proto->set_data_type(GetOnnxDataType(dtype));
tensor_proto->set_data_type(GetMindirDataType(dtype));
for (const auto &dim : shape) {
tensor_proto->add_dims(dim);
}
}
void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
onnx::TensorProto *const tensor_proto) {
mind_ir::TensorProto *const tensor_proto) {
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString();
}
auto tensor = type->cast<TensorTypePtr>();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id()));
tensor_proto->set_data_type(GetMindirDataType(tensor->element()->type_id()));
for (const auto &dim : dims) {
tensor_proto->add_dims(dim);
}
}
void IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, onnx::TensorProto *const tensor_proto) {
void IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto) {
if (param == nullptr || tensor_proto == nullptr) {
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
}
@ -296,7 +298,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, onnx::Ten
SetTensorProto(param->Type(), param->Shape(), tensor_proto);
}
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
bool is_only_return = true;
for (const AnfNodePtr &node : nodes) {
@ -317,12 +319,12 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt
}
}
void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
if (node->size() != 2) {
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
}
AnfNodePtr arg = node->input(1);
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
std::string output_name = GetUniqueNodeName(node);
output_proto->set_name(output_name);
last_node_->set_output(0, output_name);
@ -349,7 +351,7 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
}
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
onnx::AttributeProto *const attr_proto, std::string *const seq_string) {
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
if (type->isa<Tuple>() && seq_string != nullptr) {
*seq_string += "Tuple[";
auto elements = type->cast<TuplePtr>()->elements();
@ -361,7 +363,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
} else if (type->isa<TensorType>() && shape->isa<abstract::Shape>() && seq_string != nullptr) {
string shape_name = "shape" + std::to_string(GetTupleIndex());
*seq_string += shape_name + ",";
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name(shape_name);
SetTensorProto(type, shape, tensor_proto);
} else if ((type->isa<Number>() || type->isa<String>()) && seq_string != nullptr) {
@ -371,7 +373,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
}
}
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) {
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
// Get shape of cnode
// 1. need to get shape from tuple element
// 2. save shape in TensorProto
@ -381,13 +383,13 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto
auto shape = node->Shape();
ResetTupleIndex();
std::string seq_string = "shape:";
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
SetShapeToNodeProto(type, shape, attr_proto, &seq_string);
attr_proto->set_ref_attr_name(seq_string);
MS_LOG(DEBUG) << "CNode shape: " << seq_string;
}
void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
auto inputs_size = node->size();
if (inputs_size < 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
@ -403,7 +405,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g
}
// Build cnode
onnx::NodeProto *node_proto = graph_proto->add_node();
mind_ir::NodeProto *node_proto = graph_proto->add_node();
std::string output_name = GetUniqueNodeName(node);
node_proto->add_output(output_name);
node_proto->set_name(output_name);
@ -421,7 +423,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g
auto prim = GetValueNode<PrimitivePtr>(op);
for (auto attr : prim->attrs()) {
MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name(attr.first);
SetValueToAttributeProto(attr.second, attr_proto);
}
@ -430,11 +432,11 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g
}
}
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) {
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
std::string node_name = GetUniqueNodeName(node);
if (node->isa<ValueNode>()) {
// When node input is a ValueNode, need to create a Constant Node
onnx::NodeProto *node_proto = graph_proto->add_node();
mind_ir::NodeProto *node_proto = graph_proto->add_node();
node_proto->add_output(node_name);
SetAttributeProto(node, node_proto);
}
@ -478,44 +480,48 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
return node_name;
}
void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) {
void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) {
if (node == nullptr || node_proto == nullptr) {
MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
}
auto value = node->cast<ValueNodePtr>()->value();
node_proto->set_op_type("Constant");
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value");
MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString();
SetValueToAttributeProto(value, attr_proto);
}
void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
if (value->isa<Int>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto int_value = value->cast<IntPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
} else if (value->isa<Float>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto float_value = value->cast<FloatPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
} else if (value->isa<Bool>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
} else if (value->isa<TensorType>()) {
attr_proto->set_ref_attr_name("type:tensor0");
tensor_proto->set_name("tensor0");
auto elem_type = value->cast<TensorTypePtr>()->element();
if (elem_type->isa<Int>()) {
auto int_value = elem_type->cast<IntPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
} else if (elem_type->isa<Float>()) {
auto float_value = elem_type->cast<FloatPtr>();
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
} else {
MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name();
}
@ -524,18 +530,18 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri
}
}
void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
if (value->isa<StringImm>() || value->isa<Scalar>()) {
SetScalarToAttributeProto(value, attr_proto);
SetScalarToAttributeProto_ir(value, attr_proto);
} else if (value->isa<Number>() || value->isa<TensorType>()) {
SetTypeToAttributeProto(value, attr_proto);
} else if (value->isa<ValueSequeue>() || value->isa<ValueSequeue>()) {
} else if (value->isa<ValueSequeue>()) {
ResetTupleIndex();
std::string seq_string = "scalar:";
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string);
attr_proto->set_ref_attr_name(seq_string);
MS_LOG(DEBUG) << "Attr string: " << seq_string;
@ -549,74 +555,102 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr
}
}
void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
}
void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
attr_proto->set_ref_attr_name("scalar:value0");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
SetScalarToProto(value, tensor_proto, "value0");
}
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto,
const std::string &value_name) {
if (value == nullptr || tensor_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!";
}
tensor_proto->set_name(value_name);
if (value->isa<StringImm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING);
tensor_proto->add_string_data(GetValue<std::string>(value));
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
attr_proto->set_s(GetValue<std::string>(value));
} else if (value->isa<BoolImm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL);
tensor_proto->add_int32_data(GetValue<bool>(value));
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
attr_proto->set_i(GetValue<bool>(value));
} else if (value->isa<Int8Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8);
tensor_proto->add_int32_data(value->cast<Int8ImmPtr>()->value());
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
attr_proto->set_i(value->cast<Int8ImmPtr>()->value());
} else if (value->isa<Int16Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16);
tensor_proto->add_int32_data(value->cast<Int16ImmPtr>()->value());
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
attr_proto->set_i(value->cast<Int16ImmPtr>()->value());
} else if (value->isa<Int32Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32);
tensor_proto->add_int32_data(value->cast<Int32ImmPtr>()->value());
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
attr_proto->set_i(value->cast<Int32ImmPtr>()->value());
} else if (value->isa<Int64Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value());
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
attr_proto->set_i(value->cast<Int64ImmPtr>()->value());
} else if (value->isa<UInt8Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8);
tensor_proto->add_int32_data(value->cast<UInt8ImmPtr>()->value());
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
attr_proto->set_i(value->cast<UInt8ImmPtr>()->value());
} else if (value->isa<UInt16Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16);
tensor_proto->add_int32_data(value->cast<UInt16ImmPtr>()->value());
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
attr_proto->set_i(value->cast<UInt16ImmPtr>()->value());
} else if (value->isa<UInt32Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32);
tensor_proto->add_uint64_data(value->cast<UInt32ImmPtr>()->value());
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
attr_proto->set_i(value->cast<UInt32ImmPtr>()->value());
} else if (value->isa<UInt64Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64);
tensor_proto->add_uint64_data(value->cast<UInt64ImmPtr>()->value());
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
attr_proto->set_i(value->cast<UInt64ImmPtr>()->value());
} else if (value->isa<FP32Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
tensor_proto->add_float_data(GetValue<float>(value));
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
attr_proto->set_f(GetValue<float>(value));
} else if (value->isa<FP64Imm>()) {
tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
tensor_proto->add_double_data(GetValue<double>(value));
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
attr_proto->set_d(GetValue<double>(value));
} else {
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
}
}
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto,
void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value->isa<StringImm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
attr_proto->add_strings(GetValue<std::string>(value));
} else if (value->isa<BoolImm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
attr_proto->add_ints(GetValue<bool>(value));
} else if (value->isa<Int8Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
attr_proto->add_ints(value->cast<Int8ImmPtr>()->value());
} else if (value->isa<Int16Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
attr_proto->add_ints(value->cast<Int16ImmPtr>()->value());
} else if (value->isa<Int32Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
attr_proto->add_ints(value->cast<Int32ImmPtr>()->value());
} else if (value->isa<Int64Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
attr_proto->add_ints(value->cast<Int64ImmPtr>()->value());
} else if (value->isa<UInt8Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
attr_proto->add_ints(value->cast<UInt8ImmPtr>()->value());
} else if (value->isa<UInt16Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
attr_proto->add_ints(value->cast<UInt16ImmPtr>()->value());
} else if (value->isa<UInt32Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
attr_proto->add_ints(value->cast<UInt32ImmPtr>()->value());
} else if (value->isa<UInt64Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
attr_proto->add_ints(value->cast<UInt64ImmPtr>()->value());
} else if (value->isa<FP32Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
attr_proto->add_floats(GetValue<float>(value));
} else if (value->isa<FP64Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
attr_proto->add_doubles(GetValue<double>(value));
} else {
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
}
}
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string) {
string value_name = "value" + std::to_string(GetTupleIndex());
if (seq_string != nullptr) {
*seq_string += value_name + ",";
}
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
SetScalarToProto(value, tensor_proto, value_name);
SetScalarToAttributeProto_irs(value, attr_proto);
}
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto,
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!";
@ -625,6 +659,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
*seq_string += "Tuple[";
const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>();
if (tuple_value->value().size() == 0) {
*seq_string += "],";
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
return;
}
@ -640,6 +675,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
*seq_string += "List[";
const ValueListPtr &list_value = value->cast<ValueListPtr>();
if (list_value->value().size() == 0) {
*seq_string += "],";
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0.";
return;
}

View File

@ -1,3 +0,0 @@
file(GLOB_RECURSE _ONNX_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX)
add_library(_mindspore_transform_onnx_obj OBJECT ${_ONNX_SRC_FILES})

View File

@ -5,11 +5,5 @@ if (NOT ENABLE_GE)
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_GE_SRC_FILES})
endif ()
file(GLOB_RECURSE _UTILS_LITE_SRC_FILES
./load_onnx/anf_converter.cc
./load_onnx/anf_model_parser.cc
)
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_LITE_SRC_FILES})
set_property(SOURCE ${_UTILS_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_UTILS)
add_library(_mindspore_utils_obj OBJECT ${_UTILS_SRC_LIST})

View File

@ -1,134 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/load_onnx/anf_converter.h"
#include <fcntl.h>
#include <fstream>
#include <memory>
#include <vector>
#include <string>
#include "pybind11/pybind11.h"
#include "utils/load_onnx/anf_model_parser.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "proto/onnx.pb.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace lite {
const char WHITESPACE[] = "\t\n\v\f\r ";
const int FLAG_PREFIX_LEN = 2;
void AnfConverter::Trim(std::string *input) {
if (input == nullptr) {
return;
}
if (input->empty()) {
return;
}
input->erase(0, input->find_first_not_of(WHITESPACE));
input->erase(input->find_last_not_of(WHITESPACE) + 1);
}
int AnfConverter::ValidateFileStr(const std::string &modelFile, std::string fileType) {
if (modelFile.size() > fileType.size()) {
if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
return 0;
} else {
return 1;
}
} else {
return 1;
}
}
bool AnfConverter::ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) {
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
if (modelFile.size() > PATH_MAX) {
MS_LOG(DEBUG) << "file path " << modelFile << " is too long.";
return false;
}
char real_path[PATH_MAX + 1] = {0};
#if defined(_WIN32) || defined(_WIN64)
if (nullptr == _fullpath(real_path, modelFile.c_str(), PATH_MAX)) {
MS_LOG(DEBUG) << modelFile << " does not exit.";
return false;
}
#else
if (nullptr == realpath(modelFile.c_str(), real_path)) {
MS_LOG(DEBUG) << modelFile << " does not exit.";
return false;
}
#endif
int fd = open(real_path, O_RDONLY);
if (fd < 0) {
MS_LOG(EXCEPTION) << "failed to open file";
}
google::protobuf::io::FileInputStream input(fd);
google::protobuf::io::CodedInputStream code_input(&input);
code_input.SetTotalBytesLimit(INT_MAX, 536870912);
bool ret = onnx_model->ParseFromCodedStream(&code_input);
if (!ret) {
MS_LOG(ERROR) << "load onnx file failed";
return false;
}
(void)close(fd);
MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl;
return true;
}
std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const std::string &file_path) {
std::string modelFile;
std::string tmp = file_path;
Trim(&tmp);
const std::string flagItem(tmp);
size_t pos = flagItem.find_first_of("=");
if (pos == std::string::npos) {
MS_LOG(ERROR) << "Trans data not support input format!";
} else {
modelFile = flagItem.substr(pos + 1);
std::cout << "input protobuf file path is: " << modelFile << std::endl;
}
if (ValidateFileStr(modelFile, ".pb") != 0) {
MS_LOG(EXCEPTION) << "INPUT ILLEGAL: modelFile must be *.pb";
}
onnx::ModelProto model_;
ReadOnnxFromBinary(modelFile, &model_);
MSANFModelParser model_parser;
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
return dstgraph_ptr;
}
std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const char *buf, const size_t buf_size) {
Py_Initialize();
MS_EXCEPTION_IF_NULL(buf);
std::string str((const char *)buf, buf_size);
onnx::ModelProto model_;
if (!model_.ParseFromString(str)) {
MS_LOG(EXCEPTION) << "Parse model from buffer fail!";
}
MSANFModelParser model_parser;
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
return dstgraph_ptr;
}
} // namespace lite
} // namespace mindspore

View File

@ -1,693 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/load_onnx/anf_model_parser.h"
#include <functional>
#include <map>
#include <memory>
#include <stack>
#include <string>
#include <vector>
#include <unordered_map>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "ir/tensor.h"
#include "ir/param_info.h"
#include "frontend/operator/ops.h"
#include "abstract/abstract_value.h"
#include "proto/onnx.pb.h"
#include "utils/log_adapter.h"
#include "utils/shape_utils.h"
using std::string;
namespace mindspore {
namespace lite {
static constexpr char kConstantValueNode[] = "Constant";
static constexpr char kCNodeShapeAttr[] = "shape";
static constexpr char kCNodeShape1Attr[] = "shape1";
static constexpr char kCNodeShape2Attr[] = "shape2";
enum ParseForm : int {
FORM_PARSE_TYPE = 0,
FORM_PARSE_SCALAR = 1,
FORM_PARSE_TENSOR = 2,
};
static std::map<std::string, ParseForm> kParseTypeSwitchMap{
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}};
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
};
template <typename T, typename P>
std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<string, P> &kv) {
std::stack<std::string> rules;
std::stack<P> value;
int count = 0;
for (size_t i = 0; i < str.length(); i++) {
if (str[i] == '[') {
rules.push("[");
} else if (str[i] == ']') {
// rules
std::vector<P> vec;
while (rules.top() != "[") {
rules.pop();
vec.push_back(value.top());
value.pop();
}
// pop "["
rules.pop();
// make tuple for names
std::string res = "dummy";
// make tuple for values
reverse(vec.begin(), vec.end());
auto vt = std::make_shared<T>(vec);
if (rules.empty() && value.empty()) {
return vt;
}
rules.push(res);
value.push(vt);
} else if (str[i] == ',') {
continue;
} else {
count++;
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
auto value_name = str.substr(i - count + 1, count);
value.push(kv.at(value_name));
rules.push(value_name);
count = 0;
}
}
}
return {};
}
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
const std::unordered_map<string, ValuePtr> &kv) {
std::string str = attr_name;
auto replace = [&](const string &orgStr, const string &newStr) {
std::string::size_type pos(0);
while ((pos = str.find(orgStr)) != std::string::npos) {
str.replace(pos, orgStr.length(), newStr);
}
return str;
};
// remove "scalar:"
str = replace("scalar:", "");
// remove "Tuple"
str = replace("Tuple", "");
// remove "List"
str = replace("List", "");
auto result = ParserAttr<ValueTuple>(str, kv);
if (!result) {
return {};
}
return result;
}
std::shared_ptr<abstract::AbstractTuple> ParserAttrShape(
const std::string &attr_name, const std::unordered_map<string, abstract::AbstractBasePtr> &kv) {
std::string str = attr_name;
auto replace = [&](const string &orgStr, const string &newStr) {
std::string::size_type pos(0);
while ((pos = str.find(orgStr)) != std::string::npos) {
str.replace(pos, orgStr.length(), newStr);
}
return str;
};
// remove "scalar:"
str = replace("shape:", "");
// remove "Tuple"
str = replace("Tuple", "");
// remove "List"
str = replace("List", "");
auto result = ParserAttr<abstract::AbstractTuple>(str, kv);
if (!result) {
return {};
}
return result;
}
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
return MakeValue<valuetype>(value); \
}
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
PARSE_ONNXATTR_IN_SCALAR_FORM(float, float)
PARSE_ONNXATTR_IN_SCALAR_FORM(string, string)
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32)
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) {
MS_EXCEPTION_IF_NULL(node);
if (!value_proto.has_type() || !value_proto.has_name()) {
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
return false;
}
node->set_name(value_proto.name());
const auto &type_proto = value_proto.type();
if (!type_proto.has_tensor_type()) {
MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! ";
return false;
}
const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type();
if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) {
MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! ";
return false;
}
const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape();
ShapeVector shape;
for (int i = 0; i < tensor_shape.dim_size(); ++i) {
shape.push_back(tensor_shape.dim(i).dim_value());
}
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!";
return false;
}
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
node->set_abstract(tensor_abstract);
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()];
std::string initial_data = initialize_proto.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
MS_EXCEPTION_IF_NULL(tensor_data_buf);
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size());
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
node->set_default_param(tensor_info);
}
anfnode_build_map_[value_proto.name()] = node;
return true;
}
bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size();
for (int i = 0; i < importProto.initializer_size(); ++i) {
const onnx::TensorProto &initializer_proto = importProto.initializer(i);
if (!initializer_proto.has_name()) {
MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i;
return false;
}
default_para_map_[initializer_proto.name()] = initializer_proto;
}
MS_LOG(INFO) << "all parameters size: " << importProto.input_size();
for (int i = 0; i < importProto.input_size(); ++i) {
const onnx::ValueInfoProto &input_proto = importProto.input(i);
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
return false;
}
}
return true;
}
bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
const int attr_tensor_type = attr_tensor.data_type();
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
return false;
}
prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
return true;
}
ValuePtr MSANFModelParser::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
switch (attr_tensor_type) {
case onnx::TensorProto_DataType_STRING: {
return ParseAttrInScalar_string_string(attr_tensor);
}
case onnx::TensorProto_DataType_INT32: {
return ParseAttrInScalar_int32_int32(attr_tensor);
}
case onnx::TensorProto_DataType_INT64: {
return ParseAttrInScalar_int64_int64(attr_tensor);
}
case onnx::TensorProto_DataType_UINT64: {
return ParseAttrInScalar_uint64_uint64(attr_tensor);
}
case onnx::TensorProto_DataType_FLOAT: {
return ParseAttrInScalar_float_float(attr_tensor);
}
case onnx::TensorProto_DataType_DOUBLE: {
return ParseAttrInScalar_double_double(attr_tensor);
}
case onnx::TensorProto_DataType_BOOL: {
return ParseAttrInScalar_int32_bool(attr_tensor);
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
return {};
}
return {};
}
bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
MS_LOG(ERROR) << "parse attr type don't support attr type is tensor";
return false;
}
bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
MS_EXCEPTION_IF_NULL(prim);
const std::string &attr_name = attr_proto.name();
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
string type;
std::size_t pos(0);
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("type:").length() - 1);
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
}
std::unordered_map<std::string, ValuePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
switch (kParseTypeSwitchMap[type]) {
case FORM_PARSE_TYPE: {
ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
break;
}
case FORM_PARSE_SCALAR: {
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
break;
}
case FORM_PARSE_TENSOR: {
ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
break;
}
default:
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return false;
}
}
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
if (kv.size() == 1) {
auto iter = kv.begin();
prim->AddAttr(attr_name, iter->second);
} else {
auto res = ParserScalarAttrValue(ref_attr_name, kv);
prim->AddAttr(attr_name, res);
}
}
return true;
}
bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
ShapeVector shape;
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
auto new_value_node = NewValueNode(MakeValue(tensor_info));
MS_EXCEPTION_IF_NULL(new_value_node);
auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
new_value_node->set_abstract(tensor_abstract);
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool MSANFModelParser::ObtainValueNodeInScalarForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
ValuePtr value_ptr = nullptr;
switch (attr_tensor_type) {
case onnx::TensorProto_DataType_INT32: {
std::vector<int64_t> add_data;
for (int i = 0; i < attr_tensor.int32_data_size(); ++i) {
add_data.push_back(attr_tensor.int32_data(i));
}
if (add_data.size() == 1) {
value_ptr = MakeValue(add_data[0]);
} else if (!add_data.empty()) {
value_ptr = MakeValue<std::vector<int64_t>>(add_data);
}
break;
}
case onnx::TensorProto_DataType_FLOAT: {
std::vector<float> add_data;
for (int i = 0; i < attr_tensor.float_data_size(); ++i) {
add_data.push_back(attr_tensor.float_data(i));
}
if (add_data.size() == 1) {
value_ptr = MakeValue(add_data[0]);
} else if (!add_data.empty()) {
value_ptr = MakeValue<std::vector<float>>(add_data);
}
break;
}
case onnx::TensorProto_DataType_UNDEFINED: {
std::vector<ValuePtr> elems;
value_ptr = std::make_shared<ValueTuple>(elems);
break;
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
return false;
}
auto new_value_node = NewValueNode(value_ptr);
MS_EXCEPTION_IF_NULL(new_value_node);
new_value_node->set_abstract(value_ptr->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
return false;
}
auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
new_value_node->set_abstract(abs_type);
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name,
const onnx::AttributeProto &attr_proto) {
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
string type;
std::size_t pos(0);
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("type:").length() - 1);
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
}
std::unordered_map<std::string, ValuePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
auto attr_name = attr_tensor.name();
switch (kParseTypeSwitchMap[type]) {
case FORM_PARSE_TYPE: {
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
}
case FORM_PARSE_SCALAR: {
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
break;
}
case FORM_PARSE_TENSOR: {
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
}
default:
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return false;
}
}
ValueNodePtr new_value_node;
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
if (kv.size() == 1) {
auto iter = kv.begin();
new_value_node = NewValueNode(iter->second);
new_value_node->set_abstract(iter->second->ToAbstract());
} else {
auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv);
new_value_node = NewValueNode(value_ptr);
new_value_node->set_abstract(value_ptr->ToAbstract());
}
anfnode_build_map_[value_node_name] = new_value_node;
}
return true;
}
bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
const std::string &value_node_name = node_proto.output(0);
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
return false;
}
return GetAttrValueForValueNode(value_node_name, attr_proto);
}
std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode(
const onnx::AttributeProto &attr_proto) {
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); ++i) {
ShapeVector shape_vec;
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
shape_vec.push_back(attr_tensor.dims(j));
}
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
MS_EXCEPTION_IF_NULL(tensor_info);
auto abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(abstract);
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract));
}
return kv;
}
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::NodeProto &node_proto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
if (!node_proto.has_op_type()) {
MS_LOG(ERROR) << "Get CNode op_type failed!";
return nullptr;
}
const std::string &node_name = node_proto.output(0);
const std::string &fullname_with_scope = node_proto.domain();
const std::string &node_type = node_proto.op_type();
PrimitivePtr prim = std::make_shared<Primitive>(node_type);
MS_EXCEPTION_IF_NULL(prim);
prim->set_instance_name(node_type);
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
string shape_ref_attr_name;
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
shape_ref_attr_name = attr_proto.ref_attr_name();
kv = GetAbstractForCNode(attr_proto);
continue;
}
if (!GetAttrValueForCNode(prim, attr_proto)) {
MS_LOG(ERROR) << "Get CNode attr failed!";
return nullptr;
}
}
std::vector<AnfNodePtr> inputs;
inputs.clear();
inputs.push_back(NewValueNode(prim));
for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
return nullptr;
}
inputs.push_back(anfnode_build_map_[input_name]);
}
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
if (0 == kv.size()) {
AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
elem.push_back(cnode_ptr->input(index)->abstract());
}
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (1 == kv.size()) {
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
cnode_ptr->set_abstract(iter->second);
} else {
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
cnode_ptr->set_abstract(abstract);
}
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
anfnode_build_map_[node_name] = cnode_ptr;
return cnode_ptr;
}
bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_EXCEPTION_IF_NULL(cnode_ptr);
std::vector<AnfNodePtr> inputs;
if (importProto.output_size() > 1) {
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
AbstractBasePtrList elem;
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
const onnx::ValueInfoProto &output_node = importProto.output(out_size);
const std::string &out_tuple = output_node.name();
inputs.push_back(anfnode_build_map_[out_tuple]);
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
}
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
inputs.push_back(maketuple_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
} else {
const onnx::ValueInfoProto &output_node = importProto.output(0);
const onnx::TypeProto &output_typeproto = output_node.type();
ShapeVector output_shape;
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) {
output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value());
}
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
inputs.push_back(cnode_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
}
return true;
}
bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
CNodePtr cnode_ptr = nullptr;
for (int i = 0; i < importProto.node_size(); ++i) {
const onnx::NodeProto &node_proto = importProto.node(i);
const std::string &node_type = node_proto.op_type();
if (node_type == kConstantValueNode) {
if (!BuildValueNodeForFuncGraph(node_proto)) {
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
return false;
}
continue;
}
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
if (cnode_ptr == nullptr) {
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
return false;
}
}
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
}
bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
MS_EXCEPTION_IF_NULL(debug_info_ptr);
if (importProto.has_name()) {
debug_info_ptr->set_name(importProto.name());
} else {
MS_LOG(ERROR) << "FuncGraph under converting has not name!";
}
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
return false;
}
return ImportNodesForGraph(outputFuncGraph, importProto);
}
bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
if (!model_proto.has_producer_name()) {
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
return false;
}
producer_name_ = model_proto.producer_name();
MS_LOG(INFO) << "producer_name :" << producer_name_;
if (!model_proto.has_model_version()) {
MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
return false;
}
model_version_ = model_proto.model_version();
MS_LOG(INFO) << "producer_version : " << model_version_;
if (!model_proto.has_ir_version()) {
MS_LOG(ERROR) << "Parse model version from pb file failed!";
return false;
}
ir_version_ = model_proto.ir_version();
MS_LOG(INFO) << "ir_version :" << ir_version_;
return true;
}
FuncGraphPtr MSANFModelParser::Parse(const onnx::ModelProto &model_proto) {
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
MS_EXCEPTION_IF_NULL(dstGraph);
if (!MSANFParseModelConfigureInfo(model_proto)) {
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
}
const onnx::GraphProto &graphBuild = model_proto.graph();
if (!BuildFuncGraph(dstGraph, graphBuild)) {
MS_LOG(ERROR) << "Build funcgraph failed!";
return nullptr;
}
MS_LOG(INFO) << "Parse pb to build FuncGraph Success!";
return dstGraph;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,119 @@
syntax = "proto2";
package mind_ir;
message AttributeProto {
enum AttributeType {
UNDEFINED = 0;
FLOAT = 1;
UINT8 = 2;
INT8 = 3;
UINT16 = 4;
INT16 = 5;
INT32 = 6;
INT64 = 7;
STRING = 8;
BOOL = 9;
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14;
COMPLEX128 = 15;
BFLOAT16 = 16;
TENSOR = 17;
GRAPH = 18;
TENSORS = 19;
}
optional string name = 1;
optional float f = 2;
optional int64 i = 3;
optional double d = 4;
optional bytes s = 5;
optional TensorProto t = 6;
optional GraphProto g = 7;
repeated float floats = 8;
repeated double doubles = 9;
repeated int64 ints = 10;
repeated bytes strings = 11;
repeated TensorProto tensors = 12;
repeated GraphProto graphs = 13;
optional string doc_string = 14;
optional string ref_attr_name = 15;
optional AttributeType type = 16;
}
message ValueInfoProto {
optional string name = 1;
repeated TensorProto tensor = 2;
optional string doc_string = 3;
optional string denotation = 4;
}
message NodeProto {
repeated string input = 1;
repeated string output = 2;
optional string name = 3;
optional string op_type = 4;
repeated AttributeProto attribute = 5;
optional string doc_string = 6;
optional string domain = 7;
}
message ModelProto {
optional string ir_version = 1;
optional string producer_name = 2;
optional string producer_version = 3;
optional string domain = 4;
optional string model_version = 5;
optional string doc_string = 6;
optional GraphProto graph = 7;
}
message GraphProto {
repeated NodeProto node = 1;
optional string name = 2;
repeated TensorProto parameter = 3;
optional string doc_string = 4;
repeated ValueInfoProto input = 5;
repeated ValueInfoProto output = 6;
}
message TensorProto {
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14;
COMPLEX128 = 15;
BFLOAT16 = 16;
FLOAT64 = 17;
}
repeated int64 dims = 1;
optional int32 data_type = 2;
repeated float float_data = 3;
repeated int32 int32_data = 4;
repeated bytes string_data = 5;
repeated int64 int64_data = 6;
optional string name = 7;
optional string doc_string = 8;
optional bytes raw_data = 9;
repeated double double_data = 10;
repeated uint64 uint64_data = 11;
}

View File

@ -15,6 +15,7 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"c_ops/*.cc"
"ir/*.cc"
"utils/*.cc"
"load_mindir/*.cc"
)
set_property(SOURCE ${CORE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_CORE)
add_library(mindspore_core STATIC ${CORE_SRC_LIST})

View File

@ -50,4 +50,5 @@ AbstractBasePtr TensorAddInfer(const abstract::AnalysisEnginePtr &, const Primit
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(TensorAdd, prim::kPrimTensorAdd, TensorAddInfer);
REGISTER_PRIMITIVE_C(TensorAdd);
} // namespace mindspore

View File

@ -102,4 +102,5 @@ AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const Primitiv
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer);
REGISTER_PRIMITIVE_C(AvgPool);
} // namespace mindspore

View File

@ -193,4 +193,5 @@ AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const Primitive
Conv2dInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
REGISTER_PRIMITIVE_C(Conv2D);
} // namespace mindspore

View File

@ -33,7 +33,7 @@ constexpr auto kPad = "pad";
constexpr auto kPads = "pads";
constexpr auto kMode = "mode";
constexpr auto kGroup = "group";
constexpr auto kOutputChannel = "output_channel";
constexpr auto kOutputChannel = "out_channel";
constexpr auto kPadList = "pad_list";
constexpr auto kAxis = "axis";

View File

@ -31,4 +31,13 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) {
auto infer_function = iter->second.impl_;
return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list);
}
OpPrimCRegister &OpPrimCRegister::GetInstance() {
static OpPrimCRegister instance;
return instance;
}
std::map<std::string, OpPrimCDefineFunc> OpPrimCRegister::GetPrimCMap() { return op_primc_fns_; }
void OpPrimCRegister::SetPrimCMap(const std::string &name, const OpPrimCDefineFunc &fn) { op_primc_fns_[name] = fn; }
} // namespace mindspore

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
#include <string>
#include <vector>
#include <map>
#include <memory>
#include "ir/primitive.h"
#include "abstract/primitive_infer_map.h"
#include "ir/value.h"
@ -32,5 +34,33 @@ class PrimitiveC : public Primitive {
protected:
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name);
};
using OpPrimCDefineFunc = std::function<std::shared_ptr<PrimitiveC>()>;
class OpPrimCRegister {
public:
~OpPrimCRegister() {}
static OpPrimCRegister &GetInstance();
std::map<std::string, OpPrimCDefineFunc> GetPrimCMap();
void SetPrimCMap(const std::string &name, const OpPrimCDefineFunc &fn);
private:
OpPrimCRegister() {}
std::map<std::string, OpPrimCDefineFunc> op_primc_fns_;
};
class OpPrimCRegisterHelper {
public:
OpPrimCRegisterHelper(const std::string &name, const OpPrimCDefineFunc &fn) {
OpPrimCRegister::GetInstance().SetPrimCMap(name, fn);
}
~OpPrimCRegisterHelper() = default;
};
#define REGISTER_PRIMITIVE_C(name) \
std::shared_ptr<PrimitiveC> GetDefaultPrimC##name() { \
auto out = std::make_shared<name>(); \
return out; \
} \
OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name);
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_

View File

@ -49,4 +49,5 @@ AbstractBasePtr Relu6Infer(const abstract::AnalysisEnginePtr &, const PrimitiveP
Relu6InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Relu6, prim::kPrimRelu6, Relu6Infer);
REGISTER_PRIMITIVE_C(Relu6);
} // namespace mindspore

View File

@ -41,4 +41,5 @@ AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const Primitiv
}
REGISTER_PRIMITIVE_EVAL_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer);
REGISTER_PRIMITIVE_C(Reshape);
} // namespace mindspore

View File

@ -36,7 +36,7 @@ class Reshape : public PrimitiveC {
AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimTensorAddPtr = std::shared_ptr<Reshape>;
using PrimReshapePtr = std::shared_ptr<Reshape>;
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_RESHAPE_H_

View File

@ -75,4 +75,5 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv
}
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer);
REGISTER_PRIMITIVE_C(Softmax);
} // namespace mindspore

View File

@ -76,4 +76,5 @@ AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const Primitiv
}
REGISTER_PRIMITIVE_EVAL_IMPL(Squeeze, prim::kPrimSqueeze, SqueezeInfer);
REGISTER_PRIMITIVE_C(Squeeze);
} // namespace mindspore

View File

@ -0,0 +1,854 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "load_mindir/anf_model_parser.h"
#include <functional>
#include <map>
#include <memory>
#include <stack>
#include <string>
#include <vector>
#include <unordered_map>
#include <utility>
#include "ir/tensor.h"
#include "ir/param_info.h"
#include "c_ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/log_adapter.h"
#include "utils/shape_utils.h"
using std::string;
namespace mindspore {
static constexpr char kConstantValueNode[] = "Constant";
static constexpr char kCNodeShapeAttr[] = "shape";
static constexpr char kCNodeShape1Attr[] = "shape1";
static constexpr char kCNodeShape2Attr[] = "shape2";
enum ParseForm : int {
FORM_PARSE_TYPE = 0,
FORM_PARSE_SCALAR = 1,
FORM_PARSE_TENSOR = 2,
FORM_PARSE_NONE = 3,
FORM_PARSE_UNDEFINE = 4,
};
static std::map<std::string, ParseForm> kParseTypeSwitchMap{{"type", FORM_PARSE_TYPE},
{"scalar", FORM_PARSE_SCALAR},
{"tensor", FORM_PARSE_TENSOR},
{"none", FORM_PARSE_NONE},
{"", FORM_PARSE_UNDEFINE}};
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
{mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool},
{mind_ir::TensorProto_DataType_INT8, kNumberTypeInt8},
{mind_ir::TensorProto_DataType_INT16, kNumberTypeInt16},
{mind_ir::TensorProto_DataType_INT32, kNumberTypeInt32},
{mind_ir::TensorProto_DataType_INT64, kNumberTypeInt64},
{mind_ir::TensorProto_DataType_UINT8, kNumberTypeUInt8},
{mind_ir::TensorProto_DataType_UINT16, kNumberTypeUInt16},
{mind_ir::TensorProto_DataType_UINT32, kNumberTypeUInt32},
{mind_ir::TensorProto_DataType_UINT64, kNumberTypeUInt64},
{mind_ir::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
{mind_ir::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
{mind_ir::TensorProto_DataType_FLOAT64, kNumberTypeFloat64},
{mind_ir::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
{mind_ir::TensorProto_DataType_STRING, kObjectTypeString},
};
template <typename T, typename P>
std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<string, P> &kv) {
std::stack<std::string> rules;
std::stack<P> value;
int count = 0;
for (size_t i = 0; i < str.length(); i++) {
if (str[i] == '[') {
rules.push("[");
} else if (str[i] == ']') {
// rules
std::vector<P> vec;
while (rules.top() != "[") {
rules.pop();
vec.push_back(value.top());
value.pop();
}
// pop "["
rules.pop();
// make tuple for names
std::string res = "dummy";
// make tuple for values
reverse(vec.begin(), vec.end());
auto vt = std::make_shared<T>(vec);
if (rules.empty() && value.empty()) {
return vt;
}
rules.push(res);
value.push(vt);
} else if (str[i] == ',') {
continue;
} else {
count++;
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
auto value_name = str.substr(i - count + 1, count);
value.push(kv.at(value_name));
rules.push(value_name);
count = 0;
}
}
}
return {};
}
template <typename T>
std::shared_ptr<T> ParserScalarAttrValue(const std::string &attr_name, const std::unordered_map<string, ValuePtr> &kv) {
std::string str = attr_name;
auto replace = [&](const string &orgStr, const string &newStr) {
std::string::size_type pos(0);
while ((pos = str.find(orgStr)) != std::string::npos) {
str.replace(pos, orgStr.length(), newStr);
}
return str;
};
// remove "scalar:"
str = replace("scalar:", "");
// remove "Tuple"
str = replace("Tuple", "");
// remove "List"
str = replace("List", "");
auto result = ParserAttr<T>(str, kv);
return result;
}
std::shared_ptr<abstract::AbstractTuple> ParserAttrShape(
const std::string &attr_name, const std::unordered_map<string, abstract::AbstractBasePtr> &kv) {
std::string str = attr_name;
auto replace = [&](const string &orgStr, const string &newStr) {
std::string::size_type pos(0);
while ((pos = str.find(orgStr)) != std::string::npos) {
str.replace(pos, orgStr.length(), newStr);
}
return str;
};
// remove "scalar:"
str = replace("shape:", "");
// remove "Tuple"
str = replace("Tuple", "");
// remove "List"
str = replace("List", "");
auto result = ParserAttr<abstract::AbstractTuple>(str, kv);
return result;
}
std::string ParseParameterName(const string &name) {
string delimiter = ":";
size_t pos(0);
if ((pos = name.find(delimiter)) != string::npos) {
return name.substr(pos + 1, string::npos - (pos + 1));
}
return name;
}
std::string ParseCNodeName(const string &name) {
string delimiter = ":";
size_t pos = name.find(delimiter);
size_t end_pos = name.find_last_of(delimiter);
if (pos != string::npos && end_pos != string::npos && pos != end_pos) {
return name.substr(pos + 1, end_pos - (pos + 1));
}
return name;
}
#define PARSE_MINDIR_ATTR_IN_INT_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
auto value = static_cast<valuetype>(attr_proto.ints(index)); \
return MakeValue<valuetype>(value); \
} \
ValuePtr ParseAttrInSingleScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto) { \
auto value = static_cast<valuetype>(attr_proto.i()); \
return MakeValue<valuetype>(value); \
}
#define PARSE_MINDIR_ATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
auto value = static_cast<valuetype>(attr_proto.type##s(index)); \
return MakeValue<valuetype>(value); \
}
PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t, int8_t)
PARSE_MINDIR_ATTR_IN_INT_FORM(int16_t, int16_t)
PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, int32_t)
PARSE_MINDIR_ATTR_IN_INT_FORM(int64_t, int64_t)
PARSE_MINDIR_ATTR_IN_INT_FORM(uint8_t, uint8_t)
PARSE_MINDIR_ATTR_IN_INT_FORM(uint16_t, uint16_t)
PARSE_MINDIR_ATTR_IN_INT_FORM(uint32_t, uint32_t)
PARSE_MINDIR_ATTR_IN_INT_FORM(uint64_t, uint64_t)
PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, bool)
PARSE_MINDIR_ATTR_IN_SCALAR_FORM(double, double)
PARSE_MINDIR_ATTR_IN_SCALAR_FORM(float, float)
PARSE_MINDIR_ATTR_IN_SCALAR_FORM(string, string)
ValuePtr ParseAttrInSingleScalar_string_string(const mind_ir::AttributeProto &attr_proto) {
auto value = static_cast<string>(attr_proto.s());
return MakeValue<string>(value);
}
ValuePtr ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto &attr_proto) {
auto value = static_cast<float>(attr_proto.f());
return MakeValue<float>(value);
}
ValuePtr ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto &attr_proto) {
auto value = static_cast<double>(attr_proto.d());
return MakeValue<double>(value);
}
tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto) {
ShapeVector shape;
for (int i = 0; i < tensor_proto.dims_size(); ++i) {
shape.push_back(tensor_proto.dims(i));
}
if (!tensor_proto.has_data_type()) {
MS_LOG(ERROR) << "mind_ir TensorProto has no data_type or name!";
return nullptr;
}
if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "mind_ir TensorProto data_type is not support yet!";
return nullptr;
}
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
return tensor_info;
}
bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
const mind_ir::TensorProto &parameter_proto) {
MS_EXCEPTION_IF_NULL(node);
if (!parameter_proto.has_name()) {
MS_LOG(ERROR) << "mind_ir TensorProto has no name!";
return false;
}
string debug_info_name = ParseParameterName(parameter_proto.name());
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
node->set_debug_info(debug_info_ptr);
node->set_name(parameter_proto.name());
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
node->set_abstract(tensor_abstract);
std::string initial_data = parameter_proto.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
MS_EXCEPTION_IF_NULL(tensor_data_buf);
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size());
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error for build parameter, errorno " << ret;
}
node->set_default_param(tensor_info);
anfnode_build_map_[parameter_proto.name()] = node;
return true;
}
bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto) {
MS_EXCEPTION_IF_NULL(node);
if (!value_proto.has_name()) {
MS_LOG(ERROR) << "mind_ir ValueInfoProto has no name!";
return false;
}
string debug_info_name = ParseParameterName(value_proto.name());
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
node->set_debug_info(debug_info_ptr);
node->set_name(value_proto.name());
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
node->set_abstract(tensor_abstract);
anfnode_build_map_[value_proto.name()] = node;
return true;
}
bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
for (int i = 0; i < importProto.parameter_size(); ++i) {
const mind_ir::TensorProto &parameter_proto = importProto.parameter(i);
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
return false;
}
}
MS_LOG(INFO) << "All inputs size is: " << importProto.input_size();
for (int i = 0; i < importProto.input_size(); ++i) {
const mind_ir::ValueInfoProto &input_proto = importProto.input(i);
if (!BuildInputForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
MS_LOG(ERROR) << "Build input for funcgraph fail at index: " << i;
return false;
}
}
return true;
}
bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
MS_EXCEPTION_IF_NULL(prim);
const int attr_tensor_type = attr_proto.tensors(0).data_type();
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
return false;
}
prim->AddAttr(attr_proto.name(), TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
return true;
}
ValuePtr MSANFModelParser::ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index) {
const int attr_type = attr_proto.type();
switch (attr_type) {
case mind_ir::AttributeProto_AttributeType_STRING: {
return ParseAttrInScalar_string_string(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_INT8: {
return ParseAttrInScalar_int8_t_int8_t(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_INT16: {
return ParseAttrInScalar_int16_t_int16_t(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_INT32: {
return ParseAttrInScalar_int32_t_int32_t(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_INT64: {
return ParseAttrInScalar_int64_t_int64_t(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_UINT8: {
return ParseAttrInScalar_uint8_t_uint8_t(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_UINT16: {
return ParseAttrInScalar_uint16_t_uint16_t(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_UINT32: {
return ParseAttrInScalar_uint32_t_uint32_t(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_UINT64: {
return ParseAttrInScalar_uint64_t_uint64_t(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_FLOAT: {
return ParseAttrInScalar_float_float(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_DOUBLE: {
return ParseAttrInScalar_double_double(attr_proto, index);
}
case mind_ir::AttributeProto_AttributeType_BOOL: {
return ParseAttrInScalar_int32_t_bool(attr_proto, index);
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
return {};
}
return {};
}
void MSANFModelParser::ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
std::unordered_map<std::string, ValuePtr> *multi_value_map) {
string name;
for (int i = 0; i < attr_proto.ints_size(); i++) {
auto res = ParseAttrInScalarForm(attr_proto, i);
name = "value" + std::to_string(i + 1);
multi_value_map->insert(std::pair<string, ValuePtr>(name, res));
}
for (int i = 0; i < attr_proto.doubles_size(); i++) {
auto res = ParseAttrInScalarForm(attr_proto, i);
name = "value" + std::to_string(i + 1);
multi_value_map->insert(std::pair<string, ValuePtr>(name, res));
}
for (int i = 0; i < attr_proto.floats_size(); i++) {
auto res = ParseAttrInScalarForm(attr_proto, i);
name = "value" + std::to_string(i + 1);
multi_value_map->insert(std::pair<string, ValuePtr>(name, res));
}
for (int i = 0; i < attr_proto.strings_size(); i++) {
auto res = ParseAttrInScalarForm(attr_proto, i);
name = "value" + std::to_string(i + 1);
multi_value_map->insert(std::pair<string, ValuePtr>(name, res));
}
}
ValuePtr MSANFModelParser::ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) {
const int attr_type = attr_proto.type();
switch (attr_type) {
case mind_ir::AttributeProto_AttributeType_STRING: {
return ParseAttrInSingleScalar_string_string(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_INT8: {
return ParseAttrInSingleScalar_int8_t_int8_t(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_INT16: {
return ParseAttrInSingleScalar_int16_t_int16_t(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_INT32: {
return ParseAttrInSingleScalar_int32_t_int32_t(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_INT64: {
return ParseAttrInSingleScalar_int64_t_int64_t(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_UINT8: {
return ParseAttrInSingleScalar_uint8_t_uint8_t(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_UINT16: {
return ParseAttrInSingleScalar_uint16_t_uint16_t(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_UINT32: {
return ParseAttrInSingleScalar_uint32_t_uint32_t(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_UINT64: {
return ParseAttrInSingleScalar_uint64_t_uint64_t(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_FLOAT: {
return ParseAttrInSingleScalar_float_float(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_DOUBLE: {
return ParseAttrInSingleScalar_double_double(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_BOOL: {
return ParseAttrInSingleScalar_int32_t_bool(attr_proto);
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
return {};
}
return {};
}
bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
const mind_ir::AttributeProto &attr_proto) {
MS_EXCEPTION_IF_NULL(prim);
const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0);
const int attr_tensor_type = attr_tensor.data_type();
ShapeVector shape;
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
prim->AddAttr(attr_proto.name(), MakeValue(tensor_info));
return true;
}
bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
MS_EXCEPTION_IF_NULL(prim);
const std::string &attr_name = attr_proto.name();
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
string type = "";
std::size_t pos(0);
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("type:").length() - 1);
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
}
std::unordered_map<std::string, ValuePtr> multi_value_map;
switch (kParseTypeSwitchMap[type]) {
case FORM_PARSE_TYPE: {
ObtainCNodeAttrInTypeForm(prim, attr_proto);
break;
}
case FORM_PARSE_SCALAR: {
std::size_t value_pos(0);
if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) {
auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
prim->AddAttr(attr_name, res);
break;
}
ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
break;
}
case FORM_PARSE_TENSOR: {
ObtainCNodeAttrInTensorForm(prim, attr_proto);
break;
}
default:
MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
return false;
}
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) {
auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
prim->AddAttr(attr_name, value_tuple_ptr);
} else {
auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
prim->AddAttr(attr_name, value_list_ptr);
}
}
return true;
}
bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
const mind_ir::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
ShapeVector shape;
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
auto new_value_node = NewValueNode(MakeValue(tensor_info));
MS_EXCEPTION_IF_NULL(new_value_node);
auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
new_value_node->set_abstract(tensor_abstract);
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name,
const mind_ir::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
return false;
}
auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
new_value_node->set_abstract(abs_type);
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_name,
const mind_ir::AttributeProto &attr_proto) {
auto new_value_node = NewValueNode(kNone);
MS_EXCEPTION_IF_NULL(new_value_node);
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name,
const mind_ir::AttributeProto &attr_proto) {
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
string type = "";
std::size_t pos(0);
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("type:").length() - 1);
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
} else if (ref_attr_name == "none") {
type = ref_attr_name;
}
ValueNodePtr new_value_node;
std::unordered_map<std::string, ValuePtr> multi_value_map;
switch (kParseTypeSwitchMap[type]) {
case FORM_PARSE_TYPE: {
ObtainValueNodeInTypeForm(value_node_name, attr_proto.tensors(0));
break;
}
case FORM_PARSE_SCALAR: {
std::size_t value_pos(0);
if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) {
auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
new_value_node = NewValueNode(res);
new_value_node->set_abstract(res->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
break;
}
ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
break;
}
case FORM_PARSE_TENSOR: {
ObtainValueNodeInTensorForm(value_node_name, attr_proto.tensors(0));
break;
}
case FORM_PARSE_NONE: {
ObtainValueNodeInNoneForm(value_node_name, attr_proto);
break;
}
default:
MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
return false;
}
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) {
auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
new_value_node = NewValueNode(value_tuple_ptr);
new_value_node->set_abstract(value_tuple_ptr->ToAbstract());
} else {
auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
new_value_node = NewValueNode(value_list_ptr);
new_value_node->set_abstract(value_list_ptr->ToAbstract());
}
anfnode_build_map_[value_node_name] = new_value_node;
}
return true;
}
bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto) {
const std::string &value_node_name = node_proto.output(0);
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(0);
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
return false;
}
return GetAttrValueForValueNode(value_node_name, attr_proto);
}
std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode(
const mind_ir::AttributeProto &attr_proto) {
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); ++i) {
ShapeVector shape_vec;
const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i);
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
shape_vec.push_back(attr_tensor.dims(j));
}
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
MS_EXCEPTION_IF_NULL(tensor_info);
auto abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(abstract);
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract));
}
return kv;
}
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::NodeProto &node_proto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
if (!node_proto.has_op_type()) {
MS_LOG(ERROR) << "Get CNode op_type failed!";
return nullptr;
}
const std::string &node_name = node_proto.output(0);
const std::string &fullname_with_scope = node_proto.domain();
const std::string &node_type = node_proto.op_type();
std::shared_ptr<Primitive> prim;
auto op_primc_fns = OpPrimCRegister::GetInstance().GetPrimCMap();
if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
prim = op_primc_fns[node_type]();
} else {
prim = std::make_shared<Primitive>(node_type);
prim->set_instance_name(node_type);
}
MS_EXCEPTION_IF_NULL(prim);
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
string shape_ref_attr_name;
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
shape_ref_attr_name = attr_proto.ref_attr_name();
kv = GetAbstractForCNode(attr_proto);
continue;
}
if (!GetAttrValueForCNode(prim, attr_proto)) {
MS_LOG(ERROR) << "Get CNode attr failed!";
return nullptr;
}
}
std::vector<AnfNodePtr> inputs;
inputs.clear();
for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
return nullptr;
}
inputs.push_back(anfnode_build_map_[input_name]);
}
auto cnode_ptr = outputFuncGraph->NewCNode(prim, inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
if (0 == kv.size()) {
AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
elem.push_back(cnode_ptr->input(index)->abstract());
}
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (1 == kv.size()) {
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
cnode_ptr->set_abstract(iter->second);
} else {
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
cnode_ptr->set_abstract(abstract);
}
string debug_info_name = ParseCNodeName(node_name);
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
cnode_ptr->set_debug_info(debug_info_ptr);
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
anfnode_build_map_[node_name] = cnode_ptr;
return cnode_ptr;
}
bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::GraphProto &importProto, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_EXCEPTION_IF_NULL(cnode_ptr);
std::vector<AnfNodePtr> inputs;
if (importProto.output_size() > 1) {
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
AbstractBasePtrList elem;
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
const mind_ir::ValueInfoProto &output_node = importProto.output(out_size);
const std::string &out_tuple = output_node.name();
inputs.push_back(anfnode_build_map_[out_tuple]);
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
}
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
inputs.push_back(maketuple_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
} else {
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
inputs.push_back(cnode_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
}
return true;
}
bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
CNodePtr cnode_ptr = nullptr;
for (int i = 0; i < importProto.node_size(); ++i) {
const mind_ir::NodeProto &node_proto = importProto.node(i);
const std::string &node_type = node_proto.op_type();
if (node_type == kConstantValueNode) {
if (!BuildValueNodeForFuncGraph(node_proto)) {
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
return false;
}
continue;
}
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
if (cnode_ptr == nullptr) {
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
return false;
}
}
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
}
bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
MS_EXCEPTION_IF_NULL(debug_info_ptr);
if (importProto.has_name()) {
debug_info_ptr->set_name(importProto.name());
} else {
MS_LOG(ERROR) << "FuncGraph under converting has not name!";
}
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
MS_LOG(ERROR) << "import parameters for graph fail!";
return false;
}
return ImportNodesForGraph(outputFuncGraph, importProto);
}
bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto) {
if (!model_proto.has_producer_name()) {
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
return false;
}
producer_name_ = model_proto.producer_name();
MS_LOG(INFO) << "producer_name :" << producer_name_;
if (!model_proto.has_model_version()) {
MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
return false;
}
model_version_ = model_proto.model_version();
MS_LOG(INFO) << "producer_version : " << model_version_;
if (!model_proto.has_ir_version()) {
MS_LOG(ERROR) << "Parse model version from pb file failed!";
return false;
}
ir_version_ = model_proto.ir_version();
MS_LOG(INFO) << "ir_version :" << ir_version_;
return true;
}
FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
MS_EXCEPTION_IF_NULL(dstGraph);
if (!MSANFParseModelConfigureInfo(model_proto)) {
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
}
const mind_ir::GraphProto &graphBuild = model_proto.graph();
if (!BuildFuncGraph(dstGraph, graphBuild)) {
MS_LOG(ERROR) << "Build funcgraph failed!";
return nullptr;
}
MS_LOG(INFO) << "Parse pb to build FuncGraph Success!";
return dstGraph;
}
} // namespace mindspore

View File

@ -14,63 +14,62 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H
#define MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H
#ifndef MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H
#define MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H
#include <string>
#include <map>
#include <unordered_map>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "ir/func_graph.h"
#include "proto/onnx.pb.h"
#include "proto/mind_ir.pb.h"
namespace mindspore {
namespace lite {
using int32 = int32_t;
using int64 = int64_t;
using uint64 = uint64_t;
class MSANFModelParser {
public:
MSANFModelParser() : producer_name_(""), model_version_(0), ir_version_(0) {}
MSANFModelParser() : producer_name_(""), model_version_(""), ir_version_("") {}
~MSANFModelParser() = default;
FuncGraphPtr Parse(const onnx::ModelProto &model_proto);
bool MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto);
FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto);
bool MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto);
std::string GetProducerName() { return producer_name_; }
int GetProducerVersion() { return model_version_; }
int GetIrVersion() { return ir_version_; }
std::string GetProducerVersion() { return model_version_; }
std::string GetIrVersion() { return ir_version_; }
private:
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto);
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &tensor_proto);
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto);
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto);
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto,
const CNodePtr &cnode_ptr);
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_tensor);
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
void ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
std::unordered_map<std::string, ValuePtr> *multi_value_map);
ValuePtr ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index);
ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor);
bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
bool ObtainValueNodeInNoneForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
std::unordered_map<std::string, abstract::AbstractBasePtr> GetAbstractForCNode(
const onnx::AttributeProto &attr_proto);
const mind_ir::AttributeProto &attr_proto);
std::string producer_name_;
int model_version_;
int ir_version_;
std::string model_version_;
std::string ir_version_;
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
std::map<std::string, onnx::TensorProto> default_para_map_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H
#endif // MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H

View File

@ -0,0 +1,102 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "load_mindir/load_model.h"
#include <memory>
#include <algorithm>
#include <fstream>
#include "load_mindir/anf_model_parser.h"
using std::string;
using std::vector;
namespace mindspore {
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
if (file.empty()) {
MS_LOG(ERROR) << "file is nullptr";
return nullptr;
}
char real_path[PATH_MAX] = {0};
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(real_path, file.c_str(), PATH_MAX) == nullptr) {
MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
return nullptr;
}
#else
if (realpath(file.c_str(), real_path) == nullptr) {
MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
return nullptr;
}
#endif
std::ifstream ifs(real_path);
if (!ifs.good()) {
MS_LOG(ERROR) << "file: " << real_path << " is not exist";
return nullptr;
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "file: " << real_path << "open failed";
return nullptr;
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size));
if (buf == nullptr) {
MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
ifs.close();
return nullptr;
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf->data(), size);
ifs.close();
return buf;
}
std::shared_ptr<FuncGraph> RunLoadMindIR(const std::string &file_name) {
auto graphBuf = ReadProtoFile(file_name);
if (graphBuf == nullptr) {
MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str();
return nullptr;
}
try {
auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size());
return graph;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
return nullptr;
}
}
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size) {
MS_EXCEPTION_IF_NULL(buf);
std::string str((const char *)buf, buf_size);
mind_ir::ModelProto model_;
if (!model_.ParseFromString(str)) {
MS_LOG(ERROR) << "Parse model from buffer fail!";
}
MSANFModelParser model_parser;
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
return dstgraph_ptr;
}
} // namespace mindspore

View File

@ -13,27 +13,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_LOAD_MODEL_H
#define MINDSPORE_CORE_LOAD_MODEL_H
#ifndef MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_CONVERTER_H
#define MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_CONVERTER_H
#include <vector>
#include <string>
#include <memory>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "proto/onnx.pb.h"
#include "proto/mind_ir.pb.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace lite {
class AnfConverter {
public:
static std::shared_ptr<FuncGraph> RunAnfConverter(const std::string &file_path);
static std::shared_ptr<FuncGraph> RunAnfConverter(const char *buf, const size_t buf_size);
private:
static void Trim(std::string *input);
static int ValidateFileStr(const std::string &modelFile, std::string fileType);
static bool ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model);
};
} // namespace lite
std::shared_ptr<FuncGraph> RunLoadMindIR(const std::string &file_name);
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size);
} // namespace mindspore
#endif
#endif // MINDSPORE_CORE_LOAD_MODEL_H

View File

@ -116,7 +116,9 @@ endif ()
file(GLOB PROTO_FILE ""
${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto
${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/proto/*.proto
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto)
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto
${CCSRC_DIR}/utils/mind_ir.proto)
ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
add_library(proto_mid OBJECT ${PROTO_SRCS})
set(TFLITE_FBS_FILES

View File

@ -23,9 +23,12 @@ from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
from mindspore.common import dtype as mstype
from mindspore import log as logger
from mindspore.common.api import _executor
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
from mindspore.train.anf_ir_pb2 import ModelProto as anf_model
from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
def _convert_type(types):
"""
Convert from numpy type to tensor type.
@ -203,3 +206,32 @@ def check_value_type(arg_name, arg_value, valid_types):
if not is_valid:
raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
f'bug got {type(arg_value).__name__}.')
def read_proto(file_name, proto_format="MINDIR"):
"""
Read protobuf file.
Args:
file_name (str): File name.
proto_format (str): Proto format.
Returns:
Object, proto object.
"""
if proto_format == "MINDIR":
model = mindir_model()
elif model_format == "ANF":
model = anf_model()
else:
raise ValueError("Unsupported proto format.")
try:
with open(file_name, "rb") as f:
pb_content = f.read()
model.ParseFromString(pb_content)
except BaseException as e:
logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name)
raise ValueError(e.__str__())
return model