!31385 [Lite] Converter lite supports loading ONNX model with external data.

Merge pull request !31385 from moran/master
This commit is contained in:
i-robot 2022-04-12 02:42:20 +00:00 committed by Gitee
commit 3b6d8c6091
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 143 additions and 10 deletions

View File

@ -186,7 +186,8 @@ STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodeP
return RET_OK;
}
STATUS BuildParameterNode(const ParameterPtr &parameter_node, const onnx::TensorProto &tensor) {
STATUS BuildParameterNode(const ParameterPtr &parameter_node, const onnx::TensorProto &tensor,
const std::string &model_file) {
MS_ASSERT(parameter_node != nullptr);
auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(tensor.data_type()));
if (data_type == kTypeUnknown) {
@ -196,7 +197,7 @@ STATUS BuildParameterNode(const ParameterPtr &parameter_node, const onnx::Tensor
std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end());
auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
MS_LOG(ERROR) << "Create tensor abstract failed";
return RET_ERROR;
}
parameter_node->set_abstract(abstract_tensor);
@ -207,10 +208,18 @@ STATUS BuildParameterNode(const ParameterPtr &parameter_node, const onnx::Tensor
std::vector<int> shape;
std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
[](const int64_t &value) { return static_cast<int>(value); });
auto status = OnnxNodeParser::CopyOnnxTensorData(tensor, tensor_info);
if (status != RET_OK) {
MS_LOG(ERROR) << "copy data failed.";
return status;
if (tensor.data_location() != onnx::TensorProto::EXTERNAL) {
auto status = OnnxNodeParser::CopyOnnxTensorData(tensor, tensor_info);
if (status != RET_OK) {
MS_LOG(ERROR) << "copy data failed.";
return status;
}
} else {
auto status = OnnxNodeParser::LoadOnnxExternalTensorData(tensor, tensor_info, model_file);
if (status != RET_OK) {
MS_LOG(ERROR) << "load external data failed.";
return status;
}
}
parameter_node->set_default_param(tensor_info);
return RET_OK;
@ -273,12 +282,12 @@ STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_
}
STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map) {
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const std::string &model_file) {
MS_ASSERT(func_graph_ptr != nullptr && anf_nodes_map != nullptr);
for (const auto &onnx_const_value : onnx_graph.initializer()) {
auto parameter = func_graph_ptr->add_parameter();
MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
auto status = BuildParameterNode(parameter, onnx_const_value);
auto status = BuildParameterNode(parameter, onnx_const_value, model_file);
if (status != RET_OK) {
MS_LOG(ERROR) << "parameter node build failed.";
return status;
@ -620,7 +629,7 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx";
return status;
}
model_file_ = model_file;
status = ReadProtoFromBinaryFile(model_file, &onnx_model_);
if (status != RET_OK) {
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file;
@ -648,7 +657,7 @@ STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, con
MS_LOG(ERROR) << "input onnx model error: " << status;
return status;
}
status = ConvertConstTensors(onnx_graph, anf_graph, anf_nodes_map);
status = ConvertConstTensors(onnx_graph, anf_graph, anf_nodes_map, model_file_);
if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "convert const nodes failed.";

View File

@ -86,6 +86,7 @@ class OnnxModelParser : public converter::ModelParser {
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_{};
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_{};
std::unordered_map<std::string, std::string> child_root_map_{}; // for nest control flow node
std::string model_file_{};
};
} // namespace lite
} // namespace mindspore

View File

