!3687 keep lite import with mindspore export
Merge pull request !3687 from yankai10/merge123
This commit is contained in:
commit
68666bd35e
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/common/anf_exporter/anf_populater/anf_reshape_populater.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
|
||||
std::vector<schema::TensorT *> *outputs) {
|
||||
auto attr = std::make_unique<schema::FlattenT>();
|
||||
node->nodeType = schema::NodeType_CNode;
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_Flatten;
|
||||
node->primitive->value.value = attr.release();
|
||||
return 0;
|
||||
}
|
||||
|
||||
AnfNodePopulaterRegistrar anfReshapeParser("Reshape", new AnfReshapePopulater());
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_ANF_RESHAPE_PARSER_H
|
||||
#define MINDSPORE_ANF_RESHAPE_PARSER_H
|
||||
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
|
||||
#include <vector>
|
||||
namespace mindspore::lite {
|
||||
class AnfReshapePopulater : public AnfNodePopulater {
|
||||
public:
|
||||
AnfReshapePopulater() = default;
|
||||
~AnfReshapePopulater() override = default;
|
||||
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_ANF_RESHAPE_PARSER_H
|
|
@ -15,25 +15,28 @@
|
|||
*/
|
||||
|
||||
#include "src/common/anf_importer/import_from_protobuf.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <stack>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "src/ir/tensor.h"
|
||||
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "src/ir/tensor.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "tools/converter/parser/onnx/onnx.pb.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using string = std::string;
|
||||
using int32 = int32_t;
|
||||
|
@ -54,26 +57,27 @@ enum ParseForm : int {
|
|||
};
|
||||
|
||||
static std::map<std::string, ParseForm> kParseTypeSwitchMap{
|
||||
{"type", FORM_PARSE_TYPE},
|
||||
{"scalar", FORM_PARSE_SCALAR},
|
||||
{"tensor", FORM_PARSE_TENSOR}};
|
||||
{"type", FORM_PARSE_TYPE},
|
||||
{"scalar", FORM_PARSE_SCALAR},
|
||||
{"tensor", FORM_PARSE_TENSOR}};
|
||||
|
||||
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
|
||||
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool},
|
||||
{onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
|
||||
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16},
|
||||
{onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
|
||||
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64},
|
||||
{onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
|
||||
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16},
|
||||
{onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
|
||||
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64},
|
||||
{onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
|
||||
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
|
||||
{onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
|
||||
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
|
||||
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool},
|
||||
{onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
|
||||
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16},
|
||||
{onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
|
||||
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64},
|
||||
{onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
|
||||
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16},
|
||||
{onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
|
||||
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64},
|
||||
{onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
|
||||
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
|
||||
{onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
|
||||
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
|
||||
};
|
||||
|
||||
#if 0
|
||||
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
|
||||
const std::unordered_map<string, ValuePtr> &kv) {
|
||||
std::string str = attr_name;
|
||||
|
@ -190,16 +194,17 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, a
|
|||
return {};
|
||||
}
|
||||
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
|
||||
if (attr_tensor.type##_data_size() == 1) { \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
|
||||
return MakeValue<valuetype>(value); \
|
||||
} else { \
|
||||
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
|
||||
} \
|
||||
return{}; \
|
||||
}
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
ValuePtr ParseAttrInScalar_##type##_##valuetype( \
|
||||
const onnx::TensorProto &attr_tensor) { \
|
||||
if (attr_tensor.type##_data_size() == 1) { \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
|
||||
return MakeValue<valuetype>(value); \
|
||||
} else { \
|
||||
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
|
||||
} \
|
||||
return {}; \
|
||||
}
|
||||
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(float, float)
|
||||
|
@ -634,8 +639,508 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
|
|||
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
void ParseAttrInScalar_##type##_##valuetype( \
|
||||
const PrimitivePtr &prim, const std::string &attr_name, \
|
||||
const onnx::TensorProto &attr_tensor) { \
|
||||
MS_EXCEPTION_IF_NULL(prim); \
|
||||
std::vector<ValuePtr> attr_value_vec; \
|
||||
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
|
||||
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
|
||||
} \
|
||||
if (attr_value_vec.size() == 1) { \
|
||||
prim->AddAttr(attr_name, attr_value_vec[0]); \
|
||||
} else { \
|
||||
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
|
||||
} \
|
||||
}
|
||||
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(float, float)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(string, string)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(
|
||||
const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!value_proto.has_type() || !value_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
|
||||
return false;
|
||||
}
|
||||
node->set_name(value_proto.name());
|
||||
const auto &type_proto = value_proto.type();
|
||||
if (!type_proto.has_tensor_type()) {
|
||||
MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! ";
|
||||
return false;
|
||||
}
|
||||
const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type();
|
||||
if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) {
|
||||
MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! ";
|
||||
return false;
|
||||
}
|
||||
const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape();
|
||||
std::vector<int> shape;
|
||||
for (int i = 0; i < tensor_shape.dim_size(); ++i) {
|
||||
shape.push_back(tensor_shape.dim(i).dim_value());
|
||||
}
|
||||
|
||||
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) ==
|
||||
kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto type_ptr =
|
||||
TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]);
|
||||
auto abstract_tensor =
|
||||
std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
||||
node->set_abstract(abstract_tensor);
|
||||
|
||||
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
|
||||
tensor::Tensor *tensor_info = new tensor::Tensor(
|
||||
kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
tensor_info->MallocData();
|
||||
const onnx::TensorProto initialize_proto =
|
||||
default_para_map_[value_proto.name()];
|
||||
std::string initial_data = initialize_proto.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
|
||||
MS_EXCEPTION_IF_NULL(tensor_data_buf);
|
||||
memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(),
|
||||
initial_data.size());
|
||||
|
||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||
MS_EXCEPTION_IF_NULL(param_value);
|
||||
param_value->set_tensor_addr(tensor_data_buf);
|
||||
param_value->set_tensor_size(tensor_info->Size());
|
||||
node->set_default_param(param_value);
|
||||
}
|
||||
anfnode_build_map_[value_proto.name()] = node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ImportParametersForGraph(
|
||||
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "Parameters had default paramerer size is: "
|
||||
<< importProto.initializer_size();
|
||||
|
||||
for (int i = 0; i < importProto.initializer_size(); ++i) {
|
||||
const onnx::TensorProto &initializer_proto = importProto.initializer(i);
|
||||
if (!initializer_proto.has_name()) {
|
||||
MS_LOG(ERROR)
|
||||
<< "initializer vector of onnx GraphProto has no name at index: "
|
||||
<< i;
|
||||
return false;
|
||||
}
|
||||
default_para_map_[initializer_proto.name()] = initializer_proto;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "all parameters size: " << importProto.input_size();
|
||||
for (int i = 0; i < importProto.input_size(); ++i) {
|
||||
const onnx::ValueInfoProto &input_proto = importProto.input(i);
|
||||
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(),
|
||||
input_proto)) {
|
||||
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(
|
||||
const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) ==
|
||||
kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:"
|
||||
<< attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
prim->AddAttr(attr_name,
|
||||
TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(
|
||||
const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
switch (attr_tensor_type) {
|
||||
case onnx::TensorProto_DataType_STRING: {
|
||||
ParseAttrInScalar_string_string(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_INT32: {
|
||||
ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_INT64: {
|
||||
ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_UINT64: {
|
||||
ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_FLOAT: {
|
||||
ParseAttrInScalar_float_float(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_DOUBLE: {
|
||||
ParseAttrInScalar_double_double(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_BOOL: {
|
||||
ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor);
|
||||
auto value = prim->GetAttr(attr_name);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: "
|
||||
<< attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(
|
||||
const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_LOG(ERROR) << "parse attr type don't support attr type is tensor";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForCNode(
|
||||
const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const std::string &attr_name = attr_proto.name();
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
||||
switch (kParseTypeSwitchMap[ref_attr_name]) {
|
||||
case FORM_PARSE_TYPE: {
|
||||
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_SCALAR: {
|
||||
return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_TENSOR: {
|
||||
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(
|
||||
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
std::vector<int> shape;
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(
|
||||
kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
tensor_info->MallocData();
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
|
||||
memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(),
|
||||
tensor_buf.size());
|
||||
auto new_value_node = NewValueNode(MakeValue(tensor_info));
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
|
||||
auto abstract_tensor =
|
||||
std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
||||
new_value_node->set_abstract(abstract_tensor);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(
|
||||
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
ValuePtr value_ptr = nullptr;
|
||||
switch (attr_tensor_type) {
|
||||
case onnx::TensorProto_DataType_INT32: {
|
||||
std::vector<int32> add_data;
|
||||
for (int i = 0; i < attr_tensor.int32_data_size(); ++i) {
|
||||
add_data.push_back(attr_tensor.int32_data(i));
|
||||
}
|
||||
if (add_data.size() == 1) {
|
||||
value_ptr = MakeValue(add_data[0]);
|
||||
} else if (!add_data.empty()) {
|
||||
value_ptr = MakeValue<std::vector<int32>>(add_data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_FLOAT: {
|
||||
std::vector<float> add_data;
|
||||
for (int i = 0; i < attr_tensor.float_data_size(); ++i) {
|
||||
add_data.push_back(attr_tensor.float_data(i));
|
||||
}
|
||||
|
||||
if (add_data.size() == 1) {
|
||||
value_ptr = MakeValue(add_data[0]);
|
||||
} else if (!add_data.empty()) {
|
||||
value_ptr = MakeValue<std::vector<float>>(add_data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_UNDEFINED: {
|
||||
std::vector<ValuePtr> elems;
|
||||
value_ptr = std::make_shared<ValueTuple>(elems);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: "
|
||||
<< attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
auto new_value_node = NewValueNode(value_ptr);
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
new_value_node->set_abstract(value_ptr->ToAbstract());
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(
|
||||
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) ==
|
||||
kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR)
|
||||
<< "Obtain ValueNode attr in type-form has not support input type: "
|
||||
<< attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
auto new_value_node =
|
||||
NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
abstract::AbstractTypePtr abs_type =
|
||||
std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
|
||||
new_value_node->set_abstract(abs_type);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(
|
||||
const std::string &ref_attr_name, const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
switch (kParseTypeSwitchMap[ref_attr_name]) {
|
||||
case FORM_PARSE_SCALAR: {
|
||||
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_TENSOR: {
|
||||
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_TYPE: {
|
||||
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR)
|
||||
<< "parse ValueNode value don't support input of ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(
|
||||
const onnx::NodeProto &node_proto) {
|
||||
const std::string &value_node_name = node_proto.output(0);
|
||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
||||
|
||||
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor);
|
||||
}
|
||||
|
||||
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(
|
||||
const onnx::AttributeProto &attr_proto) {
|
||||
std::vector<int> shape_vec;
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape_vec.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
|
||||
auto abstract_tensor =
|
||||
std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tensor);
|
||||
return abstract_tensor;
|
||||
}
|
||||
|
||||
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(
|
||||
const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
if (!node_proto.has_op_type()) {
|
||||
MS_LOG(ERROR) << "Get CNode op_type failed!";
|
||||
return nullptr;
|
||||
}
|
||||
const std::string &node_name = node_proto.output(0);
|
||||
const std::string &fullname_with_scope = node_proto.domain();
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
PrimitivePtr prim = std::make_shared<Primitive>(node_type);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
prim->set_instance_name(node_type);
|
||||
|
||||
abstract::AbstractTensorPtr abstract = nullptr;
|
||||
abstract::AbstractTensorPtr abstract_first = nullptr;
|
||||
abstract::AbstractTensorPtr abstract_second = nullptr;
|
||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||
if (attr_proto.name() == kCNodeShapeAttr) {
|
||||
abstract = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
if (attr_proto.name() == kCNodeShape1Attr) {
|
||||
abstract_first = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
if (attr_proto.name() == kCNodeShape2Attr) {
|
||||
abstract_second = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
if (!GetAttrValueForCNode(prim, attr_proto)) {
|
||||
MS_LOG(ERROR) << "Get CNode attr failed!";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim));
|
||||
for (int i = 0; i < node_proto.input_size(); ++i) {
|
||||
const std::string &input_name = node_proto.input(i);
|
||||
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
|
||||
MS_LOG(ERROR) << node_name << " input " << i << input_name
|
||||
<< "can't find in nodes have parsed";
|
||||
return nullptr;
|
||||
}
|
||||
inputs.push_back(anfnode_build_map_[input_name]);
|
||||
}
|
||||
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
if (node_type == "LayerNorm") {
|
||||
AbstractBasePtrList elem;
|
||||
elem.push_back(abstract);
|
||||
elem.push_back(abstract_first);
|
||||
elem.push_back(abstract_second);
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else if (node_type == "ArgMaxWithValue") {
|
||||
AbstractBasePtrList elem;
|
||||
elem.push_back(abstract);
|
||||
elem.push_back(abstract_first);
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else if (nullptr == abstract) {
|
||||
AbstractBasePtrList elem;
|
||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||
elem.push_back(cnode_ptr->input(index)->abstract());
|
||||
}
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else {
|
||||
cnode_ptr->set_abstract(abstract);
|
||||
}
|
||||
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
|
||||
anfnode_build_map_[node_name] = cnode_ptr;
|
||||
return cnode_ptr;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(
|
||||
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
if (importProto.output_size() > 1) {
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
AbstractBasePtrList elem;
|
||||
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
|
||||
const onnx::ValueInfoProto &output_node = importProto.output(out_size);
|
||||
const std::string &out_tuple = output_node.name();
|
||||
inputs.push_back(anfnode_build_map_[out_tuple]);
|
||||
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
|
||||
}
|
||||
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
maketuple_ptr->set_abstract(
|
||||
std::make_shared<abstract::AbstractTuple>(elem));
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(maketuple_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
|
||||
} else {
|
||||
const onnx::ValueInfoProto &output_node = importProto.output(0);
|
||||
const onnx::TypeProto &output_typeproto = output_node.type();
|
||||
int output_type = output_typeproto.tensor_type().elem_type();
|
||||
std::vector<int> output_shape;
|
||||
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size();
|
||||
++i) {
|
||||
output_shape.push_back(
|
||||
output_typeproto.tensor_type().shape().dim(i).dim_value());
|
||||
}
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]);
|
||||
auto abstract_tensor =
|
||||
std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);
|
||||
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(cnode_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
return_node->set_abstract(abstract_tensor);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ImportNodesForGraph(
|
||||
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
|
||||
CNodePtr cnode_ptr = nullptr;
|
||||
for (int i = 0; i < importProto.node_size(); ++i) {
|
||||
const onnx::NodeProto &node_proto = importProto.node(i);
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
if (node_type == kConstantValueNode) {
|
||||
if (!BuildValueNodeForFuncGraph(node_proto)) {
|
||||
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
|
||||
if (cnode_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildFuncGraph(
|
||||
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
|
||||
MS_EXCEPTION_IF_NULL(debug_info_ptr);
|
||||
|
@ -651,7 +1156,8 @@ bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph
|
|||
return ImportNodesForGraph(outputFuncGraph, importProto);
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
|
||||
bool AnfImporterFromProtobuf::ParseModelConfigureInfo(
|
||||
const onnx::ModelProto &model_proto) {
|
||||
if (!model_proto.has_producer_name()) {
|
||||
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
|
||||
return false;
|
||||
|
@ -672,7 +1178,6 @@ bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mo
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
int AnfImporterFromProtobuf::Import() {
|
||||
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
|
||||
MS_EXCEPTION_IF_NULL(dstGraph);
|
||||
|
@ -689,9 +1194,9 @@ int AnfImporterFromProtobuf::Import() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
|
||||
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
|
||||
std::unique_ptr<char> onnx_file(new(std::nothrow) char[PATH_MAX]{0});
|
||||
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(
|
||||
const std::string &model_path) {
|
||||
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
|
||||
if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) {
|
||||
MS_LOG(ERROR) << "open file failed.";
|
||||
return nullptr;
|
||||
|
@ -707,11 +1212,10 @@ onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string
|
|||
delete onnx_model;
|
||||
return nullptr;
|
||||
}
|
||||
(void) close(fd);
|
||||
(void)close(fd);
|
||||
MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl;
|
||||
return onnx_model;
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; }
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
|||
bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
|
||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto);
|
||||
#if 0
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
|
@ -76,6 +77,30 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
|||
const onnx::TensorProto &attr_tensor);
|
||||
std::unordered_map<std::string, abstract::AbstractTensorPtr>
|
||||
GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||
#endif
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
|
||||
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto);
|
||||
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const CNodePtr &cnode_ptr);
|
||||
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
|
||||
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
|
||||
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
|
||||
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||
|
||||
|
||||
|
||||
private:
|
||||
std::string producer_name_;
|
||||
|
|
Loading…
Reference in New Issue