!5719 [MSLITE]mindspore export inferface finetune lite

Merge pull request !5719 from zhengjun10/master
This commit is contained in:
mindspore-ci-bot 2020-09-05 16:11:13 +08:00 committed by Gitee
commit 25a878d674
3 changed files with 268 additions and 178 deletions

View File

@ -1,2 +1,2 @@
ssd.pb
mobilenet_v2.pb
#ssd.pb
# mobilenet_v2.pb

View File

@ -46,9 +46,6 @@ using uint64 = uint64_t;
namespace mindspore::lite {
static constexpr char kConstantValueNode[] = "Constant";
static constexpr char kCNodeShapeAttr[] = "shape";
static constexpr char kCNodeShape1Attr[] = "shape1";
static constexpr char kCNodeShape2Attr[] = "shape2";
enum ParseForm : int {
FORM_PARSE_TYPE = 0,
@ -57,32 +54,143 @@ enum ParseForm : int {
};
static std::map<std::string, ParseForm> kParseTypeSwitchMap{
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}};
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}};
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
};
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
const onnx::TensorProto &attr_tensor) { \
MS_EXCEPTION_IF_NULL(prim); \
std::vector<ValuePtr> attr_value_vec; \
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
} \
if (attr_value_vec.size() == 1) { \
prim->AddAttr(attr_name, attr_value_vec[0]); \
} else { \
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
} \
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
const std::unordered_map<string, ValuePtr> &kv) {
std::string str = attr_name;
auto replace = [&](const string &orgStr, const string &newStr) {
std::string::size_type pos(0);
while ((pos = str.find(orgStr)) != std::string::npos) {
str.replace(pos, orgStr.length(), newStr);
}
return str;
};
// remove "scalar:"
str = replace("scalar:", "");
// remove "Tuple"
str = replace("Tuple", "");
// remove "List"
str = replace("List", "");
std::stack<std::string> rules;
std::stack<ValuePtr> value;
int num = 0, count = 0;
for (size_t i = 0; i < str.length(); i++) {
if (str[i] == '[') {
rules.push("[");
} else if (str[i] == ']') {
// rules
std::vector<ValuePtr> vec;
while (rules.top() != "[") {
rules.pop();
vec.push_back(value.top());
value.pop();
}
// pop "["
rules.pop();
// make tuple for names
std::string res = "dummy";
// make tuple for values
reverse(vec.begin(), vec.end());
auto vt = std::make_shared<ValueTuple>(vec);
if (rules.empty() && value.empty()) {
return vt;
}
rules.push(res);
value.push(vt);
} else if (str[i] == ',') {
continue;
} else {
count++;
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
auto value_name = str.substr(i - count + 1, count);
value.push(kv.at(value_name));
rules.push(value_name);
count = 0;
num++;
}
}
}
return {};
}
std::shared_ptr<abstract::AbstractTuple>
ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, abstract::AbstractTensorPtr> &kv) {
std::string str = attr_name;
auto replace = [&](const string &orgStr, const string &newStr) {
std::string::size_type pos(0);
while ((pos = str.find(orgStr)) != std::string::npos) {
str.replace(pos, orgStr.length(), newStr);
}
return str;
};
// remove "scalar:"
str = replace("shape:", "");
// remove "Tuple"
str = replace("Tuple", "");
// remove "List"
str = replace("List", "");
std::stack<std::string> rules;
std::stack<abstract::AbstractBasePtr> value;
int num = 0, count = 0;
for (size_t i = 0; i < str.length(); i++) {
if (str[i] == '[') {
rules.push("[");
} else if (str[i] == ']') {
// rules
std::vector<abstract::AbstractBasePtr> vec;
while (rules.top() != "[") {
rules.pop();
vec.push_back(value.top());
value.pop();
}
// pop "["
rules.pop();
// make tuple for names
std::string res = "dummy";
// make tuple for values
reverse(vec.begin(), vec.end());
auto vt = std::make_shared<abstract::AbstractTuple>(vec);
if (rules.empty() && value.empty()) {
return vt;
}
rules.push(res);
value.push(vt);
} else if (str[i] == ',') {
continue;
} else {
count++;
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
auto value_name = str.substr(i - count + 1, count);
value.push(kv.at(value_name));
rules.push(value_name);
count = 0;
num++;
}
}
}
return {};
}
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
if (attr_tensor.type##_data_size() == 1) { \
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
return MakeValue<valuetype>(value); \
} else { \
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
} \
return {}; \
}
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
@ -193,45 +301,34 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim
return true;
}
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor) {
MS_EXCEPTION_IF_NULL(prim);
ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
switch (attr_tensor_type) {
case onnx::TensorProto_DataType_STRING: {
ParseAttrInScalar_string_string(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_string_string(attr_tensor);
}
case onnx::TensorProto_DataType_INT32: {
ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_int32_int32(attr_tensor);
}
case onnx::TensorProto_DataType_INT64: {
ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_int64_int64(attr_tensor);
}
case onnx::TensorProto_DataType_UINT64: {
ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_uint64_uint64(attr_tensor);
}
case onnx::TensorProto_DataType_FLOAT: {
ParseAttrInScalar_float_float(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_float_float(attr_tensor);
}
case onnx::TensorProto_DataType_DOUBLE: {
ParseAttrInScalar_double_double(prim, attr_name, attr_tensor);
break;
return ParseAttrInScalar_double_double(attr_tensor);
}
case onnx::TensorProto_DataType_BOOL: {
ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor);
auto value = prim->GetAttr(attr_name);
break;
return ParseAttrInScalar_int32_bool(attr_tensor);
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
return false;
default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
return {};
}
return true;
return {};
}
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
@ -268,7 +365,6 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr
prim->set_attr(attr_name, MakeValue<bool>(attr_value));
}
}
return ret == EOK;
}
@ -280,22 +376,46 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
const onnx::TensorProto &attr_tensor = attr_proto.t();
switch (kParseTypeSwitchMap[ref_attr_name]) {
case FORM_PARSE_TYPE: {
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
}
case FORM_PARSE_SCALAR: {
return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor);
}
case FORM_PARSE_TENSOR: {
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
}
default:
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return false;
string type;
std::size_t pos(0);
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("type:").length() - 1);
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
}
std::unordered_map<std::string, ValuePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
switch (kParseTypeSwitchMap[type]) {
case FORM_PARSE_TYPE: {
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
}
case FORM_PARSE_SCALAR: {
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
break;
}
case FORM_PARSE_TENSOR: {
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
}
default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return false;
}
}
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
if (kv.size() == 1) {
std::unordered_map<std::string, ValuePtr>::iterator iter = kv.begin();
prim->AddAttr(attr_name, iter->second);
} else {
auto res = ParserScalarAttrValue(ref_attr_name, kv);
prim->AddAttr(attr_name, res);
}
}
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
@ -321,53 +441,6 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
ValuePtr value_ptr = nullptr;
switch (attr_tensor_type) {
case onnx::TensorProto_DataType_INT32: {
std::vector<int32> add_data;
for (int i = 0; i < attr_tensor.int32_data_size(); ++i) {
add_data.push_back(attr_tensor.int32_data(i));
}
if (add_data.size() == 1) {
value_ptr = MakeValue(add_data[0]);
} else if (!add_data.empty()) {
value_ptr = MakeValue<std::vector<int32> >(add_data);
}
break;
}
case onnx::TensorProto_DataType_FLOAT: {
std::vector<float> add_data;
for (int i = 0; i < attr_tensor.float_data_size(); ++i) {
add_data.push_back(attr_tensor.float_data(i));
}
if (add_data.size() == 1) {
value_ptr = MakeValue(add_data[0]);
} else if (!add_data.empty()) {
value_ptr = MakeValue<std::vector<float> >(add_data);
}
break;
}
case onnx::TensorProto_DataType_UNDEFINED: {
std::vector<ValuePtr> elems;
value_ptr = std::make_shared<ValueTuple>(elems);
break;
}
default:
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
return false;
}
auto new_value_node = NewValueNode(value_ptr);
MS_EXCEPTION_IF_NULL(new_value_node);
new_value_node->set_abstract(value_ptr->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
@ -382,23 +455,56 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value
return true;
}
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name,
const std::string &value_node_name,
const onnx::TensorProto &attr_tensor) {
switch (kParseTypeSwitchMap[ref_attr_name]) {
case FORM_PARSE_SCALAR: {
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor);
}
case FORM_PARSE_TENSOR: {
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
}
case FORM_PARSE_TYPE: {
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
}
default:
MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name";
return false;
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name,
const onnx::AttributeProto &attr_proto) {
if (!attr_proto.has_ref_attr_name()) {
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
string type;
std::size_t pos(0);
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("type:").length() - 1);
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
}
std::unordered_map<std::string, ValuePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
switch (kParseTypeSwitchMap[type]) {
case FORM_PARSE_TYPE: {
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
}
case FORM_PARSE_SCALAR: {
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
break;
}
case FORM_PARSE_TENSOR: {
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
}
default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
return false;
}
}
ValueNodePtr new_value_node;
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
if (kv.size() == 1) {
auto iter = kv.begin();
new_value_node = NewValueNode(iter->second);
new_value_node->set_abstract(iter->second->ToAbstract());
} else {
auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv);
new_value_node = NewValueNode(value_ptr);
new_value_node->set_abstract(value_ptr->ToAbstract());
}
anfnode_build_map_[value_node_name] = new_value_node;
}
return true;
}
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
@ -408,22 +514,23 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
return false;
}
const std::string &ref_attr_name = attr_proto.ref_attr_name();
const onnx::TensorProto &attr_tensor = attr_proto.t();
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor);
return GetAttrValueForValueNode(value_node_name, attr_proto);
}
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
std::vector<int> shape_vec;
const onnx::TensorProto &attr_tensor = attr_proto.t();
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape_vec.push_back(attr_tensor.dims(i));
std::unordered_map<std::string, abstract::AbstractTensorPtr>
AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); i++) {
std::vector<int> shape_vec;
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
shape_vec.push_back(attr_tensor.dims(j));
}
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
kv.insert(std::pair<string, abstract::AbstractTensorPtr>(attr_tensor.name(), abstract_tensor));
}
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
MS_EXCEPTION_IF_NULL(abstract_tensor);
return abstract_tensor;
return kv;
}
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
@ -437,25 +544,16 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
const std::string &node_name = node_proto.output(0);
const std::string &fullname_with_scope = node_proto.domain();
const std::string &node_type = node_proto.op_type();
PrimitivePtr prim = std::make_shared<Primitive>(node_type);
PrimitivePtr prim = std::make_shared<mindspore::Primitive>(node_type);
MS_EXCEPTION_IF_NULL(prim);
prim->set_instance_name(node_type);
abstract::AbstractTensorPtr abstract = nullptr;
abstract::AbstractTensorPtr abstract_first = nullptr;
abstract::AbstractTensorPtr abstract_second = nullptr;
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
string shape_ref_attr_name;
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
if (attr_proto.name() == kCNodeShapeAttr) {
abstract = GetAbstractForCNode(attr_proto);
continue;
}
if (attr_proto.name() == kCNodeShape1Attr) {
abstract_first = GetAbstractForCNode(attr_proto);
continue;
}
if (attr_proto.name() == kCNodeShape2Attr) {
abstract_second = GetAbstractForCNode(attr_proto);
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
shape_ref_attr_name = attr_proto.ref_attr_name();
kv = GetAbstractForCNode(attr_proto);
continue;
}
if (!GetAttrValueForCNode(prim, attr_proto)) {
@ -463,6 +561,7 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
return nullptr;
}
}
std::vector<AnfNodePtr> inputs;
inputs.clear();
for (int i = 0; i < node_proto.input_size(); ++i) {
@ -481,26 +580,20 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr));
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
if (node_type == "LayerNorm") {
AbstractBasePtrList elem;
elem.push_back(abstract);
elem.push_back(abstract_first);
elem.push_back(abstract_second);
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (node_type == "ArgMaxWithValue") {
AbstractBasePtrList elem;
elem.push_back(abstract);
elem.push_back(abstract_first);
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (nullptr == abstract) {
if (0 == kv.size()) {
AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
elem.push_back(cnode_ptr->input(index)->abstract());
}
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
} else if (1 == kv.size()) {
std::unordered_map<std::string, abstract::AbstractTensorPtr>::iterator iter = kv.begin();
cnode_ptr->set_abstract(iter->second);
} else {
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
cnode_ptr->set_abstract(abstract);
}
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
anfnode_build_map_[node_name] = cnode_ptr;
return cnode_ptr;
@ -652,7 +745,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
auto onnx_model = new onnx::ModelProto;
if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) {
if (ReadProtoFromBinaryFile((const char *) model_path.c_str(), onnx_model) != RET_OK) {
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path;
return nullptr;
}

View File

@ -42,9 +42,9 @@ class AnfImporterFromProtobuf : public AnfImporter {
int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override;
private:
int ConverterConstTensor() override{ return RET_ERROR; };
int ConverterCNode() override{ return RET_ERROR; };
int AddReturnCNode() override{ return RET_ERROR; };
int ConverterConstTensor() override { return RET_ERROR; };
int ConverterCNode() override { return RET_ERROR; };
int AddReturnCNode() override { return RET_ERROR; };
bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const schema::QuantType &quantType);
@ -59,18 +59,15 @@ class AnfImporterFromProtobuf : public AnfImporter {
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_proto);
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
std::unordered_map<std::string,
abstract::AbstractTensorPtr> GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
private:
std::string producer_name_;