From 222d2d25a346e68ae2527e0147364fc3ba592053 Mon Sep 17 00:00:00 2001 From: moran Date: Wed, 16 Mar 2022 16:36:23 +0800 Subject: [PATCH] lite supports converting onnx model(>2G) to ms file. --- .../parser/onnx/onnx_model_parser.cc | 29 +++-- .../converter/parser/onnx/onnx_model_parser.h | 1 + .../converter/parser/onnx/onnx_node_parser.cc | 101 ++++++++++++++++++ .../converter/parser/onnx/onnx_node_parser.h | 22 ++++ 4 files changed, 143 insertions(+), 10 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 894c039f396..537c33aaeb9 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -184,7 +184,8 @@ STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector(tensor.data_type())); if (data_type == kTypeUnknown) { @@ -194,7 +195,7 @@ STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::Tensor std::vector shape_vector(tensor.dims().begin(), tensor.dims().end()); auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type); if (abstract_tensor == nullptr) { - MS_LOG(ERROR) << "Create tensor abstarct failed"; + MS_LOG(ERROR) << "Create tensor abstract failed"; return RET_ERROR; } parameter_node->set_abstract(abstract_tensor); @@ -205,10 +206,18 @@ STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::Tensor std::vector shape; std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), [](const int64_t &value) { return static_cast(value); }); - auto status = OnnxNodeParser::CopyOnnxTensorData(tensor, tensor_info); - if (status != RET_OK) { - MS_LOG(ERROR) << "copy data failed."; - return status; + if (tensor.data_location() != onnx::TensorProto::EXTERNAL) { + auto status = OnnxNodeParser::CopyOnnxTensorData(tensor, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "copy data failed."; + return status; + } + } else { + auto status = OnnxNodeParser::LoadOnnxExternalTensorData(tensor, tensor_info, model_file); + if (status != RET_OK) { + MS_LOG(ERROR) << "load external data failed."; + return status; + } } parameter_node->set_default_param(tensor_info); return RET_OK; @@ -269,12 +278,12 @@ STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_ } STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr, - std::unordered_map *anf_nodes_map) { + std::unordered_map *anf_nodes_map, const std::string &model_file) { MS_ASSERT(func_graph_ptr != nullptr && anf_nodes_map != nullptr); for (const auto &onnx_const_value : onnx_graph.initializer()) { auto parameter = func_graph_ptr->add_parameter(); MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr"); - auto status = BuildParameterNode(parameter, onnx_const_value); + auto status = BuildParameterNode(parameter, onnx_const_value, model_file); if (status != RET_OK) { MS_LOG(ERROR) << "parameter node build failed."; return status; @@ -609,7 +618,7 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) { MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx"; return status; } - + model_file_ = model_file; status = ReadProtoFromBinaryFile(model_file, &onnx_model_); if (status != RET_OK) { MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file; @@ -637,7 +646,7 @@ STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, con MS_LOG(ERROR) << "input onnx model error: " << status; return status; } - status = ConvertConstTensors(onnx_graph, anf_graph, anf_nodes_map); + status = ConvertConstTensors(onnx_graph, anf_graph, anf_nodes_map, model_file_); if (RET_OK != status) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); MS_LOG(ERROR) << "convert const nodes failed."; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index f92cdd2adcf..4eb613bc7d9 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -82,6 +82,7 @@ class OnnxModelParser : public converter::ModelParser { std::unordered_map anf_nodes_map_{}; std::unordered_map *> control_nodes_map_{}; std::unordered_map child_root_map_{}; // for nest control flow node + std::string model_file_{}; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc index 446a488c955..864a17c2ef0 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -21,10 +21,12 @@ #include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "nnacl/op_base.h" +#include "src/common/file_utils.h" namespace mindspore { namespace lite { namespace { +constexpr int kMaxValidCharacters = 10; static std::unordered_map kOnnxTypeTransferMap = { {onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8}, {onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8}, @@ -39,6 +41,45 @@ static std::unordered_map kOnnxTypeTransferMap = { } // namespace int64_t OnnxNodeParser::opset_version_ = 0; +void *OnnxNodeParser::buffer_ = nullptr; + +STATUS ExternalDataInfo::Create(const google::protobuf::RepeatedPtrField &externalData, + ExternalDataInfo *externalDataInfo) { + const int data_size = externalData.size(); + for (int i = 0; i != data_size; ++i) { + onnx::StringStringEntryProto stringMap = externalData[i]; + if (!stringMap.has_key()) { + MS_LOG(ERROR) << "No key is in external data."; + return RET_ERROR; + } + if (!stringMap.has_value()) { + MS_LOG(ERROR) << "No value is in external data."; + return RET_ERROR; + } + + if (stringMap.key() == "location" && !stringMap.value().empty()) { + externalDataInfo->rel_path_ = stringMap.value(); + } else if (stringMap.key() == "offset" && !stringMap.value().empty()) { + externalDataInfo->offset_ = strtol(stringMap.value().c_str(), nullptr, kMaxValidCharacters); + if (std::to_string(externalDataInfo->offset_).length() != stringMap.value().length()) { + MS_LOG(ERROR) << "Failed to parse offset."; + return RET_ERROR; + } + } else if (stringMap.key() == "length" && !stringMap.value().empty()) { + externalDataInfo->length_ = static_cast(strtol(stringMap.value().c_str(), nullptr, kMaxValidCharacters)); + if (std::to_string(externalDataInfo->length_).length() != stringMap.value().length()) { + MS_LOG(ERROR) << "Failed to parse length."; + return RET_ERROR; + } + } else if (stringMap.key() == "checksum" && !stringMap.value().empty()) { + externalDataInfo->checksum_ = stringMap.value(); + } else { + MS_LOG(ERROR) << "Invalid model format"; + return RET_ERROR; + } + } + return RET_OK; +} mindspore::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) { if (onnx_node_attr.s() == "NOTSET") { @@ -173,6 +214,42 @@ size_t OnnxNodeParser::GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, b return data_count; } +STATUS OnnxNodeParser::LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_const_tensor, + const tensor::TensorPtr &tensor_info, const std::string &model_file) { + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "tensor_info is nullptr."; + return RET_NULL_PTR; + } + size_t data_size = 0; + const void *onnx_data = LoadOnnxRawData(onnx_const_tensor, &data_size, model_file); + if (data_size == 0) { + return RET_OK; + } + if (onnx_data == nullptr) { + MS_LOG(ERROR) << "origin data from external data is nullptr."; + return RET_MEMORY_FAILED; + } + auto tensor_data = reinterpret_cast(tensor_info->data_c()); + if (memcpy_s(tensor_data, tensor_info->data().nbytes(), onnx_data, data_size) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed."; + return RET_ERROR; + } + return RET_OK; +} + +STATUS OnnxNodeParser::SetExternalTensorFile(const std::string &model_file, std::string *external_tensor_dir) { + auto iEndIndex = model_file.find_last_of('/'); + if (iEndIndex == std::string::npos) { + iEndIndex = model_file.find_last_of('\\'); + } + if (iEndIndex == std::string::npos) { + *external_tensor_dir = "."; + } else { + *external_tensor_dir = model_file.substr(0, iEndIndex); + } + return RET_OK; +} + const void *OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, TypeId data_type, size_t data_count, size_t *data_size) { MS_ASSERT(data_size != nullptr); @@ -242,5 +319,29 @@ const void *OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_t } return onnx_data; } + +const void *OnnxNodeParser::LoadOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t *data_size, + const std::string &model_file) { + MS_ASSERT(data_size != nullptr); + const void *onnx_data = nullptr; + ExternalDataInfo externalDataInfo; + if (ExternalDataInfo::Create(onnx_const_tensor.external_data(), &externalDataInfo) != RET_OK) { + MS_LOG(ERROR) << "Create ExternalDataInfo failed."; + return nullptr; + } + std::string externalTensorDir; + if (SetExternalTensorFile(model_file, &externalTensorDir) != RET_OK) { + MS_LOG(ERROR) << "Failed to set external tensor file."; + return nullptr; + } +#ifdef _WIN32 + std::string externalDataFile = externalTensorDir + "\\" + externalDataInfo.GetRelPath(); +#else + std::string externalDataFile = externalTensorDir + "/" + externalDataInfo.GetRelPath(); +#endif + buffer_ = ReadFile(externalDataFile.c_str(), data_size); + onnx_data = std::move(buffer_); + return onnx_data; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h index 43e8ba13aca..06963b2d36c 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h @@ -31,6 +31,19 @@ namespace mindspore { namespace lite { +class ExternalDataInfo { + public: + const std::string GetRelPath() const { return rel_path_; } + static STATUS Create(const google::protobuf::RepeatedPtrField &externalData, + ExternalDataInfo *externalDataInfo); + + private: + std::string rel_path_; + off_t offset_ = 0; + size_t length_ = 0; + std::string checksum_; +}; + class OnnxNodeParser { public: explicit OnnxNodeParser(std::string node_name) : name_(std::move(node_name)) {} @@ -54,6 +67,9 @@ class OnnxNodeParser { static size_t GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, bool *overflowed); + static STATUS LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_const_tensor, + const tensor::TensorPtr &tensor_info, const std::string &model_file); + protected: static mindspore::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); @@ -62,10 +78,16 @@ class OnnxNodeParser { static const void *GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, TypeId data_type, size_t data_count, size_t *data_size); + static STATUS SetExternalTensorFile(const std::string &model_file, std::string *external_tensor_dir); + + static const void *LoadOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t *data_size, + const std::string &model_file); + const std::string name_{}; private: static int64_t opset_version_; + static void *buffer_; }; } // namespace lite } // namespace mindspore