!31385 [Lite] Converter lite supports loading ONNX model with external data.
Merge pull request !31385 from moran/master
This commit is contained in:
commit
3b6d8c6091
|
@ -186,7 +186,8 @@ STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodeP
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor) {
|
||||
STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor,
|
||||
const std::string &model_file) {
|
||||
MS_ASSERT(parameter_node != nullptr);
|
||||
auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(tensor.data_type()));
|
||||
if (data_type == kTypeUnknown) {
|
||||
|
@ -196,7 +197,7 @@ STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::Tensor
|
|||
std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end());
|
||||
auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
|
||||
if (abstract_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
MS_LOG(ERROR) << "Create tensor abstract failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
parameter_node->set_abstract(abstract_tensor);
|
||||
|
@ -207,10 +208,18 @@ STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::Tensor
|
|||
std::vector<int> shape;
|
||||
std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
auto status = OnnxNodeParser::CopyOnnxTensorData(tensor, tensor_info);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "copy data failed.";
|
||||
return status;
|
||||
if (tensor.data_location() != onnx::TensorProto::EXTERNAL) {
|
||||
auto status = OnnxNodeParser::CopyOnnxTensorData(tensor, tensor_info);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "copy data failed.";
|
||||
return status;
|
||||
}
|
||||
} else {
|
||||
auto status = OnnxNodeParser::LoadOnnxExternalTensorData(tensor, tensor_info, model_file);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "load external data failed.";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
parameter_node->set_default_param(tensor_info);
|
||||
return RET_OK;
|
||||
|
@ -273,12 +282,12 @@ STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_
|
|||
}
|
||||
|
||||
STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map) {
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const std::string &model_file) {
|
||||
MS_ASSERT(func_graph_ptr != nullptr && anf_nodes_map != nullptr);
|
||||
for (const auto &onnx_const_value : onnx_graph.initializer()) {
|
||||
auto parameter = func_graph_ptr->add_parameter();
|
||||
MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
|
||||
auto status = BuildParameterNode(parameter, onnx_const_value);
|
||||
auto status = BuildParameterNode(parameter, onnx_const_value, model_file);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "parameter node build failed.";
|
||||
return status;
|
||||
|
@ -620,7 +629,7 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
|
|||
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx";
|
||||
return status;
|
||||
}
|
||||
|
||||
model_file_ = model_file;
|
||||
status = ReadProtoFromBinaryFile(model_file, &onnx_model_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file;
|
||||
|
@ -648,7 +657,7 @@ STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, con
|
|||
MS_LOG(ERROR) << "input onnx model error: " << status;
|
||||
return status;
|
||||
}
|
||||
status = ConvertConstTensors(onnx_graph, anf_graph, anf_nodes_map);
|
||||
status = ConvertConstTensors(onnx_graph, anf_graph, anf_nodes_map, model_file_);
|
||||
if (RET_OK != status) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
MS_LOG(ERROR) << "convert const nodes failed.";
|
||||
|
|
|
@ -86,6 +86,7 @@ class OnnxModelParser : public converter::ModelParser {
|
|||
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_{};
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_{};
|
||||
std::unordered_map<std::string, std::string> child_root_map_{}; // for nest control flow node
|
||||
std::string model_file_{};
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,10 +21,12 @@
|
|||
#include <unordered_map>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/common/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr int kMaxValidCharacters = 10;
|
||||
static std::unordered_map<int, mindspore::TypeId> kOnnxTypeTransferMap = {
|
||||
{onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8},
|
||||
{onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8},
|
||||
|
@ -39,6 +41,45 @@ static std::unordered_map<int, mindspore::TypeId> kOnnxTypeTransferMap = {
|
|||
} // namespace
|
||||
|
||||
int64_t OnnxNodeParser::opset_version_ = 0;
|
||||
void *OnnxNodeParser::buffer_ = nullptr;
|
||||
|
||||
STATUS ExternalDataInfo::Create(const google::protobuf::RepeatedPtrField<onnx::StringStringEntryProto> &externalData,
|
||||
ExternalDataInfo *externalDataInfo) {
|
||||
const int data_size = externalData.size();
|
||||
for (int i = 0; i != data_size; ++i) {
|
||||
onnx::StringStringEntryProto stringMap = externalData[i];
|
||||
if (!stringMap.has_key()) {
|
||||
MS_LOG(ERROR) << "No key is in external data.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!stringMap.has_value()) {
|
||||
MS_LOG(ERROR) << "No value is in external data.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (stringMap.key() == "location" && !stringMap.value().empty()) {
|
||||
externalDataInfo->rel_path_ = stringMap.value();
|
||||
} else if (stringMap.key() == "offset" && !stringMap.value().empty()) {
|
||||
externalDataInfo->offset_ = strtol(stringMap.value().c_str(), nullptr, kMaxValidCharacters);
|
||||
if (std::to_string(externalDataInfo->offset_).length() != stringMap.value().length()) {
|
||||
MS_LOG(ERROR) << "Failed to parse offset.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (stringMap.key() == "length" && !stringMap.value().empty()) {
|
||||
externalDataInfo->length_ = static_cast<size_t>(strtol(stringMap.value().c_str(), nullptr, kMaxValidCharacters));
|
||||
if (std::to_string(externalDataInfo->length_).length() != stringMap.value().length()) {
|
||||
MS_LOG(ERROR) << "Failed to parse length.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (stringMap.key() == "checksum" && !stringMap.value().empty()) {
|
||||
externalDataInfo->checksum_ = stringMap.value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid model format";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
mindspore::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) {
|
||||
if (onnx_node_attr.s() == "NOTSET") {
|
||||
|
@ -173,6 +214,42 @@ size_t OnnxNodeParser::GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, b
|
|||
return data_count;
|
||||
}
|
||||
|
||||
STATUS OnnxNodeParser::LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_const_tensor,
|
||||
const tensor::TensorPtr &tensor_info, const std::string &model_file) {
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_info is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
size_t data_size = 0;
|
||||
const void *onnx_data = LoadOnnxRawData(onnx_const_tensor, &data_size, model_file);
|
||||
if (data_size == 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (onnx_data == nullptr) {
|
||||
MS_LOG(ERROR) << "origin data from external data is nullptr.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
if (memcpy_s(tensor_data, tensor_info->data().nbytes(), onnx_data, data_size) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxNodeParser::SetExternalTensorFile(const std::string &model_file, std::string *external_tensor_dir) {
|
||||
auto iEndIndex = model_file.find_last_of('/');
|
||||
if (iEndIndex == std::string::npos) {
|
||||
iEndIndex = model_file.find_last_of('\\');
|
||||
}
|
||||
if (iEndIndex == std::string::npos) {
|
||||
*external_tensor_dir = ".";
|
||||
} else {
|
||||
*external_tensor_dir = model_file.substr(0, iEndIndex);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
const void *OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, TypeId data_type,
|
||||
size_t data_count, size_t *data_size) {
|
||||
MS_ASSERT(data_size != nullptr);
|
||||
|
@ -242,5 +319,29 @@ const void *OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_t
|
|||
}
|
||||
return onnx_data;
|
||||
}
|
||||
|
||||
const void *OnnxNodeParser::LoadOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t *data_size,
|
||||
const std::string &model_file) {
|
||||
MS_ASSERT(data_size != nullptr);
|
||||
const void *onnx_data = nullptr;
|
||||
ExternalDataInfo externalDataInfo;
|
||||
if (ExternalDataInfo::Create(onnx_const_tensor.external_data(), &externalDataInfo) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Create ExternalDataInfo failed.";
|
||||
return nullptr;
|
||||
}
|
||||
std::string externalTensorDir;
|
||||
if (SetExternalTensorFile(model_file, &externalTensorDir) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Failed to set external tensor file.";
|
||||
return nullptr;
|
||||
}
|
||||
#ifdef _WIN32
|
||||
std::string externalDataFile = externalTensorDir + "\\" + externalDataInfo.GetRelPath();
|
||||
#else
|
||||
std::string externalDataFile = externalTensorDir + "/" + externalDataInfo.GetRelPath();
|
||||
#endif
|
||||
buffer_ = ReadFile(externalDataFile.c_str(), data_size);
|
||||
onnx_data = std::move(buffer_);
|
||||
return onnx_data;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,6 +37,19 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ExternalDataInfo {
|
||||
public:
|
||||
const std::string GetRelPath() const { return rel_path_; }
|
||||
static STATUS Create(const google::protobuf::RepeatedPtrField<onnx::StringStringEntryProto> &externalData,
|
||||
ExternalDataInfo *externalDataInfo);
|
||||
|
||||
private:
|
||||
std::string rel_path_;
|
||||
off_t offset_ = 0;
|
||||
size_t length_ = 0;
|
||||
std::string checksum_;
|
||||
};
|
||||
|
||||
class OnnxNodeParser {
|
||||
public:
|
||||
explicit OnnxNodeParser(std::string node_name) : name_(std::move(node_name)) {}
|
||||
|
@ -58,6 +71,9 @@ class OnnxNodeParser {
|
|||
|
||||
static size_t GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, bool *overflowed);
|
||||
|
||||
static STATUS LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_const_tensor,
|
||||
const tensor::TensorPtr &tensor_info, const std::string &model_file);
|
||||
|
||||
protected:
|
||||
static mindspore::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr);
|
||||
|
||||
|
@ -66,10 +82,16 @@ class OnnxNodeParser {
|
|||
static const void *GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, TypeId data_type, size_t data_count,
|
||||
size_t *data_size);
|
||||
|
||||
static STATUS SetExternalTensorFile(const std::string &model_file, std::string *external_tensor_dir);
|
||||
|
||||
static const void *LoadOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t *data_size,
|
||||
const std::string &model_file);
|
||||
|
||||
const std::string name_{};
|
||||
|
||||
private:
|
||||
static int64_t opset_version_;
|
||||
static void *buffer_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue