forked from mindspore-Ecosystem/mindspore
!8402 refactor mindir loading
From: @wangnan39 Reviewed-by: @kingxian,@guoqi1024 Signed-off-by: @kingxian
This commit is contained in:
commit
6ecf200b49
|
@ -125,7 +125,8 @@ if (ENABLE_DUMP_PROTO)
|
|||
"utils/lineage.proto"
|
||||
"utils/checkpoint.proto"
|
||||
"utils/print.proto"
|
||||
"utils/node_strategy.proto"
|
||||
"utils/node_strategy.proto"
|
||||
"utils/mind_ir.proto"
|
||||
)
|
||||
ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY})
|
||||
|
||||
|
@ -156,7 +157,7 @@ endif()
|
|||
## make sub objects
|
||||
set(SUB_COMP
|
||||
transform/graph_ir
|
||||
transform/onnx
|
||||
transform/express_ir
|
||||
backend/optimizer
|
||||
backend/kernel_compiler
|
||||
backend/session
|
||||
|
@ -344,13 +345,13 @@ if (ENABLE_MINDDATA)
|
|||
endif ()
|
||||
|
||||
# build inference
|
||||
set(LOAD_ONNX_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_converter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc
|
||||
set(LOAD_MINDIR_SRC
|
||||
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc
|
||||
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc
|
||||
)
|
||||
add_library(inference SHARED
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc
|
||||
${LOAD_ONNX_SRC}
|
||||
${LOAD_MINDIR_SRC}
|
||||
)
|
||||
|
||||
set_target_properties(inference PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
|
||||
|
|
|
@ -20,11 +20,11 @@
|
|||
#include <fstream>
|
||||
|
||||
#include "include/inference.h"
|
||||
#include "utils/load_onnx/anf_converter.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "backend/session/executor_manager.h"
|
||||
#include "base/base_ref_utils.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "utils/context/context_extends.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
|
@ -58,46 +58,9 @@ std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &dev
|
|||
MSInferSession::MSInferSession() = default;
|
||||
MSInferSession::~MSInferSession() = default;
|
||||
|
||||
std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "file is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = file;
|
||||
std::ifstream ifs(realPath);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "file: " << realPath << " is not exist";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "file: " << realPath << "open failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size));
|
||||
if (buf == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc buf failed, file: " << realPath;
|
||||
ifs.close();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(buf->data(), size);
|
||||
ifs.close();
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||
auto graphBuf = ReadFile(file_name);
|
||||
if (graphBuf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
|
||||
return FAILED;
|
||||
}
|
||||
auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_);
|
||||
Py_Initialize();
|
||||
auto graph = RunLoadMindIR(file_name);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
return FAILED;
|
||||
|
@ -213,6 +176,7 @@ Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &reques
|
|||
}
|
||||
inputs.push_back(input);
|
||||
}
|
||||
|
||||
auto ret = CheckModelInputs(model_id, inputs);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
|
||||
|
@ -250,16 +214,6 @@ Status MSInferSession::FinalizeEnv() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
||||
try {
|
||||
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
|
||||
return anf_graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference LoadModel failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void MSInferSession::RegAllOp() {
|
||||
static std::mutex init_mutex;
|
||||
static bool Initialized = false;
|
||||
|
|
|
@ -54,8 +54,6 @@ class MSInferSession : public InferSession {
|
|||
rtContext_t context_ = nullptr;
|
||||
#endif
|
||||
|
||||
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device);
|
||||
std::shared_ptr<std::vector<char>> ReadFile(const std::string &file);
|
||||
static void RegAllOp();
|
||||
string AjustTargetName(const std::string &device);
|
||||
Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# build mindspore_shared_lib
|
||||
set(LOAD_ONNX_SRC
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_converter.cc
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc
|
||||
set(LOAD_MINDIR_SRC
|
||||
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc
|
||||
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc
|
||||
)
|
||||
file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc")
|
||||
|
||||
|
@ -18,7 +18,7 @@ set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
|
|||
${API_MS_INFER_SRC}
|
||||
${API_ACL_SRC}
|
||||
${API_OPS_SRC}
|
||||
${LOAD_ONNX_SRC})
|
||||
${LOAD_MINDIR_SRC})
|
||||
|
||||
add_library(mindspore_shared_lib SHARED ${MSLIB_SRC})
|
||||
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}")
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
#include "cxx_api/model/acl/model_converter.h"
|
||||
#include <memory>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "utils/load_onnx/anf_converter.h"
|
||||
#include "transform/graph_ir/convert.h"
|
||||
#include "transform/graph_ir/graph_runner.h"
|
||||
#include "core/load_mindir/load_model.h"
|
||||
#include "mindspore/core/utils/ms_context.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
|
||||
|
@ -79,8 +79,7 @@ bool CreateSessionAndGraphRunner() {
|
|||
|
||||
std::shared_ptr<FuncGraph> ModelConverter::ConvertMindIrToFuncGraph(const Buffer &model_data) {
|
||||
try {
|
||||
auto anf_graph =
|
||||
lite::AnfConverter::RunAnfConverter(reinterpret_cast<const char *>(model_data.Data()), model_data.DataSize());
|
||||
auto anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data.Data()), model_data.DataSize());
|
||||
return anf_graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||
|
@ -364,6 +363,7 @@ Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
|
|||
|
||||
Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) {
|
||||
RegAllOp();
|
||||
Py_Initialize();
|
||||
auto func_graph = ConvertMindIrToFuncGraph(model_data);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed.";
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <algorithm>
|
||||
#include <fstream>
|
||||
|
||||
#include "utils/load_onnx/anf_converter.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "backend/session/executor_manager.h"
|
||||
|
@ -117,9 +117,9 @@ Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::m
|
|||
return FAILED;
|
||||
}
|
||||
std::shared_ptr<FuncGraph> anf_graph;
|
||||
Py_Initialize();
|
||||
try {
|
||||
anf_graph =
|
||||
lite::AnfConverter::RunAnfConverter(static_cast<const char *>(model_data.Data()), model_data.DataSize());
|
||||
anf_graph = ConvertStreamToFuncGraph(static_cast<const char *>(model_data.Data()), model_data.DataSize());
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference LoadModel failed";
|
||||
return FAILED;
|
||||
|
@ -290,9 +290,10 @@ Status MsModel::FinalizeEnv() {
|
|||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
||||
Py_Initialize();
|
||||
MS_EXCEPTION_IF_NULL(model_buf);
|
||||
try {
|
||||
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
|
||||
auto anf_graph = ConvertStreamToFuncGraph(model_buf, size);
|
||||
return anf_graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what();
|
||||
|
|
|
@ -74,6 +74,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.")
|
||||
.def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""),
|
||||
py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.")
|
||||
.def("convert_funcgraph_to_mindir", &ExecutorPy::ConvertFuncGraphToMindIR, py::arg("graph"),
|
||||
"Convert FuncGraph to MindIR proto.")
|
||||
.def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
|
||||
py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
|
||||
.def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
|
||||
|
@ -108,6 +110,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
(void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend.");
|
||||
|
||||
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
|
||||
(py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load MindIR as Graph.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
|
|
|
@ -45,6 +45,7 @@
|
|||
#include "debug/draw.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "frontend/optimizer/py_pass_manager.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/info.h"
|
||||
|
@ -103,6 +104,16 @@ void CheckArgIsTensor(const ValuePtr &arg, std::size_t idx) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
py::bytes ExecutorPy::ConvertFuncGraphToMindIR(const FuncGraphPtr &fg_ptr) {
|
||||
std::string proto_str = GetBinaryProtoString(fg_ptr);
|
||||
if (proto_str.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Graph proto is empty.";
|
||||
}
|
||||
return proto_str;
|
||||
}
|
||||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::RunLoadMindIR(file_name); }
|
||||
|
||||
py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) {
|
||||
MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size();
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
|
|
|
@ -82,6 +82,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
|
|||
ResourcePtr GetResource(const std::string &phase);
|
||||
FuncGraphPtr GetFuncGraph(const std::string &phase);
|
||||
py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type);
|
||||
py::bytes ConvertFuncGraphToMindIR(const FuncGraphPtr &fg_ptr);
|
||||
compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase);
|
||||
bool HasCompiled(const std::string &phase) const;
|
||||
|
||||
|
@ -138,6 +139,7 @@ void ClearResAtexit();
|
|||
void ReleaseGeTsd();
|
||||
|
||||
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name);
|
||||
|
||||
// init and exec dataset sub graph
|
||||
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
file(GLOB_RECURSE _EXPORTER_IR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX)
|
||||
add_library(_mindspore_transform_express_ir_obj OBJECT ${_EXPORTER_IR_SRC_FILES})
|
|
@ -25,33 +25,40 @@
|
|||
#include "ir/param_info.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "proto/mind_ir.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
using FloatPtr = std::shared_ptr<Float>;
|
||||
using IntPtr = std::shared_ptr<Int>;
|
||||
|
||||
// anf type to onnx type map
|
||||
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_type_map = {
|
||||
{kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8},
|
||||
{kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32},
|
||||
{kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8},
|
||||
{kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32},
|
||||
{kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16},
|
||||
{kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE},
|
||||
{kObjectTypeString, onnx::TensorProto_DataType_STRING},
|
||||
// anf type to mindir type map
|
||||
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_type_map = {
|
||||
{kNumberTypeBool, mind_ir::TensorProto_DataType_BOOL},
|
||||
{kNumberTypeInt8, mind_ir::TensorProto_DataType_INT8},
|
||||
{kNumberTypeInt16, mind_ir::TensorProto_DataType_INT16},
|
||||
{kNumberTypeInt32, mind_ir::TensorProto_DataType_INT32},
|
||||
{kNumberTypeInt64, mind_ir::TensorProto_DataType_INT64},
|
||||
{kNumberTypeUInt8, mind_ir::TensorProto_DataType_UINT8},
|
||||
{kNumberTypeUInt16, mind_ir::TensorProto_DataType_UINT16},
|
||||
{kNumberTypeUInt32, mind_ir::TensorProto_DataType_UINT32},
|
||||
{kNumberTypeUInt64, mind_ir::TensorProto_DataType_UINT64},
|
||||
{kNumberTypeFloat16, mind_ir::TensorProto_DataType_FLOAT16},
|
||||
{kNumberTypeFloat32, mind_ir::TensorProto_DataType_FLOAT},
|
||||
{kNumberTypeFloat64, mind_ir::TensorProto_DataType_DOUBLE},
|
||||
{kObjectTypeString, mind_ir::TensorProto_DataType_STRING},
|
||||
};
|
||||
|
||||
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_bits_int_map = {
|
||||
{8, onnx::TensorProto_DataType_INT8},
|
||||
{16, onnx::TensorProto_DataType_INT16},
|
||||
{32, onnx::TensorProto_DataType_INT32},
|
||||
{64, onnx::TensorProto_DataType_INT64},
|
||||
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_int_map = {
|
||||
{8, mind_ir::TensorProto_DataType_INT8},
|
||||
{16, mind_ir::TensorProto_DataType_INT16},
|
||||
{32, mind_ir::TensorProto_DataType_INT32},
|
||||
{64, mind_ir::TensorProto_DataType_INT64},
|
||||
};
|
||||
|
||||
static std::unordered_map<int, onnx::TensorProto_DataType> g_data_bits_float_map = {
|
||||
{16, onnx::TensorProto_DataType_FLOAT16},
|
||||
{32, onnx::TensorProto_DataType_FLOAT},
|
||||
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_float_map = {
|
||||
{16, mind_ir::TensorProto_DataType_FLOAT16},
|
||||
{32, mind_ir::TensorProto_DataType_FLOAT},
|
||||
{64, mind_ir::TensorProto_DataType_FLOAT64},
|
||||
};
|
||||
|
||||
// Can build different builder according to format
|
||||
|
@ -77,34 +84,34 @@ class IrExportBuilder {
|
|||
void BuildModel(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
|
||||
void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
|
||||
void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto);
|
||||
void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto);
|
||||
void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto);
|
||||
std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto);
|
||||
void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
||||
void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
||||
void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
||||
void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||
void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||
std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||
|
||||
void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto);
|
||||
void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto);
|
||||
void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto);
|
||||
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto);
|
||||
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto);
|
||||
void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto);
|
||||
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto,
|
||||
void SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
||||
void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::ValueInfoProto *const value_proto);
|
||||
void SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
||||
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
|
||||
void SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
|
||||
void SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
|
||||
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
||||
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
||||
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
||||
void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
||||
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name);
|
||||
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto,
|
||||
void SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
void SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
void SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
void SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
void SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto,
|
||||
void SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
|
||||
onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
|
||||
onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits);
|
||||
onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits);
|
||||
mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id);
|
||||
mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits);
|
||||
mind_ir::TensorProto_DataType GetMindirDataBitsFloatType(int bits);
|
||||
std::string GetNodeName(const AnfNodePtr &node);
|
||||
std::string GetUniqueNodeName(const AnfNodePtr &node);
|
||||
std::string GetOpTypeName(const AnfNodePtr &node);
|
||||
|
@ -114,8 +121,8 @@ class IrExportBuilder {
|
|||
void ResetTupleIndex() { shape_index_ = 0; }
|
||||
|
||||
private:
|
||||
onnx::ModelProto model_;
|
||||
onnx::NodeProto *last_node_{nullptr};
|
||||
mind_ir::ModelProto model_;
|
||||
mind_ir::NodeProto *last_node_{nullptr};
|
||||
std::list<FuncGraphPtr> todo_;
|
||||
std::map<AnfNodePtr, size_t> node_index_map_;
|
||||
size_t node_index_{0};
|
||||
|
@ -144,13 +151,13 @@ std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
void IrExportBuilder::BuildModelInfo() {
|
||||
model_.set_ir_version(onnx::IR_VERSION_2019_1_22);
|
||||
model_.set_ir_version("0.1.0");
|
||||
model_.set_producer_name("MindSpore");
|
||||
model_.set_model_version(1);
|
||||
model_.set_model_version("1.1.0");
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
|
||||
onnx::GraphProto *graph_proto = model_.mutable_graph();
|
||||
mind_ir::GraphProto *graph_proto = model_.mutable_graph();
|
||||
graph_proto->set_name(func_graph->ToString());
|
||||
ResetNodeIndex();
|
||||
todo_.clear();
|
||||
|
@ -162,7 +169,7 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
|
||||
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||
// Export parameters
|
||||
// 1. parameters should be mapped to ValueInfoProto
|
||||
// 2. parameters with default value should be mapped to Initializer
|
||||
|
@ -172,33 +179,31 @@ void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::Graph
|
|||
BuildNodes(func_graph, graph_proto);
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
|
||||
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||
for (auto &item : func_graph->parameters()) {
|
||||
auto param = item->cast<ParameterPtr>();
|
||||
if (param == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
|
||||
}
|
||||
onnx::ValueInfoProto *input_proto = graph_proto->add_input();
|
||||
std::string param_name = GetUniqueNodeName(param);
|
||||
input_proto->set_name(param_name);
|
||||
SetValueInfoProto(param, input_proto);
|
||||
if (!param->has_default()) {
|
||||
if (param->has_default()) {
|
||||
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default.";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Using ONNX initializer to set parameter's default value
|
||||
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
|
||||
initializer_proto->set_name(param_name);
|
||||
SetParamToTensorProto(param, initializer_proto);
|
||||
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
|
||||
if (tensor) {
|
||||
initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
|
||||
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
|
||||
parameter_proto->set_name(param_name);
|
||||
SetParamToTensorProto(param, parameter_proto);
|
||||
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
|
||||
if (tensor) {
|
||||
parameter_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
|
||||
}
|
||||
} else {
|
||||
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
|
||||
input_proto->set_name(param_name);
|
||||
SetValueInfoProto(param, input_proto);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) {
|
||||
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) {
|
||||
auto iter = g_data_type_map.find(type_id);
|
||||
if (iter == g_data_type_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id;
|
||||
|
@ -206,7 +211,7 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) {
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) {
|
||||
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) {
|
||||
auto iter = g_data_bits_int_map.find(bits);
|
||||
if (iter == g_data_bits_int_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits;
|
||||
|
@ -214,7 +219,7 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) {
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) {
|
||||
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) {
|
||||
auto iter = g_data_bits_float_map.find(bits);
|
||||
if (iter == g_data_bits_float_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits;
|
||||
|
@ -222,73 +227,70 @@ onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) {
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) {
|
||||
void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) {
|
||||
if (node == nullptr || value_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!";
|
||||
}
|
||||
MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
|
||||
SetValueInfoProto(node->Type(), node->Shape(), value_proto);
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape,
|
||||
onnx::ValueInfoProto *const value_proto) {
|
||||
onnx::TypeProto *type_proto = value_proto->mutable_type();
|
||||
const TypePtr &type = node->Type();
|
||||
const BaseShapePtr &shape = node->Shape();
|
||||
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
|
||||
auto tensor = type->cast<TensorTypePtr>();
|
||||
auto elem_type = tensor->element();
|
||||
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
|
||||
type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id()));
|
||||
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
|
||||
tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id()));
|
||||
if (dims.size() == 0) {
|
||||
MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1.";
|
||||
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
tensor_proto->add_dims(1);
|
||||
} else {
|
||||
for (const auto &dim : dims) {
|
||||
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
|
||||
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
|
||||
tensor_proto->add_dims(dim);
|
||||
}
|
||||
}
|
||||
} else if (type->isa<Tuple>()) {
|
||||
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
|
||||
type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
|
||||
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
|
||||
} else if (type->isa<Number>() || type->isa<String>()) {
|
||||
type_proto->set_denotation(type->type_name());
|
||||
value_proto->set_denotation(type->type_name());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
|
||||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
|
||||
void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
||||
}
|
||||
attr_proto->set_ref_attr_name("tensor:value0");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_name("value0");
|
||||
auto data = value->cast<tensor::TensorPtr>();
|
||||
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
|
||||
auto dtype = data->data_type();
|
||||
auto shape = data->shape_c();
|
||||
tensor_proto->set_data_type(GetOnnxDataType(dtype));
|
||||
tensor_proto->set_data_type(GetMindirDataType(dtype));
|
||||
for (const auto &dim : shape) {
|
||||
tensor_proto->add_dims(dim);
|
||||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
|
||||
onnx::TensorProto *const tensor_proto) {
|
||||
mind_ir::TensorProto *const tensor_proto) {
|
||||
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
|
||||
MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString();
|
||||
}
|
||||
auto tensor = type->cast<TensorTypePtr>();
|
||||
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
|
||||
tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id()));
|
||||
tensor_proto->set_data_type(GetMindirDataType(tensor->element()->type_id()));
|
||||
for (const auto &dim : dims) {
|
||||
tensor_proto->add_dims(dim);
|
||||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) {
|
||||
void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) {
|
||||
if (param == nullptr || tensor_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
|
||||
}
|
||||
|
@ -296,7 +298,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten
|
|||
SetTensorProto(param->Type(), param->Shape(), tensor_proto);
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
|
||||
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
||||
bool is_only_return = true;
|
||||
for (const AnfNodePtr &node : nodes) {
|
||||
|
@ -317,12 +319,12 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt
|
|||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
|
||||
void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
||||
if (node->size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
|
||||
}
|
||||
AnfNodePtr arg = node->input(1);
|
||||
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
|
||||
mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
|
||||
std::string output_name = GetUniqueNodeName(node);
|
||||
output_proto->set_name(output_name);
|
||||
last_node_->set_output(0, output_name);
|
||||
|
@ -349,7 +351,7 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
|
||||
onnx::AttributeProto *const attr_proto, std::string *const seq_string) {
|
||||
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
|
||||
if (type->isa<Tuple>() && seq_string != nullptr) {
|
||||
*seq_string += "Tuple[";
|
||||
auto elements = type->cast<TuplePtr>()->elements();
|
||||
|
@ -361,7 +363,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
|||
} else if (type->isa<TensorType>() && shape->isa<abstract::Shape>() && seq_string != nullptr) {
|
||||
string shape_name = "shape" + std::to_string(GetTupleIndex());
|
||||
*seq_string += shape_name + ",";
|
||||
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_name(shape_name);
|
||||
SetTensorProto(type, shape, tensor_proto);
|
||||
} else if ((type->isa<Number>() || type->isa<String>()) && seq_string != nullptr) {
|
||||
|
@ -371,7 +373,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
|||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) {
|
||||
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
|
||||
// Get shape of cnode
|
||||
// 1. need to get shape from tuple element
|
||||
// 2. save shape in TensorProto
|
||||
|
@ -381,13 +383,13 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto
|
|||
auto shape = node->Shape();
|
||||
ResetTupleIndex();
|
||||
std::string seq_string = "shape:";
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
SetShapeToNodeProto(type, shape, attr_proto, &seq_string);
|
||||
attr_proto->set_ref_attr_name(seq_string);
|
||||
MS_LOG(DEBUG) << "CNode shape: " << seq_string;
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) {
|
||||
void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
||||
auto inputs_size = node->size();
|
||||
if (inputs_size < 1) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||
|
@ -403,7 +405,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g
|
|||
}
|
||||
|
||||
// Build cnode
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
mind_ir::NodeProto *node_proto = graph_proto->add_node();
|
||||
std::string output_name = GetUniqueNodeName(node);
|
||||
node_proto->add_output(output_name);
|
||||
node_proto->set_name(output_name);
|
||||
|
@ -421,7 +423,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g
|
|||
auto prim = GetValueNode<PrimitivePtr>(op);
|
||||
for (auto attr : prim->attrs()) {
|
||||
MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name(attr.first);
|
||||
SetValueToAttributeProto(attr.second, attr_proto);
|
||||
}
|
||||
|
@ -430,11 +432,11 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g
|
|||
}
|
||||
}
|
||||
|
||||
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) {
|
||||
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
|
||||
std::string node_name = GetUniqueNodeName(node);
|
||||
if (node->isa<ValueNode>()) {
|
||||
// When node input is a ValueNode, need to create a Constant Node
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
mind_ir::NodeProto *node_proto = graph_proto->add_node();
|
||||
node_proto->add_output(node_name);
|
||||
SetAttributeProto(node, node_proto);
|
||||
}
|
||||
|
@ -478,44 +480,48 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
|
|||
return node_name;
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) {
|
||||
void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) {
|
||||
if (node == nullptr || node_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
|
||||
}
|
||||
auto value = node->cast<ValueNodePtr>()->value();
|
||||
node_proto->set_op_type("Constant");
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name("value");
|
||||
MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString();
|
||||
SetValueToAttributeProto(value, attr_proto);
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
|
||||
void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
||||
}
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
if (value->isa<Int>()) {
|
||||
attr_proto->set_ref_attr_name("type:value0");
|
||||
tensor_proto->set_name("value0");
|
||||
auto int_value = value->cast<IntPtr>();
|
||||
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
|
||||
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
|
||||
} else if (value->isa<Float>()) {
|
||||
attr_proto->set_ref_attr_name("type:value0");
|
||||
tensor_proto->set_name("value0");
|
||||
auto float_value = value->cast<FloatPtr>();
|
||||
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
|
||||
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
|
||||
} else if (value->isa<Bool>()) {
|
||||
attr_proto->set_ref_attr_name("type:value0");
|
||||
tensor_proto->set_name("value0");
|
||||
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
|
||||
} else if (value->isa<TensorType>()) {
|
||||
attr_proto->set_ref_attr_name("type:tensor0");
|
||||
tensor_proto->set_name("tensor0");
|
||||
auto elem_type = value->cast<TensorTypePtr>()->element();
|
||||
if (elem_type->isa<Int>()) {
|
||||
auto int_value = elem_type->cast<IntPtr>();
|
||||
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits()));
|
||||
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
|
||||
} else if (elem_type->isa<Float>()) {
|
||||
auto float_value = elem_type->cast<FloatPtr>();
|
||||
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits()));
|
||||
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name();
|
||||
}
|
||||
|
@ -524,18 +530,18 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri
|
|||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
|
||||
void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
||||
}
|
||||
if (value->isa<StringImm>() || value->isa<Scalar>()) {
|
||||
SetScalarToAttributeProto(value, attr_proto);
|
||||
SetScalarToAttributeProto_ir(value, attr_proto);
|
||||
} else if (value->isa<Number>() || value->isa<TensorType>()) {
|
||||
SetTypeToAttributeProto(value, attr_proto);
|
||||
} else if (value->isa<ValueSequeue>() || value->isa<ValueSequeue>()) {
|
||||
} else if (value->isa<ValueSequeue>()) {
|
||||
ResetTupleIndex();
|
||||
std::string seq_string = "scalar:";
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string);
|
||||
attr_proto->set_ref_attr_name(seq_string);
|
||||
MS_LOG(DEBUG) << "Attr string: " << seq_string;
|
||||
|
@ -549,74 +555,102 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr
|
|||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) {
|
||||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
||||
}
|
||||
void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
attr_proto->set_ref_attr_name("scalar:value0");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
SetScalarToProto(value, tensor_proto, "value0");
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto,
|
||||
const std::string &value_name) {
|
||||
if (value == nullptr || tensor_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!";
|
||||
}
|
||||
tensor_proto->set_name(value_name);
|
||||
if (value->isa<StringImm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING);
|
||||
tensor_proto->add_string_data(GetValue<std::string>(value));
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
|
||||
attr_proto->set_s(GetValue<std::string>(value));
|
||||
} else if (value->isa<BoolImm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL);
|
||||
tensor_proto->add_int32_data(GetValue<bool>(value));
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
|
||||
attr_proto->set_i(GetValue<bool>(value));
|
||||
} else if (value->isa<Int8Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8);
|
||||
tensor_proto->add_int32_data(value->cast<Int8ImmPtr>()->value());
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
|
||||
attr_proto->set_i(value->cast<Int8ImmPtr>()->value());
|
||||
} else if (value->isa<Int16Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16);
|
||||
tensor_proto->add_int32_data(value->cast<Int16ImmPtr>()->value());
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
|
||||
attr_proto->set_i(value->cast<Int16ImmPtr>()->value());
|
||||
} else if (value->isa<Int32Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32);
|
||||
tensor_proto->add_int32_data(value->cast<Int32ImmPtr>()->value());
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
|
||||
attr_proto->set_i(value->cast<Int32ImmPtr>()->value());
|
||||
} else if (value->isa<Int64Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
|
||||
tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value());
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
|
||||
attr_proto->set_i(value->cast<Int64ImmPtr>()->value());
|
||||
} else if (value->isa<UInt8Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8);
|
||||
tensor_proto->add_int32_data(value->cast<UInt8ImmPtr>()->value());
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
|
||||
attr_proto->set_i(value->cast<UInt8ImmPtr>()->value());
|
||||
} else if (value->isa<UInt16Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16);
|
||||
tensor_proto->add_int32_data(value->cast<UInt16ImmPtr>()->value());
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
|
||||
attr_proto->set_i(value->cast<UInt16ImmPtr>()->value());
|
||||
} else if (value->isa<UInt32Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32);
|
||||
tensor_proto->add_uint64_data(value->cast<UInt32ImmPtr>()->value());
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
|
||||
attr_proto->set_i(value->cast<UInt32ImmPtr>()->value());
|
||||
} else if (value->isa<UInt64Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64);
|
||||
tensor_proto->add_uint64_data(value->cast<UInt64ImmPtr>()->value());
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
|
||||
attr_proto->set_i(value->cast<UInt64ImmPtr>()->value());
|
||||
} else if (value->isa<FP32Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
|
||||
tensor_proto->add_float_data(GetValue<float>(value));
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
|
||||
attr_proto->set_f(GetValue<float>(value));
|
||||
} else if (value->isa<FP64Imm>()) {
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
|
||||
tensor_proto->add_double_data(GetValue<double>(value));
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
|
||||
attr_proto->set_d(GetValue<double>(value));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
|
||||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto,
|
||||
void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (value->isa<StringImm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
|
||||
attr_proto->add_strings(GetValue<std::string>(value));
|
||||
} else if (value->isa<BoolImm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
|
||||
attr_proto->add_ints(GetValue<bool>(value));
|
||||
} else if (value->isa<Int8Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
|
||||
attr_proto->add_ints(value->cast<Int8ImmPtr>()->value());
|
||||
} else if (value->isa<Int16Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
|
||||
attr_proto->add_ints(value->cast<Int16ImmPtr>()->value());
|
||||
} else if (value->isa<Int32Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
|
||||
attr_proto->add_ints(value->cast<Int32ImmPtr>()->value());
|
||||
} else if (value->isa<Int64Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
|
||||
attr_proto->add_ints(value->cast<Int64ImmPtr>()->value());
|
||||
} else if (value->isa<UInt8Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
|
||||
attr_proto->add_ints(value->cast<UInt8ImmPtr>()->value());
|
||||
} else if (value->isa<UInt16Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
|
||||
attr_proto->add_ints(value->cast<UInt16ImmPtr>()->value());
|
||||
} else if (value->isa<UInt32Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
|
||||
attr_proto->add_ints(value->cast<UInt32ImmPtr>()->value());
|
||||
} else if (value->isa<UInt64Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
|
||||
attr_proto->add_ints(value->cast<UInt64ImmPtr>()->value());
|
||||
} else if (value->isa<FP32Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
|
||||
attr_proto->add_floats(GetValue<float>(value));
|
||||
} else if (value->isa<FP64Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
|
||||
attr_proto->add_doubles(GetValue<double>(value));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
|
||||
}
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string) {
|
||||
string value_name = "value" + std::to_string(GetTupleIndex());
|
||||
if (seq_string != nullptr) {
|
||||
*seq_string += value_name + ",";
|
||||
}
|
||||
onnx::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
SetScalarToProto(value, tensor_proto, value_name);
|
||||
SetScalarToAttributeProto_irs(value, attr_proto);
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto,
|
||||
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
||||
mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string) {
|
||||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!";
|
||||
|
@ -625,6 +659,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
|||
*seq_string += "Tuple[";
|
||||
const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>();
|
||||
if (tuple_value->value().size() == 0) {
|
||||
*seq_string += "],";
|
||||
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
|
||||
return;
|
||||
}
|
||||
|
@ -640,6 +675,7 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
|||
*seq_string += "List[";
|
||||
const ValueListPtr &list_value = value->cast<ValueListPtr>();
|
||||
if (list_value->value().size() == 0) {
|
||||
*seq_string += "],";
|
||||
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0.";
|
||||
return;
|
||||
}
|
|
@ -1,3 +0,0 @@
|
|||
file(GLOB_RECURSE _ONNX_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX)
|
||||
add_library(_mindspore_transform_onnx_obj OBJECT ${_ONNX_SRC_FILES})
|
|
@ -5,11 +5,5 @@ if (NOT ENABLE_GE)
|
|||
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_GE_SRC_FILES})
|
||||
endif ()
|
||||
|
||||
file(GLOB_RECURSE _UTILS_LITE_SRC_FILES
|
||||
./load_onnx/anf_converter.cc
|
||||
./load_onnx/anf_model_parser.cc
|
||||
)
|
||||
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_LITE_SRC_FILES})
|
||||
|
||||
set_property(SOURCE ${_UTILS_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_UTILS)
|
||||
add_library(_mindspore_utils_obj OBJECT ${_UTILS_SRC_LIST})
|
||||
|
|
|
@ -1,134 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "utils/load_onnx/anf_converter.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
#include "utils/load_onnx/anf_model_parser.h"
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
const char WHITESPACE[] = "\t\n\v\f\r ";
|
||||
const int FLAG_PREFIX_LEN = 2;
|
||||
|
||||
void AnfConverter::Trim(std::string *input) {
|
||||
if (input == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (input->empty()) {
|
||||
return;
|
||||
}
|
||||
input->erase(0, input->find_first_not_of(WHITESPACE));
|
||||
input->erase(input->find_last_not_of(WHITESPACE) + 1);
|
||||
}
|
||||
|
||||
int AnfConverter::ValidateFileStr(const std::string &modelFile, std::string fileType) {
|
||||
if (modelFile.size() > fileType.size()) {
|
||||
if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
|
||||
return 0;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
bool AnfConverter::ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) {
|
||||
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
|
||||
|
||||
if (modelFile.size() > PATH_MAX) {
|
||||
MS_LOG(DEBUG) << "file path " << modelFile << " is too long.";
|
||||
return false;
|
||||
}
|
||||
char real_path[PATH_MAX + 1] = {0};
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
if (nullptr == _fullpath(real_path, modelFile.c_str(), PATH_MAX)) {
|
||||
MS_LOG(DEBUG) << modelFile << " does not exit.";
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
if (nullptr == realpath(modelFile.c_str(), real_path)) {
|
||||
MS_LOG(DEBUG) << modelFile << " does not exit.";
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
int fd = open(real_path, O_RDONLY);
|
||||
if (fd < 0) {
|
||||
MS_LOG(EXCEPTION) << "failed to open file";
|
||||
}
|
||||
google::protobuf::io::FileInputStream input(fd);
|
||||
google::protobuf::io::CodedInputStream code_input(&input);
|
||||
code_input.SetTotalBytesLimit(INT_MAX, 536870912);
|
||||
bool ret = onnx_model->ParseFromCodedStream(&code_input);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "load onnx file failed";
|
||||
return false;
|
||||
}
|
||||
(void)close(fd);
|
||||
MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const std::string &file_path) {
|
||||
std::string modelFile;
|
||||
|
||||
std::string tmp = file_path;
|
||||
Trim(&tmp);
|
||||
const std::string flagItem(tmp);
|
||||
|
||||
size_t pos = flagItem.find_first_of("=");
|
||||
if (pos == std::string::npos) {
|
||||
MS_LOG(ERROR) << "Trans data not support input format!";
|
||||
} else {
|
||||
modelFile = flagItem.substr(pos + 1);
|
||||
std::cout << "input protobuf file path is: " << modelFile << std::endl;
|
||||
}
|
||||
|
||||
if (ValidateFileStr(modelFile, ".pb") != 0) {
|
||||
MS_LOG(EXCEPTION) << "INPUT ILLEGAL: modelFile must be *.pb";
|
||||
}
|
||||
|
||||
onnx::ModelProto model_;
|
||||
ReadOnnxFromBinary(modelFile, &model_);
|
||||
MSANFModelParser model_parser;
|
||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
|
||||
return dstgraph_ptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const char *buf, const size_t buf_size) {
|
||||
Py_Initialize();
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
std::string str((const char *)buf, buf_size);
|
||||
onnx::ModelProto model_;
|
||||
if (!model_.ParseFromString(str)) {
|
||||
MS_LOG(EXCEPTION) << "Parse model from buffer fail!";
|
||||
}
|
||||
MSANFModelParser model_parser;
|
||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
|
||||
return dstgraph_ptr;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,693 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "utils/load_onnx/anf_model_parser.h"
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
||||
using std::string;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
static constexpr char kConstantValueNode[] = "Constant";
|
||||
static constexpr char kCNodeShapeAttr[] = "shape";
|
||||
static constexpr char kCNodeShape1Attr[] = "shape1";
|
||||
static constexpr char kCNodeShape2Attr[] = "shape2";
|
||||
enum ParseForm : int {
|
||||
FORM_PARSE_TYPE = 0,
|
||||
FORM_PARSE_SCALAR = 1,
|
||||
FORM_PARSE_TENSOR = 2,
|
||||
};
|
||||
|
||||
static std::map<std::string, ParseForm> kParseTypeSwitchMap{
|
||||
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}};
|
||||
|
||||
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
|
||||
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
|
||||
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
|
||||
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
|
||||
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
|
||||
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
|
||||
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
|
||||
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
|
||||
};
|
||||
|
||||
template <typename T, typename P>
|
||||
std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<string, P> &kv) {
|
||||
std::stack<std::string> rules;
|
||||
std::stack<P> value;
|
||||
int count = 0;
|
||||
for (size_t i = 0; i < str.length(); i++) {
|
||||
if (str[i] == '[') {
|
||||
rules.push("[");
|
||||
} else if (str[i] == ']') {
|
||||
// rules
|
||||
std::vector<P> vec;
|
||||
while (rules.top() != "[") {
|
||||
rules.pop();
|
||||
vec.push_back(value.top());
|
||||
value.pop();
|
||||
}
|
||||
// pop "["
|
||||
rules.pop();
|
||||
// make tuple for names
|
||||
std::string res = "dummy";
|
||||
// make tuple for values
|
||||
reverse(vec.begin(), vec.end());
|
||||
auto vt = std::make_shared<T>(vec);
|
||||
if (rules.empty() && value.empty()) {
|
||||
return vt;
|
||||
}
|
||||
rules.push(res);
|
||||
value.push(vt);
|
||||
} else if (str[i] == ',') {
|
||||
continue;
|
||||
} else {
|
||||
count++;
|
||||
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
|
||||
auto value_name = str.substr(i - count + 1, count);
|
||||
value.push(kv.at(value_name));
|
||||
rules.push(value_name);
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
|
||||
const std::unordered_map<string, ValuePtr> &kv) {
|
||||
std::string str = attr_name;
|
||||
auto replace = [&](const string &orgStr, const string &newStr) {
|
||||
std::string::size_type pos(0);
|
||||
while ((pos = str.find(orgStr)) != std::string::npos) {
|
||||
str.replace(pos, orgStr.length(), newStr);
|
||||
}
|
||||
return str;
|
||||
};
|
||||
// remove "scalar:"
|
||||
str = replace("scalar:", "");
|
||||
// remove "Tuple"
|
||||
str = replace("Tuple", "");
|
||||
// remove "List"
|
||||
str = replace("List", "");
|
||||
auto result = ParserAttr<ValueTuple>(str, kv);
|
||||
if (!result) {
|
||||
return {};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::shared_ptr<abstract::AbstractTuple> ParserAttrShape(
|
||||
const std::string &attr_name, const std::unordered_map<string, abstract::AbstractBasePtr> &kv) {
|
||||
std::string str = attr_name;
|
||||
auto replace = [&](const string &orgStr, const string &newStr) {
|
||||
std::string::size_type pos(0);
|
||||
while ((pos = str.find(orgStr)) != std::string::npos) {
|
||||
str.replace(pos, orgStr.length(), newStr);
|
||||
}
|
||||
return str;
|
||||
};
|
||||
// remove "scalar:"
|
||||
str = replace("shape:", "");
|
||||
// remove "Tuple"
|
||||
str = replace("Tuple", "");
|
||||
// remove "List"
|
||||
str = replace("List", "");
|
||||
|
||||
auto result = ParserAttr<abstract::AbstractTuple>(str, kv);
|
||||
if (!result) {
|
||||
return {};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
|
||||
return MakeValue<valuetype>(value); \
|
||||
}
|
||||
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(float, float)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(string, string)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
|
||||
|
||||
bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!value_proto.has_type() || !value_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
|
||||
return false;
|
||||
}
|
||||
node->set_name(value_proto.name());
|
||||
const auto &type_proto = value_proto.type();
|
||||
if (!type_proto.has_tensor_type()) {
|
||||
MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! ";
|
||||
return false;
|
||||
}
|
||||
const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type();
|
||||
if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) {
|
||||
MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! ";
|
||||
return false;
|
||||
}
|
||||
const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape();
|
||||
ShapeVector shape;
|
||||
for (int i = 0; i < tensor_shape.dim_size(); ++i) {
|
||||
shape.push_back(tensor_shape.dim(i).dim_value());
|
||||
}
|
||||
|
||||
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!";
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor::TensorPtr tensor_info =
|
||||
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
node->set_abstract(tensor_abstract);
|
||||
|
||||
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
|
||||
const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()];
|
||||
std::string initial_data = initialize_proto.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
MS_EXCEPTION_IF_NULL(tensor_data_buf);
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
|
||||
node->set_default_param(tensor_info);
|
||||
}
|
||||
anfnode_build_map_[value_proto.name()] = node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size();
|
||||
|
||||
for (int i = 0; i < importProto.initializer_size(); ++i) {
|
||||
const onnx::TensorProto &initializer_proto = importProto.initializer(i);
|
||||
if (!initializer_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i;
|
||||
return false;
|
||||
}
|
||||
default_para_map_[initializer_proto.name()] = initializer_proto;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "all parameters size: " << importProto.input_size();
|
||||
for (int i = 0; i < importProto.input_size(); ++i) {
|
||||
const onnx::ValueInfoProto &input_proto = importProto.input(i);
|
||||
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
|
||||
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
return true;
|
||||
}
|
||||
|
||||
ValuePtr MSANFModelParser::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
switch (attr_tensor_type) {
|
||||
case onnx::TensorProto_DataType_STRING: {
|
||||
return ParseAttrInScalar_string_string(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_INT32: {
|
||||
return ParseAttrInScalar_int32_int32(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_INT64: {
|
||||
return ParseAttrInScalar_int64_int64(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_UINT64: {
|
||||
return ParseAttrInScalar_uint64_uint64(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_FLOAT: {
|
||||
return ParseAttrInScalar_float_float(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_DOUBLE: {
|
||||
return ParseAttrInScalar_double_double(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_BOOL: {
|
||||
return ParseAttrInScalar_int32_bool(attr_tensor);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
||||
return {};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_LOG(ERROR) << "parse attr type don't support attr type is tensor";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const std::string &attr_name = attr_proto.name();
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
string type;
|
||||
std::size_t pos(0);
|
||||
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("type:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
|
||||
}
|
||||
std::unordered_map<std::string, ValuePtr> kv;
|
||||
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||
switch (kParseTypeSwitchMap[type]) {
|
||||
case FORM_PARSE_TYPE: {
|
||||
ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_SCALAR: {
|
||||
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
|
||||
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_TENSOR: {
|
||||
ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
|
||||
if (kv.size() == 1) {
|
||||
auto iter = kv.begin();
|
||||
prim->AddAttr(attr_name, iter->second);
|
||||
} else {
|
||||
auto res = ParserScalarAttrValue(ref_attr_name, kv);
|
||||
prim->AddAttr(attr_name, res);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
ShapeVector shape;
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
|
||||
auto new_value_node = NewValueNode(MakeValue(tensor_info));
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
new_value_node->set_abstract(tensor_abstract);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainValueNodeInScalarForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
ValuePtr value_ptr = nullptr;
|
||||
switch (attr_tensor_type) {
|
||||
case onnx::TensorProto_DataType_INT32: {
|
||||
std::vector<int64_t> add_data;
|
||||
for (int i = 0; i < attr_tensor.int32_data_size(); ++i) {
|
||||
add_data.push_back(attr_tensor.int32_data(i));
|
||||
}
|
||||
if (add_data.size() == 1) {
|
||||
value_ptr = MakeValue(add_data[0]);
|
||||
} else if (!add_data.empty()) {
|
||||
value_ptr = MakeValue<std::vector<int64_t>>(add_data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_FLOAT: {
|
||||
std::vector<float> add_data;
|
||||
for (int i = 0; i < attr_tensor.float_data_size(); ++i) {
|
||||
add_data.push_back(attr_tensor.float_data(i));
|
||||
}
|
||||
|
||||
if (add_data.size() == 1) {
|
||||
value_ptr = MakeValue(add_data[0]);
|
||||
} else if (!add_data.empty()) {
|
||||
value_ptr = MakeValue<std::vector<float>>(add_data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_UNDEFINED: {
|
||||
std::vector<ValuePtr> elems;
|
||||
value_ptr = std::make_shared<ValueTuple>(elems);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
auto new_value_node = NewValueNode(value_ptr);
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
new_value_node->set_abstract(value_ptr->ToAbstract());
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
|
||||
new_value_node->set_abstract(abs_type);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name,
|
||||
const onnx::AttributeProto &attr_proto) {
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
string type;
|
||||
std::size_t pos(0);
|
||||
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("type:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
|
||||
}
|
||||
std::unordered_map<std::string, ValuePtr> kv;
|
||||
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||
auto attr_name = attr_tensor.name();
|
||||
switch (kParseTypeSwitchMap[type]) {
|
||||
case FORM_PARSE_TYPE: {
|
||||
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_SCALAR: {
|
||||
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
|
||||
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_TENSOR: {
|
||||
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ValueNodePtr new_value_node;
|
||||
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
|
||||
if (kv.size() == 1) {
|
||||
auto iter = kv.begin();
|
||||
new_value_node = NewValueNode(iter->second);
|
||||
new_value_node->set_abstract(iter->second->ToAbstract());
|
||||
} else {
|
||||
auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv);
|
||||
new_value_node = NewValueNode(value_ptr);
|
||||
new_value_node->set_abstract(value_ptr->ToAbstract());
|
||||
}
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
|
||||
const std::string &value_node_name = node_proto.output(0);
|
||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
return GetAttrValueForValueNode(value_node_name, attr_proto);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode(
|
||||
const onnx::AttributeProto &attr_proto) {
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
|
||||
for (int i = 0; i < attr_proto.tensors_size(); ++i) {
|
||||
ShapeVector shape_vec;
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
|
||||
shape_vec.push_back(attr_tensor.dims(j));
|
||||
}
|
||||
tensor::TensorPtr tensor_info =
|
||||
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
auto abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract));
|
||||
}
|
||||
return kv;
|
||||
}
|
||||
|
||||
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::NodeProto &node_proto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
if (!node_proto.has_op_type()) {
|
||||
MS_LOG(ERROR) << "Get CNode op_type failed!";
|
||||
return nullptr;
|
||||
}
|
||||
const std::string &node_name = node_proto.output(0);
|
||||
const std::string &fullname_with_scope = node_proto.domain();
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
PrimitivePtr prim = std::make_shared<Primitive>(node_type);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
prim->set_instance_name(node_type);
|
||||
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
|
||||
string shape_ref_attr_name;
|
||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
|
||||
shape_ref_attr_name = attr_proto.ref_attr_name();
|
||||
kv = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
if (!GetAttrValueForCNode(prim, attr_proto)) {
|
||||
MS_LOG(ERROR) << "Get CNode attr failed!";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim));
|
||||
for (int i = 0; i < node_proto.input_size(); ++i) {
|
||||
const std::string &input_name = node_proto.input(i);
|
||||
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
|
||||
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
|
||||
return nullptr;
|
||||
}
|
||||
inputs.push_back(anfnode_build_map_[input_name]);
|
||||
}
|
||||
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
if (0 == kv.size()) {
|
||||
AbstractBasePtrList elem;
|
||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||
elem.push_back(cnode_ptr->input(index)->abstract());
|
||||
}
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else if (1 == kv.size()) {
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
|
||||
cnode_ptr->set_abstract(iter->second);
|
||||
} else {
|
||||
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
|
||||
cnode_ptr->set_abstract(abstract);
|
||||
}
|
||||
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
|
||||
anfnode_build_map_[node_name] = cnode_ptr;
|
||||
return cnode_ptr;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
if (importProto.output_size() > 1) {
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
AbstractBasePtrList elem;
|
||||
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
|
||||
const onnx::ValueInfoProto &output_node = importProto.output(out_size);
|
||||
const std::string &out_tuple = output_node.name();
|
||||
inputs.push_back(anfnode_build_map_[out_tuple]);
|
||||
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
|
||||
}
|
||||
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(maketuple_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
|
||||
} else {
|
||||
const onnx::ValueInfoProto &output_node = importProto.output(0);
|
||||
const onnx::TypeProto &output_typeproto = output_node.type();
|
||||
ShapeVector output_shape;
|
||||
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) {
|
||||
output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value());
|
||||
}
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(cnode_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
|
||||
CNodePtr cnode_ptr = nullptr;
|
||||
for (int i = 0; i < importProto.node_size(); ++i) {
|
||||
const onnx::NodeProto &node_proto = importProto.node(i);
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
if (node_type == kConstantValueNode) {
|
||||
if (!BuildValueNodeForFuncGraph(node_proto)) {
|
||||
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
|
||||
if (cnode_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
|
||||
MS_EXCEPTION_IF_NULL(debug_info_ptr);
|
||||
if (importProto.has_name()) {
|
||||
debug_info_ptr->set_name(importProto.name());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "FuncGraph under converting has not name!";
|
||||
}
|
||||
|
||||
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
|
||||
return false;
|
||||
}
|
||||
return ImportNodesForGraph(outputFuncGraph, importProto);
|
||||
}
|
||||
|
||||
bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
|
||||
if (!model_proto.has_producer_name()) {
|
||||
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
|
||||
return false;
|
||||
}
|
||||
producer_name_ = model_proto.producer_name();
|
||||
MS_LOG(INFO) << "producer_name :" << producer_name_;
|
||||
|
||||
if (!model_proto.has_model_version()) {
|
||||
MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
|
||||
return false;
|
||||
}
|
||||
model_version_ = model_proto.model_version();
|
||||
MS_LOG(INFO) << "producer_version : " << model_version_;
|
||||
|
||||
if (!model_proto.has_ir_version()) {
|
||||
MS_LOG(ERROR) << "Parse model version from pb file failed!";
|
||||
return false;
|
||||
}
|
||||
ir_version_ = model_proto.ir_version();
|
||||
MS_LOG(INFO) << "ir_version :" << ir_version_;
|
||||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr MSANFModelParser::Parse(const onnx::ModelProto &model_proto) {
|
||||
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
|
||||
MS_EXCEPTION_IF_NULL(dstGraph);
|
||||
if (!MSANFParseModelConfigureInfo(model_proto)) {
|
||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||
}
|
||||
const onnx::GraphProto &graphBuild = model_proto.graph();
|
||||
if (!BuildFuncGraph(dstGraph, graphBuild)) {
|
||||
MS_LOG(ERROR) << "Build funcgraph failed!";
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "Parse pb to build FuncGraph Success!";
|
||||
return dstGraph;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,119 @@
|
|||
syntax = "proto2";
|
||||
package mind_ir;
|
||||
|
||||
message AttributeProto {
|
||||
enum AttributeType {
|
||||
UNDEFINED = 0;
|
||||
FLOAT = 1;
|
||||
UINT8 = 2;
|
||||
INT8 = 3;
|
||||
UINT16 = 4;
|
||||
INT16 = 5;
|
||||
INT32 = 6;
|
||||
INT64 = 7;
|
||||
STRING = 8;
|
||||
BOOL = 9;
|
||||
FLOAT16 = 10;
|
||||
DOUBLE = 11;
|
||||
UINT32 = 12;
|
||||
UINT64 = 13;
|
||||
COMPLEX64 = 14;
|
||||
COMPLEX128 = 15;
|
||||
BFLOAT16 = 16;
|
||||
TENSOR = 17;
|
||||
GRAPH = 18;
|
||||
TENSORS = 19;
|
||||
}
|
||||
optional string name = 1;
|
||||
optional float f = 2;
|
||||
optional int64 i = 3;
|
||||
optional double d = 4;
|
||||
optional bytes s = 5;
|
||||
optional TensorProto t = 6;
|
||||
optional GraphProto g = 7;
|
||||
repeated float floats = 8;
|
||||
repeated double doubles = 9;
|
||||
repeated int64 ints = 10;
|
||||
repeated bytes strings = 11;
|
||||
repeated TensorProto tensors = 12;
|
||||
repeated GraphProto graphs = 13;
|
||||
optional string doc_string = 14;
|
||||
optional string ref_attr_name = 15;
|
||||
optional AttributeType type = 16;
|
||||
}
|
||||
|
||||
|
||||
message ValueInfoProto {
|
||||
optional string name = 1;
|
||||
repeated TensorProto tensor = 2;
|
||||
optional string doc_string = 3;
|
||||
optional string denotation = 4;
|
||||
}
|
||||
|
||||
|
||||
message NodeProto {
|
||||
repeated string input = 1;
|
||||
repeated string output = 2;
|
||||
optional string name = 3;
|
||||
optional string op_type = 4;
|
||||
repeated AttributeProto attribute = 5;
|
||||
optional string doc_string = 6;
|
||||
optional string domain = 7;
|
||||
}
|
||||
|
||||
|
||||
message ModelProto {
|
||||
optional string ir_version = 1;
|
||||
optional string producer_name = 2;
|
||||
optional string producer_version = 3;
|
||||
optional string domain = 4;
|
||||
optional string model_version = 5;
|
||||
optional string doc_string = 6;
|
||||
optional GraphProto graph = 7;
|
||||
}
|
||||
|
||||
|
||||
message GraphProto {
|
||||
repeated NodeProto node = 1;
|
||||
optional string name = 2;
|
||||
repeated TensorProto parameter = 3;
|
||||
optional string doc_string = 4;
|
||||
repeated ValueInfoProto input = 5;
|
||||
repeated ValueInfoProto output = 6;
|
||||
}
|
||||
|
||||
|
||||
message TensorProto {
|
||||
enum DataType {
|
||||
UNDEFINED = 0;
|
||||
// Basic types.
|
||||
FLOAT = 1; // float
|
||||
UINT8 = 2; // uint8_t
|
||||
INT8 = 3; // int8_t
|
||||
UINT16 = 4; // uint16_t
|
||||
INT16 = 5; // int16_t
|
||||
INT32 = 6; // int32_t
|
||||
INT64 = 7; // int64_t
|
||||
STRING = 8; // string
|
||||
BOOL = 9; // bool
|
||||
FLOAT16 = 10;
|
||||
DOUBLE = 11;
|
||||
UINT32 = 12;
|
||||
UINT64 = 13;
|
||||
COMPLEX64 = 14;
|
||||
COMPLEX128 = 15;
|
||||
BFLOAT16 = 16;
|
||||
FLOAT64 = 17;
|
||||
}
|
||||
repeated int64 dims = 1;
|
||||
optional int32 data_type = 2;
|
||||
repeated float float_data = 3;
|
||||
repeated int32 int32_data = 4;
|
||||
repeated bytes string_data = 5;
|
||||
repeated int64 int64_data = 6;
|
||||
optional string name = 7;
|
||||
optional string doc_string = 8;
|
||||
optional bytes raw_data = 9;
|
||||
repeated double double_data = 10;
|
||||
repeated uint64 uint64_data = 11;
|
||||
}
|
|
@ -15,6 +15,7 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"c_ops/*.cc"
|
||||
"ir/*.cc"
|
||||
"utils/*.cc"
|
||||
"load_mindir/*.cc"
|
||||
)
|
||||
set_property(SOURCE ${CORE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_CORE)
|
||||
add_library(mindspore_core STATIC ${CORE_SRC_LIST})
|
||||
|
|
|
@ -50,4 +50,5 @@ AbstractBasePtr TensorAddInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorAdd, prim::kPrimTensorAdd, TensorAddInfer);
|
||||
REGISTER_PRIMITIVE_C(TensorAdd);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -102,4 +102,5 @@ AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer);
|
||||
REGISTER_PRIMITIVE_C(AvgPool);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -193,4 +193,5 @@ AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
Conv2dInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
|
||||
REGISTER_PRIMITIVE_C(Conv2D);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,7 +33,7 @@ constexpr auto kPad = "pad";
|
|||
constexpr auto kPads = "pads";
|
||||
constexpr auto kMode = "mode";
|
||||
constexpr auto kGroup = "group";
|
||||
constexpr auto kOutputChannel = "output_channel";
|
||||
constexpr auto kOutputChannel = "out_channel";
|
||||
constexpr auto kPadList = "pad_list";
|
||||
constexpr auto kAxis = "axis";
|
||||
|
||||
|
|
|
@ -31,4 +31,13 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) {
|
|||
auto infer_function = iter->second.impl_;
|
||||
return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list);
|
||||
}
|
||||
|
||||
OpPrimCRegister &OpPrimCRegister::GetInstance() {
|
||||
static OpPrimCRegister instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::map<std::string, OpPrimCDefineFunc> OpPrimCRegister::GetPrimCMap() { return op_primc_fns_; }
|
||||
void OpPrimCRegister::SetPrimCMap(const std::string &name, const OpPrimCDefineFunc &fn) { op_primc_fns_[name] = fn; }
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "ir/primitive.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ir/value.h"
|
||||
|
@ -32,5 +34,33 @@ class PrimitiveC : public Primitive {
|
|||
protected:
|
||||
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name);
|
||||
};
|
||||
|
||||
using OpPrimCDefineFunc = std::function<std::shared_ptr<PrimitiveC>()>;
|
||||
class OpPrimCRegister {
|
||||
public:
|
||||
~OpPrimCRegister() {}
|
||||
static OpPrimCRegister &GetInstance();
|
||||
std::map<std::string, OpPrimCDefineFunc> GetPrimCMap();
|
||||
void SetPrimCMap(const std::string &name, const OpPrimCDefineFunc &fn);
|
||||
|
||||
private:
|
||||
OpPrimCRegister() {}
|
||||
std::map<std::string, OpPrimCDefineFunc> op_primc_fns_;
|
||||
};
|
||||
|
||||
class OpPrimCRegisterHelper {
|
||||
public:
|
||||
OpPrimCRegisterHelper(const std::string &name, const OpPrimCDefineFunc &fn) {
|
||||
OpPrimCRegister::GetInstance().SetPrimCMap(name, fn);
|
||||
}
|
||||
~OpPrimCRegisterHelper() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_PRIMITIVE_C(name) \
|
||||
std::shared_ptr<PrimitiveC> GetDefaultPrimC##name() { \
|
||||
auto out = std::make_shared<name>(); \
|
||||
return out; \
|
||||
} \
|
||||
OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
|
||||
|
|
|
@ -49,4 +49,5 @@ AbstractBasePtr Relu6Infer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
Relu6InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Relu6, prim::kPrimRelu6, Relu6Infer);
|
||||
REGISTER_PRIMITIVE_C(Relu6);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,4 +41,5 @@ AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer);
|
||||
REGISTER_PRIMITIVE_C(Reshape);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,7 +36,7 @@ class Reshape : public PrimitiveC {
|
|||
|
||||
AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimTensorAddPtr = std::shared_ptr<Reshape>;
|
||||
using PrimReshapePtr = std::shared_ptr<Reshape>;
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_RESHAPE_H_
|
||||
|
|
|
@ -75,4 +75,5 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer);
|
||||
REGISTER_PRIMITIVE_C(Softmax);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,4 +76,5 @@ AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Squeeze, prim::kPrimSqueeze, SqueezeInfer);
|
||||
REGISTER_PRIMITIVE_C(Squeeze);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,854 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "load_mindir/anf_model_parser.h"
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
||||
using std::string;
|
||||
|
||||
namespace mindspore {
|
||||
static constexpr char kConstantValueNode[] = "Constant";
|
||||
static constexpr char kCNodeShapeAttr[] = "shape";
|
||||
static constexpr char kCNodeShape1Attr[] = "shape1";
|
||||
static constexpr char kCNodeShape2Attr[] = "shape2";
|
||||
enum ParseForm : int {
|
||||
FORM_PARSE_TYPE = 0,
|
||||
FORM_PARSE_SCALAR = 1,
|
||||
FORM_PARSE_TENSOR = 2,
|
||||
FORM_PARSE_NONE = 3,
|
||||
FORM_PARSE_UNDEFINE = 4,
|
||||
};
|
||||
|
||||
static std::map<std::string, ParseForm> kParseTypeSwitchMap{{"type", FORM_PARSE_TYPE},
|
||||
{"scalar", FORM_PARSE_SCALAR},
|
||||
{"tensor", FORM_PARSE_TENSOR},
|
||||
{"none", FORM_PARSE_NONE},
|
||||
{"", FORM_PARSE_UNDEFINE}};
|
||||
|
||||
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
|
||||
{mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool},
|
||||
{mind_ir::TensorProto_DataType_INT8, kNumberTypeInt8},
|
||||
{mind_ir::TensorProto_DataType_INT16, kNumberTypeInt16},
|
||||
{mind_ir::TensorProto_DataType_INT32, kNumberTypeInt32},
|
||||
{mind_ir::TensorProto_DataType_INT64, kNumberTypeInt64},
|
||||
{mind_ir::TensorProto_DataType_UINT8, kNumberTypeUInt8},
|
||||
{mind_ir::TensorProto_DataType_UINT16, kNumberTypeUInt16},
|
||||
{mind_ir::TensorProto_DataType_UINT32, kNumberTypeUInt32},
|
||||
{mind_ir::TensorProto_DataType_UINT64, kNumberTypeUInt64},
|
||||
{mind_ir::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
|
||||
{mind_ir::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
|
||||
{mind_ir::TensorProto_DataType_FLOAT64, kNumberTypeFloat64},
|
||||
{mind_ir::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
|
||||
{mind_ir::TensorProto_DataType_STRING, kObjectTypeString},
|
||||
};
|
||||
|
||||
template <typename T, typename P>
|
||||
std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<string, P> &kv) {
|
||||
std::stack<std::string> rules;
|
||||
std::stack<P> value;
|
||||
int count = 0;
|
||||
for (size_t i = 0; i < str.length(); i++) {
|
||||
if (str[i] == '[') {
|
||||
rules.push("[");
|
||||
} else if (str[i] == ']') {
|
||||
// rules
|
||||
std::vector<P> vec;
|
||||
while (rules.top() != "[") {
|
||||
rules.pop();
|
||||
vec.push_back(value.top());
|
||||
value.pop();
|
||||
}
|
||||
// pop "["
|
||||
rules.pop();
|
||||
// make tuple for names
|
||||
std::string res = "dummy";
|
||||
// make tuple for values
|
||||
reverse(vec.begin(), vec.end());
|
||||
auto vt = std::make_shared<T>(vec);
|
||||
if (rules.empty() && value.empty()) {
|
||||
return vt;
|
||||
}
|
||||
rules.push(res);
|
||||
value.push(vt);
|
||||
} else if (str[i] == ',') {
|
||||
continue;
|
||||
} else {
|
||||
count++;
|
||||
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
|
||||
auto value_name = str.substr(i - count + 1, count);
|
||||
value.push(kv.at(value_name));
|
||||
rules.push(value_name);
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> ParserScalarAttrValue(const std::string &attr_name, const std::unordered_map<string, ValuePtr> &kv) {
|
||||
std::string str = attr_name;
|
||||
auto replace = [&](const string &orgStr, const string &newStr) {
|
||||
std::string::size_type pos(0);
|
||||
while ((pos = str.find(orgStr)) != std::string::npos) {
|
||||
str.replace(pos, orgStr.length(), newStr);
|
||||
}
|
||||
return str;
|
||||
};
|
||||
// remove "scalar:"
|
||||
str = replace("scalar:", "");
|
||||
// remove "Tuple"
|
||||
str = replace("Tuple", "");
|
||||
// remove "List"
|
||||
str = replace("List", "");
|
||||
auto result = ParserAttr<T>(str, kv);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::shared_ptr<abstract::AbstractTuple> ParserAttrShape(
|
||||
const std::string &attr_name, const std::unordered_map<string, abstract::AbstractBasePtr> &kv) {
|
||||
std::string str = attr_name;
|
||||
auto replace = [&](const string &orgStr, const string &newStr) {
|
||||
std::string::size_type pos(0);
|
||||
while ((pos = str.find(orgStr)) != std::string::npos) {
|
||||
str.replace(pos, orgStr.length(), newStr);
|
||||
}
|
||||
return str;
|
||||
};
|
||||
// remove "scalar:"
|
||||
str = replace("shape:", "");
|
||||
// remove "Tuple"
|
||||
str = replace("Tuple", "");
|
||||
// remove "List"
|
||||
str = replace("List", "");
|
||||
|
||||
auto result = ParserAttr<abstract::AbstractTuple>(str, kv);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string ParseParameterName(const string &name) {
|
||||
string delimiter = ":";
|
||||
size_t pos(0);
|
||||
if ((pos = name.find(delimiter)) != string::npos) {
|
||||
return name.substr(pos + 1, string::npos - (pos + 1));
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
std::string ParseCNodeName(const string &name) {
|
||||
string delimiter = ":";
|
||||
size_t pos = name.find(delimiter);
|
||||
size_t end_pos = name.find_last_of(delimiter);
|
||||
|
||||
if (pos != string::npos && end_pos != string::npos && pos != end_pos) {
|
||||
return name.substr(pos + 1, end_pos - (pos + 1));
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
#define PARSE_MINDIR_ATTR_IN_INT_FORM(type, valuetype) \
|
||||
ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
|
||||
auto value = static_cast<valuetype>(attr_proto.ints(index)); \
|
||||
return MakeValue<valuetype>(value); \
|
||||
} \
|
||||
ValuePtr ParseAttrInSingleScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto) { \
|
||||
auto value = static_cast<valuetype>(attr_proto.i()); \
|
||||
return MakeValue<valuetype>(value); \
|
||||
}
|
||||
|
||||
#define PARSE_MINDIR_ATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
ValuePtr ParseAttrInScalar_##type##_##valuetype(const mind_ir::AttributeProto &attr_proto, int index) { \
|
||||
auto value = static_cast<valuetype>(attr_proto.type##s(index)); \
|
||||
return MakeValue<valuetype>(value); \
|
||||
}
|
||||
|
||||
PARSE_MINDIR_ATTR_IN_INT_FORM(int8_t, int8_t)
|
||||
PARSE_MINDIR_ATTR_IN_INT_FORM(int16_t, int16_t)
|
||||
PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, int32_t)
|
||||
PARSE_MINDIR_ATTR_IN_INT_FORM(int64_t, int64_t)
|
||||
PARSE_MINDIR_ATTR_IN_INT_FORM(uint8_t, uint8_t)
|
||||
PARSE_MINDIR_ATTR_IN_INT_FORM(uint16_t, uint16_t)
|
||||
PARSE_MINDIR_ATTR_IN_INT_FORM(uint32_t, uint32_t)
|
||||
PARSE_MINDIR_ATTR_IN_INT_FORM(uint64_t, uint64_t)
|
||||
PARSE_MINDIR_ATTR_IN_INT_FORM(int32_t, bool)
|
||||
|
||||
PARSE_MINDIR_ATTR_IN_SCALAR_FORM(double, double)
|
||||
PARSE_MINDIR_ATTR_IN_SCALAR_FORM(float, float)
|
||||
PARSE_MINDIR_ATTR_IN_SCALAR_FORM(string, string)
|
||||
|
||||
ValuePtr ParseAttrInSingleScalar_string_string(const mind_ir::AttributeProto &attr_proto) {
|
||||
auto value = static_cast<string>(attr_proto.s());
|
||||
return MakeValue<string>(value);
|
||||
}
|
||||
|
||||
ValuePtr ParseAttrInSingleScalar_float_float(const mind_ir::AttributeProto &attr_proto) {
|
||||
auto value = static_cast<float>(attr_proto.f());
|
||||
return MakeValue<float>(value);
|
||||
}
|
||||
|
||||
ValuePtr ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto &attr_proto) {
|
||||
auto value = static_cast<double>(attr_proto.d());
|
||||
return MakeValue<double>(value);
|
||||
}
|
||||
|
||||
tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto) {
|
||||
ShapeVector shape;
|
||||
for (int i = 0; i < tensor_proto.dims_size(); ++i) {
|
||||
shape.push_back(tensor_proto.dims(i));
|
||||
}
|
||||
|
||||
if (!tensor_proto.has_data_type()) {
|
||||
MS_LOG(ERROR) << "mind_ir TensorProto has no data_type or name!";
|
||||
return nullptr;
|
||||
}
|
||||
if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "mind_ir TensorProto data_type is not support yet!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensor::TensorPtr tensor_info =
|
||||
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
return tensor_info;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||
const mind_ir::TensorProto ¶meter_proto) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (!parameter_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "mind_ir TensorProto has no name!";
|
||||
return false;
|
||||
}
|
||||
string debug_info_name = ParseParameterName(parameter_proto.name());
|
||||
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
||||
node->set_debug_info(debug_info_ptr);
|
||||
node->set_name(parameter_proto.name());
|
||||
|
||||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
node->set_abstract(tensor_abstract);
|
||||
|
||||
std::string initial_data = parameter_proto.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
MS_EXCEPTION_IF_NULL(tensor_data_buf);
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error for build parameter, errorno " << ret;
|
||||
}
|
||||
|
||||
node->set_default_param(tensor_info);
|
||||
|
||||
anfnode_build_map_[parameter_proto.name()] = node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (!value_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "mind_ir ValueInfoProto has no name!";
|
||||
return false;
|
||||
}
|
||||
string debug_info_name = ParseParameterName(value_proto.name());
|
||||
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
||||
node->set_debug_info(debug_info_ptr);
|
||||
node->set_name(value_proto.name());
|
||||
|
||||
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
|
||||
|
||||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
node->set_abstract(tensor_abstract);
|
||||
|
||||
anfnode_build_map_[value_proto.name()] = node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const mind_ir::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size();
|
||||
for (int i = 0; i < importProto.parameter_size(); ++i) {
|
||||
const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i);
|
||||
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) {
|
||||
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "All inputs size is: " << importProto.input_size();
|
||||
for (int i = 0; i < importProto.input_size(); ++i) {
|
||||
const mind_ir::ValueInfoProto &input_proto = importProto.input(i);
|
||||
if (!BuildInputForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
|
||||
MS_LOG(ERROR) << "Build input for funcgraph fail at index: " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int attr_tensor_type = attr_proto.tensors(0).data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
prim->AddAttr(attr_proto.name(), TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
return true;
|
||||
}
|
||||
|
||||
ValuePtr MSANFModelParser::ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index) {
|
||||
const int attr_type = attr_proto.type();
|
||||
switch (attr_type) {
|
||||
case mind_ir::AttributeProto_AttributeType_STRING: {
|
||||
return ParseAttrInScalar_string_string(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_INT8: {
|
||||
return ParseAttrInScalar_int8_t_int8_t(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_INT16: {
|
||||
return ParseAttrInScalar_int16_t_int16_t(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_INT32: {
|
||||
return ParseAttrInScalar_int32_t_int32_t(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_INT64: {
|
||||
return ParseAttrInScalar_int64_t_int64_t(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_UINT8: {
|
||||
return ParseAttrInScalar_uint8_t_uint8_t(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_UINT16: {
|
||||
return ParseAttrInScalar_uint16_t_uint16_t(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_UINT32: {
|
||||
return ParseAttrInScalar_uint32_t_uint32_t(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_UINT64: {
|
||||
return ParseAttrInScalar_uint64_t_uint64_t(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_FLOAT: {
|
||||
return ParseAttrInScalar_float_float(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_DOUBLE: {
|
||||
return ParseAttrInScalar_double_double(attr_proto, index);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_BOOL: {
|
||||
return ParseAttrInScalar_int32_t_bool(attr_proto, index);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
|
||||
return {};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void MSANFModelParser::ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
|
||||
std::unordered_map<std::string, ValuePtr> *multi_value_map) {
|
||||
string name;
|
||||
for (int i = 0; i < attr_proto.ints_size(); i++) {
|
||||
auto res = ParseAttrInScalarForm(attr_proto, i);
|
||||
name = "value" + std::to_string(i + 1);
|
||||
multi_value_map->insert(std::pair<string, ValuePtr>(name, res));
|
||||
}
|
||||
for (int i = 0; i < attr_proto.doubles_size(); i++) {
|
||||
auto res = ParseAttrInScalarForm(attr_proto, i);
|
||||
name = "value" + std::to_string(i + 1);
|
||||
multi_value_map->insert(std::pair<string, ValuePtr>(name, res));
|
||||
}
|
||||
for (int i = 0; i < attr_proto.floats_size(); i++) {
|
||||
auto res = ParseAttrInScalarForm(attr_proto, i);
|
||||
name = "value" + std::to_string(i + 1);
|
||||
multi_value_map->insert(std::pair<string, ValuePtr>(name, res));
|
||||
}
|
||||
for (int i = 0; i < attr_proto.strings_size(); i++) {
|
||||
auto res = ParseAttrInScalarForm(attr_proto, i);
|
||||
name = "value" + std::to_string(i + 1);
|
||||
multi_value_map->insert(std::pair<string, ValuePtr>(name, res));
|
||||
}
|
||||
}
|
||||
|
||||
ValuePtr MSANFModelParser::ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) {
|
||||
const int attr_type = attr_proto.type();
|
||||
switch (attr_type) {
|
||||
case mind_ir::AttributeProto_AttributeType_STRING: {
|
||||
return ParseAttrInSingleScalar_string_string(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_INT8: {
|
||||
return ParseAttrInSingleScalar_int8_t_int8_t(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_INT16: {
|
||||
return ParseAttrInSingleScalar_int16_t_int16_t(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_INT32: {
|
||||
return ParseAttrInSingleScalar_int32_t_int32_t(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_INT64: {
|
||||
return ParseAttrInSingleScalar_int64_t_int64_t(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_UINT8: {
|
||||
return ParseAttrInSingleScalar_uint8_t_uint8_t(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_UINT16: {
|
||||
return ParseAttrInSingleScalar_uint16_t_uint16_t(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_UINT32: {
|
||||
return ParseAttrInSingleScalar_uint32_t_uint32_t(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_UINT64: {
|
||||
return ParseAttrInSingleScalar_uint64_t_uint64_t(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_FLOAT: {
|
||||
return ParseAttrInSingleScalar_float_float(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_DOUBLE: {
|
||||
return ParseAttrInSingleScalar_double_double(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_BOOL: {
|
||||
return ParseAttrInSingleScalar_int32_t_bool(attr_proto);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
|
||||
return {};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0);
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
ShapeVector shape;
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
prim->AddAttr(attr_proto.name(), MakeValue(tensor_info));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const std::string &attr_name = attr_proto.name();
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
string type = "";
|
||||
std::size_t pos(0);
|
||||
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("type:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
|
||||
}
|
||||
std::unordered_map<std::string, ValuePtr> multi_value_map;
|
||||
switch (kParseTypeSwitchMap[type]) {
|
||||
case FORM_PARSE_TYPE: {
|
||||
ObtainCNodeAttrInTypeForm(prim, attr_proto);
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_SCALAR: {
|
||||
std::size_t value_pos(0);
|
||||
if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) {
|
||||
auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||
prim->AddAttr(attr_name, res);
|
||||
break;
|
||||
}
|
||||
ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_TENSOR: {
|
||||
ObtainCNodeAttrInTensorForm(prim, attr_proto);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
|
||||
if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) {
|
||||
auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
|
||||
prim->AddAttr(attr_name, value_tuple_ptr);
|
||||
} else {
|
||||
auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
|
||||
prim->AddAttr(attr_name, value_list_ptr);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
||||
const mind_ir::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
ShapeVector shape;
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
|
||||
auto new_value_node = NewValueNode(MakeValue(tensor_info));
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
new_value_node->set_abstract(tensor_abstract);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
||||
const mind_ir::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
|
||||
new_value_node->set_abstract(abs_type);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_name,
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
auto new_value_node = NewValueNode(kNone);
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name,
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
string type = "";
|
||||
std::size_t pos(0);
|
||||
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("type:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
|
||||
} else if (ref_attr_name == "none") {
|
||||
type = ref_attr_name;
|
||||
}
|
||||
|
||||
ValueNodePtr new_value_node;
|
||||
std::unordered_map<std::string, ValuePtr> multi_value_map;
|
||||
switch (kParseTypeSwitchMap[type]) {
|
||||
case FORM_PARSE_TYPE: {
|
||||
ObtainValueNodeInTypeForm(value_node_name, attr_proto.tensors(0));
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_SCALAR: {
|
||||
std::size_t value_pos(0);
|
||||
if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) {
|
||||
auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||
new_value_node = NewValueNode(res);
|
||||
new_value_node->set_abstract(res->ToAbstract());
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
break;
|
||||
}
|
||||
ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_TENSOR: {
|
||||
ObtainValueNodeInTensorForm(value_node_name, attr_proto.tensors(0));
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_NONE: {
|
||||
ObtainValueNodeInNoneForm(value_node_name, attr_proto);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
|
||||
if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) {
|
||||
auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
|
||||
new_value_node = NewValueNode(value_tuple_ptr);
|
||||
new_value_node->set_abstract(value_tuple_ptr->ToAbstract());
|
||||
} else {
|
||||
auto value_list_ptr = ParserScalarAttrValue<ValueList>(ref_attr_name, multi_value_map);
|
||||
new_value_node = NewValueNode(value_list_ptr);
|
||||
new_value_node->set_abstract(value_list_ptr->ToAbstract());
|
||||
}
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto) {
|
||||
const std::string &value_node_name = node_proto.output(0);
|
||||
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(0);
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
return GetAttrValueForValueNode(value_node_name, attr_proto);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode(
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
|
||||
for (int i = 0; i < attr_proto.tensors_size(); ++i) {
|
||||
ShapeVector shape_vec;
|
||||
const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
|
||||
shape_vec.push_back(attr_tensor.dims(j));
|
||||
}
|
||||
tensor::TensorPtr tensor_info =
|
||||
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
auto abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract));
|
||||
}
|
||||
return kv;
|
||||
}
|
||||
|
||||
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const mind_ir::NodeProto &node_proto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
if (!node_proto.has_op_type()) {
|
||||
MS_LOG(ERROR) << "Get CNode op_type failed!";
|
||||
return nullptr;
|
||||
}
|
||||
const std::string &node_name = node_proto.output(0);
|
||||
const std::string &fullname_with_scope = node_proto.domain();
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
|
||||
std::shared_ptr<Primitive> prim;
|
||||
auto op_primc_fns = OpPrimCRegister::GetInstance().GetPrimCMap();
|
||||
if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
|
||||
prim = op_primc_fns[node_type]();
|
||||
} else {
|
||||
prim = std::make_shared<Primitive>(node_type);
|
||||
prim->set_instance_name(node_type);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
|
||||
string shape_ref_attr_name;
|
||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
|
||||
shape_ref_attr_name = attr_proto.ref_attr_name();
|
||||
kv = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!GetAttrValueForCNode(prim, attr_proto)) {
|
||||
MS_LOG(ERROR) << "Get CNode attr failed!";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.clear();
|
||||
for (int i = 0; i < node_proto.input_size(); ++i) {
|
||||
const std::string &input_name = node_proto.input(i);
|
||||
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
|
||||
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inputs.push_back(anfnode_build_map_[input_name]);
|
||||
}
|
||||
|
||||
auto cnode_ptr = outputFuncGraph->NewCNode(prim, inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
|
||||
if (0 == kv.size()) {
|
||||
AbstractBasePtrList elem;
|
||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||
elem.push_back(cnode_ptr->input(index)->abstract());
|
||||
}
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else if (1 == kv.size()) {
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
|
||||
cnode_ptr->set_abstract(iter->second);
|
||||
} else {
|
||||
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
|
||||
cnode_ptr->set_abstract(abstract);
|
||||
}
|
||||
|
||||
string debug_info_name = ParseCNodeName(node_name);
|
||||
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
||||
cnode_ptr->set_debug_info(debug_info_ptr);
|
||||
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
|
||||
|
||||
anfnode_build_map_[node_name] = cnode_ptr;
|
||||
return cnode_ptr;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const mind_ir::GraphProto &importProto, const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
if (importProto.output_size() > 1) {
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
AbstractBasePtrList elem;
|
||||
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
|
||||
const mind_ir::ValueInfoProto &output_node = importProto.output(out_size);
|
||||
const std::string &out_tuple = output_node.name();
|
||||
inputs.push_back(anfnode_build_map_[out_tuple]);
|
||||
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
|
||||
}
|
||||
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(maketuple_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
|
||||
} else {
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(cnode_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const mind_ir::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
|
||||
CNodePtr cnode_ptr = nullptr;
|
||||
for (int i = 0; i < importProto.node_size(); ++i) {
|
||||
const mind_ir::NodeProto &node_proto = importProto.node(i);
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
if (node_type == kConstantValueNode) {
|
||||
if (!BuildValueNodeForFuncGraph(node_proto)) {
|
||||
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
|
||||
if (cnode_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
|
||||
MS_EXCEPTION_IF_NULL(debug_info_ptr);
|
||||
if (importProto.has_name()) {
|
||||
debug_info_ptr->set_name(importProto.name());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "FuncGraph under converting has not name!";
|
||||
}
|
||||
|
||||
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
|
||||
MS_LOG(ERROR) << "import parameters for graph fail!";
|
||||
return false;
|
||||
}
|
||||
return ImportNodesForGraph(outputFuncGraph, importProto);
|
||||
}
|
||||
|
||||
bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto) {
|
||||
if (!model_proto.has_producer_name()) {
|
||||
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
|
||||
return false;
|
||||
}
|
||||
producer_name_ = model_proto.producer_name();
|
||||
MS_LOG(INFO) << "producer_name :" << producer_name_;
|
||||
|
||||
if (!model_proto.has_model_version()) {
|
||||
MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
|
||||
return false;
|
||||
}
|
||||
model_version_ = model_proto.model_version();
|
||||
MS_LOG(INFO) << "producer_version : " << model_version_;
|
||||
|
||||
if (!model_proto.has_ir_version()) {
|
||||
MS_LOG(ERROR) << "Parse model version from pb file failed!";
|
||||
return false;
|
||||
}
|
||||
ir_version_ = model_proto.ir_version();
|
||||
MS_LOG(INFO) << "ir_version :" << ir_version_;
|
||||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
|
||||
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
|
||||
MS_EXCEPTION_IF_NULL(dstGraph);
|
||||
if (!MSANFParseModelConfigureInfo(model_proto)) {
|
||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||
}
|
||||
const mind_ir::GraphProto &graphBuild = model_proto.graph();
|
||||
if (!BuildFuncGraph(dstGraph, graphBuild)) {
|
||||
MS_LOG(ERROR) << "Build funcgraph failed!";
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "Parse pb to build FuncGraph Success!";
|
||||
return dstGraph;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -14,63 +14,62 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H
|
||||
#define MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H
|
||||
#ifndef MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H
|
||||
#define MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "proto/mind_ir.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
using int32 = int32_t;
|
||||
using int64 = int64_t;
|
||||
using uint64 = uint64_t;
|
||||
class MSANFModelParser {
|
||||
public:
|
||||
MSANFModelParser() : producer_name_(""), model_version_(0), ir_version_(0) {}
|
||||
MSANFModelParser() : producer_name_(""), model_version_(""), ir_version_("") {}
|
||||
~MSANFModelParser() = default;
|
||||
|
||||
FuncGraphPtr Parse(const onnx::ModelProto &model_proto);
|
||||
bool MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto);
|
||||
FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto);
|
||||
bool MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto);
|
||||
|
||||
std::string GetProducerName() { return producer_name_; }
|
||||
int GetProducerVersion() { return model_version_; }
|
||||
int GetIrVersion() { return ir_version_; }
|
||||
std::string GetProducerVersion() { return model_version_; }
|
||||
std::string GetIrVersion() { return ir_version_; }
|
||||
|
||||
private:
|
||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
|
||||
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto);
|
||||
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &tensor_proto);
|
||||
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
|
||||
tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto);
|
||||
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto);
|
||||
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto,
|
||||
const CNodePtr &cnode_ptr);
|
||||
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
|
||||
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
|
||||
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
|
||||
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_tensor);
|
||||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
||||
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
||||
void ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
|
||||
std::unordered_map<std::string, ValuePtr> *multi_value_map);
|
||||
ValuePtr ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index);
|
||||
ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
|
||||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
||||
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
|
||||
bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
||||
bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor);
|
||||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
||||
bool ObtainValueNodeInNoneForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> GetAbstractForCNode(
|
||||
const onnx::AttributeProto &attr_proto);
|
||||
const mind_ir::AttributeProto &attr_proto);
|
||||
|
||||
std::string producer_name_;
|
||||
int model_version_;
|
||||
int ir_version_;
|
||||
std::string model_version_;
|
||||
std::string ir_version_;
|
||||
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
|
||||
std::map<std::string, onnx::TensorProto> default_para_map_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_MODEL_PARSER_H
|
||||
#endif // MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "load_mindir/load_model.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
|
||||
#include "load_mindir/anf_model_parser.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
namespace mindspore {
|
||||
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "file is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
char real_path[PATH_MAX] = {0};
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
if (_fullpath(real_path, file.c_str(), PATH_MAX) == nullptr) {
|
||||
MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
|
||||
return nullptr;
|
||||
}
|
||||
#else
|
||||
if (realpath(file.c_str(), real_path) == nullptr) {
|
||||
MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::ifstream ifs(real_path);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "file: " << real_path << " is not exist";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "file: " << real_path << "open failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size));
|
||||
if (buf == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
|
||||
ifs.close();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(buf->data(), size);
|
||||
ifs.close();
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> RunLoadMindIR(const std::string &file_name) {
|
||||
auto graphBuf = ReadProtoFile(file_name);
|
||||
if (graphBuf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
try {
|
||||
auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size());
|
||||
return graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size) {
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
std::string str((const char *)buf, buf_size);
|
||||
mind_ir::ModelProto model_;
|
||||
if (!model_.ParseFromString(str)) {
|
||||
MS_LOG(ERROR) << "Parse model from buffer fail!";
|
||||
}
|
||||
MSANFModelParser model_parser;
|
||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
|
||||
return dstgraph_ptr;
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -13,27 +13,19 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_LOAD_MODEL_H
|
||||
#define MINDSPORE_CORE_LOAD_MODEL_H
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_CONVERTER_H
|
||||
#define MINDSPORE_CCSRC_UTILS_LOAD_ONNX_ANF_CONVERTER_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
|
||||
#include "proto/mind_ir.pb.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class AnfConverter {
|
||||
public:
|
||||
static std::shared_ptr<FuncGraph> RunAnfConverter(const std::string &file_path);
|
||||
static std::shared_ptr<FuncGraph> RunAnfConverter(const char *buf, const size_t buf_size);
|
||||
|
||||
private:
|
||||
static void Trim(std::string *input);
|
||||
static int ValidateFileStr(const std::string &modelFile, std::string fileType);
|
||||
static bool ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model);
|
||||
};
|
||||
} // namespace lite
|
||||
std::shared_ptr<FuncGraph> RunLoadMindIR(const std::string &file_name);
|
||||
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size);
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
#endif // MINDSPORE_CORE_LOAD_MODEL_H
|
|
@ -116,7 +116,9 @@ endif ()
|
|||
file(GLOB PROTO_FILE ""
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/proto/*.proto
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto
|
||||
${CCSRC_DIR}/utils/mind_ir.proto)
|
||||
|
||||
ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
|
||||
add_library(proto_mid OBJECT ${PROTO_SRCS})
|
||||
set(TFLITE_FBS_FILES
|
||||
|
|
|
@ -23,9 +23,12 @@ from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
||||
from mindspore.train.anf_ir_pb2 import ModelProto as anf_model
|
||||
|
||||
from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
|
||||
|
||||
|
||||
def _convert_type(types):
|
||||
"""
|
||||
Convert from numpy type to tensor type.
|
||||
|
@ -203,3 +206,32 @@ def check_value_type(arg_name, arg_value, valid_types):
|
|||
if not is_valid:
|
||||
raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
|
||||
f'bug got {type(arg_value).__name__}.')
|
||||
|
||||
|
||||
def read_proto(file_name, proto_format="MINDIR"):
|
||||
"""
|
||||
Read protobuf file.
|
||||
|
||||
Args:
|
||||
file_name (str): File name.
|
||||
proto_format (str): Proto format.
|
||||
|
||||
Returns:
|
||||
Object, proto object.
|
||||
"""
|
||||
|
||||
if proto_format == "MINDIR":
|
||||
model = mindir_model()
|
||||
elif model_format == "ANF":
|
||||
model = anf_model()
|
||||
else:
|
||||
raise ValueError("Unsupported proto format.")
|
||||
|
||||
try:
|
||||
with open(file_name, "rb") as f:
|
||||
pb_content = f.read()
|
||||
model.ParseFromString(pb_content)
|
||||
except BaseException as e:
|
||||
logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name)
|
||||
raise ValueError(e.__str__())
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue