!3687 keep lite import with mindspore export

Merge pull request !3687 from yankai10/merge123
This commit is contained in:
mindspore-ci-bot 2020-07-30 15:53:12 +08:00 committed by Gitee
commit 68666bd35e
4 changed files with 637 additions and 43 deletions

View File

@ -0,0 +1,35 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_reshape_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto attr = std::make_unique<schema::FlattenT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Flatten;
node->primitive->value.value = attr.release();
return 0;
}
AnfNodePopulaterRegistrar anfReshapeParser("Reshape", new AnfReshapePopulater());
} // namespace mindspore::lite

View File

@ -0,0 +1,30 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ANF_RESHAPE_PARSER_H
#define MINDSPORE_ANF_RESHAPE_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfReshapePopulater : public AnfNodePopulater {
public:
AnfReshapePopulater() = default;
~AnfReshapePopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_ANF_RESHAPE_PARSER_H

View File

@ -15,25 +15,28 @@
*/
#include "src/common/anf_importer/import_from_protobuf.h"
#include <fcntl.h>
#include <unistd.h>
#include <fstream>
#include <functional>
#include <map>
#include <stack>
#include <unordered_map>
#include <memory>
#include <stack>
#include <string>
#include <unordered_map>
#include <vector>
#include <fstream>
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "src/param_value_lite.h"
#include "src/ir/tensor.h"
#include "frontend/operator/ops.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "include/errorcode.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "src/ir/tensor.h"
#include "src/param_value_lite.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "utils/log_adapter.h"
#include "include/errorcode.h"
using string = std::string;
using int32 = int32_t;
@ -54,26 +57,27 @@ enum ParseForm : int {
};
static std::map<std::string, ParseForm> kParseTypeSwitchMap{
{"type", FORM_PARSE_TYPE},
{"scalar", FORM_PARSE_SCALAR},
{"tensor", FORM_PARSE_TENSOR}};
{"type", FORM_PARSE_TYPE},
{"scalar", FORM_PARSE_SCALAR},
{"tensor", FORM_PARSE_TENSOR}};
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool},
{onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16},
{onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64},
{onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16},
{onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64},
{onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
{onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool},
{onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16},
{onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64},
{onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16},
{onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64},
{onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
{onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
};
#if 0
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
const std::unordered_map<string, ValuePtr> &kv) {
std::string str = attr_name;
@ -190,16 +194,17 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, a
return {};
}
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
if (attr_tensor.type##_data_size() == 1) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
return MakeValue<valuetype>(value); \
} else { \
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
} \
return{}; \
}
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype( \
const onnx::TensorProto &attr_tensor) { \
if (attr_tensor.type##_data_size() == 1) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
return MakeValue<valuetype>(value); \
} else { \
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
} \
return {}; \
}
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
PARSE_ONNXATTR_IN_SCALAR_FORM(float, float)
@ -634,8 +639,508 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
}
#endif
bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
void ParseAttrInScalar_##type##_##valuetype( \
const PrimitivePtr &prim, const std::string &attr_name, \
const onnx::TensorProto &attr_tensor) { \
MS_EXCEPTION_IF_NULL(prim); \
std::vector<ValuePtr> attr_value_vec; \
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
} \
if (attr_value_vec.size() == 1) { \
prim->AddAttr(attr_name, attr_value_vec[0]); \
} else { \
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
} \
}
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
PARSE_ONNXATTR_IN_SCALAR_FORM(float, float)
PARSE_ONNXATTR_IN_SCALAR_FORM(string, string)
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32)
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(
const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) {
MS_EXCEPTION_IF_NULL(node);
if (!value_proto.has_type() || !value_proto.has_name()) {
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
return false;
}
node->set_name(value_proto.name());
const auto &type_proto = value_proto.type();
if (!type_proto.has_tensor_type()) {
MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! ";
return false;
}
const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type();
if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) {
MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! ";
return false;
}
const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape();
std::vector<int> shape;
for (int i = 0; i < tensor_shape.dim_size(); ++i) {
shape.push_back(tensor_shape.dim(i).dim_value());
}
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) ==
kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!";
return false;
}
auto type_ptr =
TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]);
auto abstract_tensor =
std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
node->set_abstract(abstract_tensor);
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
tensor::Tensor *tensor_info = new tensor::Tensor(
kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
tensor_info->MallocData();
const onnx::TensorProto initialize_proto =
default_para_map_[value_proto.name()];
std::string initial_data = initialize_proto.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
MS_EXCEPTION_IF_NULL(tensor_data_buf);
memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(),
initial_data.size());
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
MS_EXCEPTION_IF_NULL(param_value);
param_value->set_tensor_addr(tensor_data_buf);
param_value->set_tensor_size(tensor_info->Size());
node->set_default_param(param_value);
}
anfnode_build_map_[value_proto.name()] = node;
return true;
}
bool AnfImporterFromProtobuf::ImportParametersForGraph(
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "Parameters had default paramerer size is: "
<< importProto.initializer_size();
for (int i = 0; i < importProto.initializer_size(); ++i) {
const onnx::TensorProto &initializer_proto = importProto.initializer(i);
if (!initializer_proto.has_name()) {
MS_LOG(ERROR)
<< "initializer vector of onnx GraphProto has no name at index: "
<< i;
return false;
}
default_para_map_[initializer_proto.name()] = initializer_proto;
}
MS_LOG(INFO) << "all parameters size: " << importProto.input_size();
for (int i = 0; i < importProto.input_size(); ++i) {
const onnx::ValueInfoProto &input_proto = importProto.input(i);
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(),
input_proto)) {
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
return false;
}
}
return true;
}
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(
const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
const int attr_tensor_type = attr_tensor.data_type();
if (kDefaultValueSwitchMap.find(attr_tensor_type) ==
kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:"
<< attr_tensor_type;
return false;
}
prim->AddAttr(attr_name,
TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
return true;
}
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(
const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
const int attr_tensor_type = attr_tensor.data_type();
switch (attr_tensor_type) {
case onnx::TensorProto_DataType_STRING: {
ParseAttrInScalar_string_string(prim, attr_name, attr_tensor);
break;
}
case onnx::TensorProto_DataType_INT32: {
ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor);
break;
}
case onnx::TensorProto_DataType_INT64: {
ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor);
break;
}
case onnx::TensorProto_DataType_UINT64: {
ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor);
break;
}
case onnx::TensorProto_DataType_FLOAT: {
ParseAttrInScalar_float_float(prim, attr_name, attr_tensor);
break;
}
case onnx::TensorProto_DataType_DOUBLE: {
ParseAttrInScalar_double_double(prim, attr_name, attr_tensor);
break;
}
case onnx::TensorProto_DataType_BOOL: {
ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor);
auto value = prim->GetAttr(attr_name);
break;
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: "
<< attr_tensor_type;
return false;
}
return true;
}
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(
const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
MS_LOG(ERROR) << "parse attr type don't support attr type is tensor";
return false;
}
bool AnfImporterFromProtobuf::GetAttrValueForCNode(
const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
MS_EXCEPTION_IF_NULL(prim);
const std::string &attr_name = attr_proto.name();
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
const onnx::TensorProto &attr_tensor = attr_proto.t();
switch (kParseTypeSwitchMap[ref_attr_name]) {
case FORM_PARSE_TYPE: {
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
}
case FORM_PARSE_SCALAR: {
return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor);
}
case FORM_PARSE_TENSOR: {
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
}
default:
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return false;
}
}
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
std::vector<int> shape;
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(
kDefaultValueSwitchMap[attr_tensor_type], shape);
tensor_info->MallocData();
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(),
tensor_buf.size());
auto new_value_node = NewValueNode(MakeValue(tensor_info));
MS_EXCEPTION_IF_NULL(new_value_node);
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
auto abstract_tensor =
std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
new_value_node->set_abstract(abstract_tensor);
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
ValuePtr value_ptr = nullptr;
switch (attr_tensor_type) {
case onnx::TensorProto_DataType_INT32: {
std::vector<int32> add_data;
for (int i = 0; i < attr_tensor.int32_data_size(); ++i) {
add_data.push_back(attr_tensor.int32_data(i));
}
if (add_data.size() == 1) {
value_ptr = MakeValue(add_data[0]);
} else if (!add_data.empty()) {
value_ptr = MakeValue<std::vector<int32>>(add_data);
}
break;
}
case onnx::TensorProto_DataType_FLOAT: {
std::vector<float> add_data;
for (int i = 0; i < attr_tensor.float_data_size(); ++i) {
add_data.push_back(attr_tensor.float_data(i));
}
if (add_data.size() == 1) {
value_ptr = MakeValue(add_data[0]);
} else if (!add_data.empty()) {
value_ptr = MakeValue<std::vector<float>>(add_data);
}
break;
}
case onnx::TensorProto_DataType_UNDEFINED: {
std::vector<ValuePtr> elems;
value_ptr = std::make_shared<ValueTuple>(elems);
break;
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: "
<< attr_tensor_type;
return false;
}
auto new_value_node = NewValueNode(value_ptr);
MS_EXCEPTION_IF_NULL(new_value_node);
new_value_node->set_abstract(value_ptr->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
if (kDefaultValueSwitchMap.find(attr_tensor_type) ==
kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR)
<< "Obtain ValueNode attr in type-form has not support input type: "
<< attr_tensor_type;
return false;
}
auto new_value_node =
NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
abstract::AbstractTypePtr abs_type =
std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
new_value_node->set_abstract(abs_type);
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(
const std::string &ref_attr_name, const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
switch (kParseTypeSwitchMap[ref_attr_name]) {
case FORM_PARSE_SCALAR: {
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor);
}
case FORM_PARSE_TENSOR: {
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
}
case FORM_PARSE_TYPE: {
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
}
default:
MS_LOG(ERROR)
<< "parse ValueNode value don't support input of ref_attr_name";
return false;
}
}
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(
const onnx::NodeProto &node_proto) {
const std::string &value_node_name = node_proto.output(0);
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
const onnx::TensorProto &attr_tensor = attr_proto.t();
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor);
}
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(
const onnx::AttributeProto &attr_proto) {
std::vector<int> shape_vec;
const onnx::TensorProto &attr_tensor = attr_proto.t();
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape_vec.push_back(attr_tensor.dims(i));
}
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
auto abstract_tensor =
std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
MS_EXCEPTION_IF_NULL(abstract_tensor);
return abstract_tensor;
}
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(
const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
if (!node_proto.has_op_type()) {
MS_LOG(ERROR) << "Get CNode op_type failed!";
return nullptr;
}
const std::string &node_name = node_proto.output(0);
const std::string &fullname_with_scope = node_proto.domain();
const std::string &node_type = node_proto.op_type();
PrimitivePtr prim = std::make_shared<Primitive>(node_type);
MS_EXCEPTION_IF_NULL(prim);
prim->set_instance_name(node_type);
abstract::AbstractTensorPtr abstract = nullptr;
abstract::AbstractTensorPtr abstract_first = nullptr;
abstract::AbstractTensorPtr abstract_second = nullptr;
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
if (attr_proto.name() == kCNodeShapeAttr) {
abstract = GetAbstractForCNode(attr_proto);
continue;
}
if (attr_proto.name() == kCNodeShape1Attr) {
abstract_first = GetAbstractForCNode(attr_proto);
continue;
}
if (attr_proto.name() == kCNodeShape2Attr) {
abstract_second = GetAbstractForCNode(attr_proto);
continue;
}
if (!GetAttrValueForCNode(prim, attr_proto)) {
MS_LOG(ERROR) << "Get CNode attr failed!";
return nullptr;
}
}
std::vector<AnfNodePtr> inputs;
inputs.clear();
inputs.push_back(NewValueNode(prim));
for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << node_name << " input " << i << input_name
<< "can't find in nodes have parsed";
return nullptr;
}
inputs.push_back(anfnode_build_map_[input_name]);
}
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
if (node_type == "LayerNorm") {
AbstractBasePtrList elem;
elem.push_back(abstract);
elem.push_back(abstract_first);
elem.push_back(abstract_second);
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (node_type == "ArgMaxWithValue") {
AbstractBasePtrList elem;
elem.push_back(abstract);
elem.push_back(abstract_first);
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (nullptr == abstract) {
AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
elem.push_back(cnode_ptr->input(index)->abstract());
}
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else {
cnode_ptr->set_abstract(abstract);
}
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
anfnode_build_map_[node_name] = cnode_ptr;
return cnode_ptr;
}
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_EXCEPTION_IF_NULL(cnode_ptr);
std::vector<AnfNodePtr> inputs;
if (importProto.output_size() > 1) {
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
AbstractBasePtrList elem;
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
const onnx::ValueInfoProto &output_node = importProto.output(out_size);
const std::string &out_tuple = output_node.name();
inputs.push_back(anfnode_build_map_[out_tuple]);
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
}
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
maketuple_ptr->set_abstract(
std::make_shared<abstract::AbstractTuple>(elem));
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
inputs.push_back(maketuple_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
} else {
const onnx::ValueInfoProto &output_node = importProto.output(0);
const onnx::TypeProto &output_typeproto = output_node.type();
int output_type = output_typeproto.tensor_type().elem_type();
std::vector<int> output_shape;
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size();
++i) {
output_shape.push_back(
output_typeproto.tensor_type().shape().dim(i).dim_value());
}
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]);
auto abstract_tensor =
std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn));
inputs.push_back(cnode_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node);
return_node->set_abstract(abstract_tensor);
outputFuncGraph->set_return(return_node);
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
}
return true;
}
bool AnfImporterFromProtobuf::ImportNodesForGraph(
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
CNodePtr cnode_ptr = nullptr;
for (int i = 0; i < importProto.node_size(); ++i) {
const onnx::NodeProto &node_proto = importProto.node(i);
const std::string &node_type = node_proto.op_type();
if (node_type == kConstantValueNode) {
if (!BuildValueNodeForFuncGraph(node_proto)) {
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
return false;
}
continue;
}
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
if (cnode_ptr == nullptr) {
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
return false;
}
}
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
}
bool AnfImporterFromProtobuf::BuildFuncGraph(
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
MS_EXCEPTION_IF_NULL(debug_info_ptr);
@ -651,7 +1156,8 @@ bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph
return ImportNodesForGraph(outputFuncGraph, importProto);
}
bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
bool AnfImporterFromProtobuf::ParseModelConfigureInfo(
const onnx::ModelProto &model_proto) {
if (!model_proto.has_producer_name()) {
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
return false;
@ -672,7 +1178,6 @@ bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mo
return true;
}
int AnfImporterFromProtobuf::Import() {
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
MS_EXCEPTION_IF_NULL(dstGraph);
@ -689,9 +1194,9 @@ int AnfImporterFromProtobuf::Import() {
return RET_OK;
}
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
std::unique_ptr<char> onnx_file(new(std::nothrow) char[PATH_MAX]{0});
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(
const std::string &model_path) {
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) {
MS_LOG(ERROR) << "open file failed.";
return nullptr;
@ -707,11 +1212,10 @@ onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string
delete onnx_model;
return nullptr;
}
(void) close(fd);
(void)close(fd);
MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl;
return onnx_model;
}
FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; }
} // namespace mindspore::lite

View File

@ -47,6 +47,7 @@ class AnfImporterFromProtobuf : public AnfImporter {
bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto);
#if 0
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
@ -76,6 +77,30 @@ class AnfImporterFromProtobuf : public AnfImporter {
const onnx::TensorProto &attr_tensor);
std::unordered_map<std::string, abstract::AbstractTensorPtr>
GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
#endif
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto);
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const CNodePtr &cnode_ptr);
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
private:
std::string producer_name_;