@ -21,10 +21,12 @@
#include <unordered_map>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include "nnacl/op_base.h"
#include "src/common/file_utils.h"
namespace mindspore {
namespace lite {
namespace {
constexpr int kMaxValidCharacters = 10;
static std::unordered_map<int, mindspore::TypeId> kOnnxTypeTransferMap = {
{onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8},
{onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8},
@ -39,6 +41,45 @@ static std::unordered_map<int, mindspore::TypeId> kOnnxTypeTransferMap = {
} // namespace
int64_t OnnxNodeParser::opset_version_ = 0;
void *OnnxNodeParser::buffer_ = nullptr;
STATUS ExternalDataInfo::Create(const google::protobuf::RepeatedPtrField<onnx::StringStringEntryProto> &externalData,
ExternalDataInfo *externalDataInfo) {
const int data_size = externalData.size();
for (int i = 0; i != data_size; ++i) {
onnx::StringStringEntryProto stringMap = externalData[i];
if (!stringMap.has_key()) {
MS_LOG(ERROR) << "No key is in external data.";
return RET_ERROR;
}
if (!stringMap.has_value()) {
MS_LOG(ERROR) << "No value is in external data.";
return RET_ERROR;
}
if (stringMap.key() == "location" && !stringMap.value().empty()) {
externalDataInfo->rel_path_ = stringMap.value();
} else if (stringMap.key() == "offset" && !stringMap.value().empty()) {
externalDataInfo->offset_ = strtol(stringMap.value().c_str(), nullptr, kMaxValidCharacters);
if (std::to_string(externalDataInfo->offset_).length() != stringMap.value().length()) {
MS_LOG(ERROR) << "Failed to parse offset.";
return RET_ERROR;
}
} else if (stringMap.key() == "length" && !stringMap.value().empty()) {
externalDataInfo->length_ = static_cast<size_t>(strtol(stringMap.value().c_str(), nullptr, kMaxValidCharacters));
if (std::to_string(externalDataInfo->length_).length() != stringMap.value().length()) {
MS_LOG(ERROR) << "Failed to parse length.";
return RET_ERROR;
}
} else if (stringMap.key() == "checksum" && !stringMap.value().empty()) {
externalDataInfo->checksum_ = stringMap.value();
} else {
MS_LOG(ERROR) << "Invalid model format";
return RET_ERROR;
}
}
return RET_OK;
}
mindspore::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) {
if (onnx_node_attr.s() == "NOTSET") {
@ -173,6 +214,42 @@ size_t OnnxNodeParser::GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, b
return data_count;
}
STATUS OnnxNodeParser::LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_const_tensor,
const tensor::TensorPtr &tensor_info, const std::string &model_file) {
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "tensor_info is nullptr.";
return RET_NULL_PTR;
}
size_t data_size = 0;
const void *onnx_data = LoadOnnxRawData(onnx_const_tensor, &data_size, model_file);
if (data_size == 0) {
return RET_OK;
}
if (onnx_data == nullptr) {
MS_LOG(ERROR) << "origin data from external data is nullptr.";
return RET_MEMORY_FAILED;
}
auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
if (memcpy_s(tensor_data, tensor_info->data().nbytes(), onnx_data, data_size) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_ERROR;
}
return RET_OK;
}
STATUS OnnxNodeParser::SetExternalTensorFile(const std::string &model_file, std::string *external_tensor_dir) {
auto iEndIndex = model_file.find_last_of('/');
if (iEndIndex == std::string::npos) {
iEndIndex = model_file.find_last_of('\\');
}
if (iEndIndex == std::string::npos) {
*external_tensor_dir = ".";
} else {
*external_tensor_dir = model_file.substr(0, iEndIndex);
}
return RET_OK;
}
const void *OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, TypeId data_type,
size_t data_count, size_t *data_size) {
MS_ASSERT(data_size != nullptr);
@ -242,5 +319,29 @@ const void *OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_t
}
return onnx_data;
}
const void *OnnxNodeParser::LoadOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t *data_size,
const std::string &model_file) {
MS_ASSERT(data_size != nullptr);
const void *onnx_data = nullptr;
ExternalDataInfo externalDataInfo;
if (ExternalDataInfo::Create(onnx_const_tensor.external_data(), &externalDataInfo) != RET_OK) {
MS_LOG(ERROR) << "Create ExternalDataInfo failed.";
return nullptr;
}
std::string externalTensorDir;
if (SetExternalTensorFile(model_file, &externalTensorDir) != RET_OK) {
MS_LOG(ERROR) << "Failed to set external tensor file.";
return nullptr;
}
#ifdef _WIN32
std::string externalDataFile = externalTensorDir + "\\" + externalDataInfo.GetRelPath();
#else
std::string externalDataFile = externalTensorDir + "/" + externalDataInfo.GetRelPath();
#endif
buffer_ = ReadFile(externalDataFile.c_str(), data_size);
onnx_data = std::move(buffer_);
return onnx_data;
}
} // namespace lite
} // namespace mindspore

View File

@ -37,6 +37,19 @@
namespace mindspore {
namespace lite {
class ExternalDataInfo {
public:
const std::string GetRelPath() const { return rel_path_; }
static STATUS Create(const google::protobuf::RepeatedPtrField<onnx::StringStringEntryProto> &externalData,
ExternalDataInfo *externalDataInfo);
private:
std::string rel_path_;
off_t offset_ = 0;
size_t length_ = 0;
std::string checksum_;
};
class OnnxNodeParser {
public:
explicit OnnxNodeParser(std::string node_name) : name_(std::move(node_name)) {}
@ -58,6 +71,9 @@ class OnnxNodeParser {
static size_t GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, bool *overflowed);
static STATUS LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_const_tensor,
const tensor::TensorPtr &tensor_info, const std::string &model_file);
protected:
static mindspore::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr);
@ -66,10 +82,16 @@ class OnnxNodeParser {
static const void *GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, TypeId data_type, size_t data_count,
size_t *data_size);
static STATUS SetExternalTensorFile(const std::string &model_file, std::string *external_tensor_dir);
static const void *LoadOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t *data_size,
const std::string &model_file);
const std::string name_{};
private:
static int64_t opset_version_;
static void *buffer_;
};
} // namespace lite
} // namespace mindspore