!5719 [MSLITE]mindspore export inferface finetune lite

Merge pull request !5719 from zhengjun10/master
This commit is contained in:
mindspore-ci-bot 2020-09-05 16:11:13 +08:00 committed by Gitee
commit 25a878d674
3 changed files with 268 additions and 178 deletions

View File

@ -1,2 +1,2 @@
ssd.pb #ssd.pb
mobilenet_v2.pb # mobilenet_v2.pb

View File

@ -46,9 +46,6 @@ using uint64 = uint64_t;
namespace mindspore::lite { namespace mindspore::lite {
static constexpr char kConstantValueNode[] = "Constant"; static constexpr char kConstantValueNode[] = "Constant";
static constexpr char kCNodeShapeAttr[] = "shape";
static constexpr char kCNodeShape1Attr[] = "shape1";
static constexpr char kCNodeShape2Attr[] = "shape2";
enum ParseForm : int { enum ParseForm : int {
FORM_PARSE_TYPE = 0, FORM_PARSE_TYPE = 0,
@ -69,20 +66,131 @@ static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
{onnx::TensorProto_DataType_STRING, kObjectTypeString}, {onnx::TensorProto_DataType_STRING, kObjectTypeString},
}; };
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
const std::unordered_map<string, ValuePtr> &kv) {
std::string str = attr_name;
auto replace = [&](const string &orgStr, const string &newStr) {
std::string::size_type pos(0);
while ((pos = str.find(orgStr)) != std::string::npos) {
str.replace(pos, orgStr.length(), newStr);
}
return str;
};
// remove "scalar:"
str = replace("scalar:", "");
// remove "Tuple"
str = replace("Tuple", "");
// remove "List"
str = replace("List", "");
std::stack<std::string> rules;
std::stack<ValuePtr> value;
int num = 0, count = 0;
for (size_t i = 0; i < str.length(); i++) {
if (str[i] == '[') {
rules.push("[");
} else if (str[i] == ']') {
// rules
std::vector<ValuePtr> vec;
while (rules.top() != "[") {
rules.pop();
vec.push_back(value.top());
value.pop();
}
// pop "["
rules.pop();
// make tuple for names
std::string res = "dummy";
// make tuple for values
reverse(vec.begin(), vec.end());
auto vt = std::make_shared<ValueTuple>(vec);
if (rules.empty() && value.empty()) {
return vt;
}
rules.push(res);
value.push(vt);
} else if (str[i] == ',') {
continue;
} else {
count++;
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
auto value_name = str.substr(i - count + 1, count);
value.push(kv.at(value_name));
rules.push(value_name);
count = 0;
num++;
}
}
}
return {};
}
std::shared_ptr<abstract::AbstractTuple>
ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, abstract::AbstractTensorPtr> &kv) {
std::string str = attr_name;
auto replace = [&](const string &orgStr, const string &newStr) {
std::string::size_type pos(0);
while ((pos = str.find(orgStr)) != std::string::npos) {
str.replace(pos, orgStr.length(), newStr);
}
return str;
};
// remove "scalar:"
str = replace("shape:", "");
// remove "Tuple"
str = replace("Tuple", "");
// remove "List"
str = replace("List", "");
std::stack<std::string> rules;
std::stack<abstract::AbstractBasePtr> value;
int num = 0, count = 0;
for (size_t i = 0; i < str.length(); i++) {
if (str[i] == '[') {
rules.push("[");
} else if (str[i] == ']') {
// rules
std::vector<abstract::AbstractBasePtr> vec;
while (rules.top() != "[") {
rules.pop();
vec.push_back(value.top());
value.pop();
}
// pop "["
rules.pop();
// make tuple for names
std::string res = "dummy";
// make tuple for values
reverse(vec.begin(), vec.end());
auto vt = std::make_shared<abstract::AbstractTuple>(vec);
if (rules.empty() && value.empty()) {
return vt;
}
rules.push(res);
value.push(vt);
} else if (str[i] == ',') {
continue;
} else {
count++;
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
auto value_name = str.substr(i - count + 1, count);
value.push(kv.at(value_name));
rules.push(value_name);
count = 0;
num++;
}
}
}
return {};
}
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
const onnx::TensorProto &attr_tensor) { \ if (attr_tensor.type##_data_size() == 1) { \
MS_EXCEPTION_IF_NULL(prim); \ auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
std::vector<ValuePtr> attr_value_vec; \ return MakeValue<valuetype>(value); \
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
} \
if (attr_value_vec.size() == 1) { \
prim->AddAttr(attr_name, attr_value_vec[0]); \
} else { \ } else { \
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \ MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
} \ } \
return {}; \
} }
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
@ -193,45 +301,34 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim
return true; return true;
} }
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
const int attr_tensor_type = attr_tensor.data_type(); const int attr_tensor_type = attr_tensor.data_type();
switch (attr_tensor_type) { switch (attr_tensor_type) {
case onnx::TensorProto_DataType_STRING: { case onnx::TensorProto_DataType_STRING: {
ParseAttrInScalar_string_string(prim, attr_name, attr_tensor); return ParseAttrInScalar_string_string(attr_tensor);
break;
} }
case onnx::TensorProto_DataType_INT32: { case onnx::TensorProto_DataType_INT32: {
ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor); return ParseAttrInScalar_int32_int32(attr_tensor);
break;
} }
case onnx::TensorProto_DataType_INT64: { case onnx::TensorProto_DataType_INT64: {
ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor); return ParseAttrInScalar_int64_int64(attr_tensor);
break;
} }
case onnx::TensorProto_DataType_UINT64: { case onnx::TensorProto_DataType_UINT64: {
ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor); return ParseAttrInScalar_uint64_uint64(attr_tensor);
break;
} }
case onnx::TensorProto_DataType_FLOAT: { case onnx::TensorProto_DataType_FLOAT: {
ParseAttrInScalar_float_float(prim, attr_name, attr_tensor); return ParseAttrInScalar_float_float(attr_tensor);
break;
} }
case onnx::TensorProto_DataType_DOUBLE: { case onnx::TensorProto_DataType_DOUBLE: {
ParseAttrInScalar_double_double(prim, attr_name, attr_tensor); return ParseAttrInScalar_double_double(attr_tensor);
break;
} }
case onnx::TensorProto_DataType_BOOL: { case onnx::TensorProto_DataType_BOOL: {
ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor); return ParseAttrInScalar_int32_bool(attr_tensor);
auto value = prim->GetAttr(attr_name);
break;
} }
default: default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; return {};
return false;
} }
return true; return {};
} }
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
@ -268,7 +365,6 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr
prim->set_attr(attr_name, MakeValue<bool>(attr_value)); prim->set_attr(attr_name, MakeValue<bool>(attr_value));
} }
} }
return ret == EOK; return ret == EOK;
} }
@ -280,22 +376,46 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con
return false; return false;
} }
const std::string &ref_attr_name = attr_proto.ref_attr_name(); const std::string &ref_attr_name = attr_proto.ref_attr_name();
const onnx::TensorProto &attr_tensor = attr_proto.t(); string type;
switch (kParseTypeSwitchMap[ref_attr_name]) { std::size_t pos(0);
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("type:").length() - 1);
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
}
std::unordered_map<std::string, ValuePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
switch (kParseTypeSwitchMap[type]) {
case FORM_PARSE_TYPE: { case FORM_PARSE_TYPE: {
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
} }
case FORM_PARSE_SCALAR: { case FORM_PARSE_SCALAR: {
return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor); auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
break;
} }
case FORM_PARSE_TENSOR: { case FORM_PARSE_TENSOR: {
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
} }
default: default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return false; return false;
} }
} }
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
if (kv.size() == 1) {
std::unordered_map<std::string, ValuePtr>::iterator iter = kv.begin();
prim->AddAttr(attr_name, iter->second);
} else {
auto res = ParserScalarAttrValue(ref_attr_name, kv);
prim->AddAttr(attr_name, res);
}
}
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) { const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type(); const int attr_tensor_type = attr_tensor.data_type();
@ -321,53 +441,6 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
return true; return true;
} }
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
ValuePtr value_ptr = nullptr;
switch (attr_tensor_type) {
case onnx::TensorProto_DataType_INT32: {
std::vector<int32> add_data;
for (int i = 0; i < attr_tensor.int32_data_size(); ++i) {
add_data.push_back(attr_tensor.int32_data(i));
}
if (add_data.size() == 1) {
value_ptr = MakeValue(add_data[0]);
} else if (!add_data.empty()) {
value_ptr = MakeValue<std::vector<int32> >(add_data);
}
break;
}
case onnx::TensorProto_DataType_FLOAT: {
std::vector<float> add_data;
for (int i = 0; i < attr_tensor.float_data_size(); ++i) {
add_data.push_back(attr_tensor.float_data(i));
}
if (add_data.size() == 1) {
value_ptr = MakeValue(add_data[0]);
} else if (!add_data.empty()) {
value_ptr = MakeValue<std::vector<float> >(add_data);
}
break;
}
case onnx::TensorProto_DataType_UNDEFINED: {
std::vector<ValuePtr> elems;
value_ptr = std::make_shared<ValueTuple>(elems);
break;
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
return false;
}
auto new_value_node = NewValueNode(value_ptr);
MS_EXCEPTION_IF_NULL(new_value_node);
new_value_node->set_abstract(value_ptr->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) { const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type(); const int attr_tensor_type = attr_tensor.data_type();
@ -382,25 +455,58 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value
return true; return true;
} }
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name, bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name,
const std::string &value_node_name, const onnx::AttributeProto &attr_proto) {
const onnx::TensorProto &attr_tensor) { if (!attr_proto.has_ref_attr_name()) {
switch (kParseTypeSwitchMap[ref_attr_name]) { MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
string type;
std::size_t pos(0);
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("type:").length() - 1);
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
}
std::unordered_map<std::string, ValuePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
switch (kParseTypeSwitchMap[type]) {
case FORM_PARSE_TYPE: {
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
}
case FORM_PARSE_SCALAR: { case FORM_PARSE_SCALAR: {
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
break;
} }
case FORM_PARSE_TENSOR: { case FORM_PARSE_TENSOR: {
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
} }
case FORM_PARSE_TYPE: { default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
}
default:
MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name";
return false; return false;
} }
} }
ValueNodePtr new_value_node;
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
if (kv.size() == 1) {
auto iter = kv.begin();
new_value_node = NewValueNode(iter->second);
new_value_node->set_abstract(iter->second->ToAbstract());
} else {
auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv);
new_value_node = NewValueNode(value_ptr);
new_value_node->set_abstract(value_ptr->ToAbstract());
}
anfnode_build_map_[value_node_name] = new_value_node;
}
return true;
}
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
const std::string &value_node_name = node_proto.output(0); const std::string &value_node_name = node_proto.output(0);
const onnx::AttributeProto &attr_proto = node_proto.attribute(0); const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
@ -408,22 +514,23 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
return false; return false;
} }
const std::string &ref_attr_name = attr_proto.ref_attr_name(); return GetAttrValueForValueNode(value_node_name, attr_proto);
const onnx::TensorProto &attr_tensor = attr_proto.t();
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor);
} }
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { std::unordered_map<std::string, abstract::AbstractTensorPtr>
AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
std::vector<int> shape_vec; std::vector<int> shape_vec;
const onnx::TensorProto &attr_tensor = attr_proto.t(); const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
for (int i = 0; i < attr_tensor.dims_size(); ++i) { for (int j = 0; j < attr_tensor.dims_size(); ++j) {
shape_vec.push_back(attr_tensor.dims(i)); shape_vec.push_back(attr_tensor.dims(j));
} }
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec); auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
MS_EXCEPTION_IF_NULL(abstract_tensor); kv.insert(std::pair<string, abstract::AbstractTensorPtr>(attr_tensor.name(), abstract_tensor));
return abstract_tensor; }
return kv;
} }
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
@ -437,25 +544,16 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
const std::string &node_name = node_proto.output(0); const std::string &node_name = node_proto.output(0);
const std::string &fullname_with_scope = node_proto.domain(); const std::string &fullname_with_scope = node_proto.domain();
const std::string &node_type = node_proto.op_type(); const std::string &node_type = node_proto.op_type();
PrimitivePtr prim = std::make_shared<Primitive>(node_type); PrimitivePtr prim = std::make_shared<mindspore::Primitive>(node_type);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
prim->set_instance_name(node_type); prim->set_instance_name(node_type);
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
abstract::AbstractTensorPtr abstract = nullptr; string shape_ref_attr_name;
abstract::AbstractTensorPtr abstract_first = nullptr;
abstract::AbstractTensorPtr abstract_second = nullptr;
for (int i = 0; i < node_proto.attribute_size(); ++i) { for (int i = 0; i < node_proto.attribute_size(); ++i) {
const onnx::AttributeProto &attr_proto = node_proto.attribute(i); const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
if (attr_proto.name() == kCNodeShapeAttr) { if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
abstract = GetAbstractForCNode(attr_proto); shape_ref_attr_name = attr_proto.ref_attr_name();
continue; kv = GetAbstractForCNode(attr_proto);
}
if (attr_proto.name() == kCNodeShape1Attr) {
abstract_first = GetAbstractForCNode(attr_proto);
continue;
}
if (attr_proto.name() == kCNodeShape2Attr) {
abstract_second = GetAbstractForCNode(attr_proto);
continue; continue;
} }
if (!GetAttrValueForCNode(prim, attr_proto)) { if (!GetAttrValueForCNode(prim, attr_proto)) {
@ -463,6 +561,7 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
return nullptr; return nullptr;
} }
} }
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
inputs.clear(); inputs.clear();
for (int i = 0; i < node_proto.input_size(); ++i) { for (int i = 0; i < node_proto.input_size(); ++i) {
@ -481,26 +580,20 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr)); inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr));
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr); MS_EXCEPTION_IF_NULL(cnode_ptr);
if (node_type == "LayerNorm") { if (0 == kv.size()) {
AbstractBasePtrList elem;
elem.push_back(abstract);
elem.push_back(abstract_first);
elem.push_back(abstract_second);
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (node_type == "ArgMaxWithValue") {
AbstractBasePtrList elem;
elem.push_back(abstract);
elem.push_back(abstract_first);
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (nullptr == abstract) {
AbstractBasePtrList elem; AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
elem.push_back(cnode_ptr->input(index)->abstract()); elem.push_back(cnode_ptr->input(index)->abstract());
} }
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (1 == kv.size()) {
std::unordered_map<std::string, abstract::AbstractTensorPtr>::iterator iter = kv.begin();
cnode_ptr->set_abstract(iter->second);
} else { } else {
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
cnode_ptr->set_abstract(abstract); cnode_ptr->set_abstract(abstract);
} }
cnode_ptr->set_fullname_with_scope(fullname_with_scope); cnode_ptr->set_fullname_with_scope(fullname_with_scope);
anfnode_build_map_[node_name] = cnode_ptr; anfnode_build_map_[node_name] = cnode_ptr;
return cnode_ptr; return cnode_ptr;

View File

@ -59,18 +59,15 @@ class AnfImporterFromProtobuf : public AnfImporter {
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor); const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor);
const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor); const onnx::TensorProto &attr_tensor);
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_proto);
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); std::unordered_map<std::string,
abstract::AbstractTensorPtr> GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
private: private:
std::string producer_name_; std::string producer_name_;