forked from mindspore-Ecosystem/mindspore
!5719 [MSLITE]mindspore export inferface finetune lite
Merge pull request !5719 from zhengjun10/master
This commit is contained in:
commit
25a878d674
|
@ -1,2 +1,2 @@
|
|||
ssd.pb
|
||||
mobilenet_v2.pb
|
||||
#ssd.pb
|
||||
# mobilenet_v2.pb
|
||||
|
|
|
@ -46,9 +46,6 @@ using uint64 = uint64_t;
|
|||
namespace mindspore::lite {
|
||||
|
||||
static constexpr char kConstantValueNode[] = "Constant";
|
||||
static constexpr char kCNodeShapeAttr[] = "shape";
|
||||
static constexpr char kCNodeShape1Attr[] = "shape1";
|
||||
static constexpr char kCNodeShape2Attr[] = "shape2";
|
||||
|
||||
enum ParseForm : int {
|
||||
FORM_PARSE_TYPE = 0,
|
||||
|
@ -57,32 +54,143 @@ enum ParseForm : int {
|
|||
};
|
||||
|
||||
static std::map<std::string, ParseForm> kParseTypeSwitchMap{
|
||||
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}};
|
||||
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}};
|
||||
|
||||
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
|
||||
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
|
||||
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
|
||||
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
|
||||
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
|
||||
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
|
||||
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
|
||||
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
|
||||
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8},
|
||||
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32},
|
||||
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8},
|
||||
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32},
|
||||
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
|
||||
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
|
||||
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
|
||||
};
|
||||
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
|
||||
const onnx::TensorProto &attr_tensor) { \
|
||||
MS_EXCEPTION_IF_NULL(prim); \
|
||||
std::vector<ValuePtr> attr_value_vec; \
|
||||
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
|
||||
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
|
||||
} \
|
||||
if (attr_value_vec.size() == 1) { \
|
||||
prim->AddAttr(attr_name, attr_value_vec[0]); \
|
||||
} else { \
|
||||
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
|
||||
} \
|
||||
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
|
||||
const std::unordered_map<string, ValuePtr> &kv) {
|
||||
std::string str = attr_name;
|
||||
auto replace = [&](const string &orgStr, const string &newStr) {
|
||||
std::string::size_type pos(0);
|
||||
while ((pos = str.find(orgStr)) != std::string::npos) {
|
||||
str.replace(pos, orgStr.length(), newStr);
|
||||
}
|
||||
return str;
|
||||
};
|
||||
// remove "scalar:"
|
||||
str = replace("scalar:", "");
|
||||
// remove "Tuple"
|
||||
str = replace("Tuple", "");
|
||||
// remove "List"
|
||||
str = replace("List", "");
|
||||
std::stack<std::string> rules;
|
||||
std::stack<ValuePtr> value;
|
||||
int num = 0, count = 0;
|
||||
for (size_t i = 0; i < str.length(); i++) {
|
||||
if (str[i] == '[') {
|
||||
rules.push("[");
|
||||
} else if (str[i] == ']') {
|
||||
// rules
|
||||
std::vector<ValuePtr> vec;
|
||||
while (rules.top() != "[") {
|
||||
rules.pop();
|
||||
vec.push_back(value.top());
|
||||
value.pop();
|
||||
}
|
||||
// pop "["
|
||||
rules.pop();
|
||||
// make tuple for names
|
||||
std::string res = "dummy";
|
||||
// make tuple for values
|
||||
reverse(vec.begin(), vec.end());
|
||||
auto vt = std::make_shared<ValueTuple>(vec);
|
||||
if (rules.empty() && value.empty()) {
|
||||
return vt;
|
||||
}
|
||||
rules.push(res);
|
||||
value.push(vt);
|
||||
} else if (str[i] == ',') {
|
||||
continue;
|
||||
} else {
|
||||
count++;
|
||||
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
|
||||
auto value_name = str.substr(i - count + 1, count);
|
||||
value.push(kv.at(value_name));
|
||||
rules.push(value_name);
|
||||
count = 0;
|
||||
num++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::shared_ptr<abstract::AbstractTuple>
|
||||
ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, abstract::AbstractTensorPtr> &kv) {
|
||||
std::string str = attr_name;
|
||||
auto replace = [&](const string &orgStr, const string &newStr) {
|
||||
std::string::size_type pos(0);
|
||||
while ((pos = str.find(orgStr)) != std::string::npos) {
|
||||
str.replace(pos, orgStr.length(), newStr);
|
||||
}
|
||||
return str;
|
||||
};
|
||||
// remove "scalar:"
|
||||
str = replace("shape:", "");
|
||||
// remove "Tuple"
|
||||
str = replace("Tuple", "");
|
||||
// remove "List"
|
||||
str = replace("List", "");
|
||||
std::stack<std::string> rules;
|
||||
std::stack<abstract::AbstractBasePtr> value;
|
||||
int num = 0, count = 0;
|
||||
for (size_t i = 0; i < str.length(); i++) {
|
||||
if (str[i] == '[') {
|
||||
rules.push("[");
|
||||
} else if (str[i] == ']') {
|
||||
// rules
|
||||
std::vector<abstract::AbstractBasePtr> vec;
|
||||
while (rules.top() != "[") {
|
||||
rules.pop();
|
||||
vec.push_back(value.top());
|
||||
value.pop();
|
||||
}
|
||||
// pop "["
|
||||
rules.pop();
|
||||
// make tuple for names
|
||||
std::string res = "dummy";
|
||||
// make tuple for values
|
||||
reverse(vec.begin(), vec.end());
|
||||
auto vt = std::make_shared<abstract::AbstractTuple>(vec);
|
||||
if (rules.empty() && value.empty()) {
|
||||
return vt;
|
||||
}
|
||||
rules.push(res);
|
||||
value.push(vt);
|
||||
} else if (str[i] == ',') {
|
||||
continue;
|
||||
} else {
|
||||
count++;
|
||||
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
|
||||
auto value_name = str.substr(i - count + 1, count);
|
||||
value.push(kv.at(value_name));
|
||||
rules.push(value_name);
|
||||
count = 0;
|
||||
num++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
|
||||
if (attr_tensor.type##_data_size() == 1) { \
|
||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
|
||||
return MakeValue<valuetype>(value); \
|
||||
} else { \
|
||||
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
|
||||
} \
|
||||
return {}; \
|
||||
}
|
||||
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
|
||||
|
@ -193,45 +301,34 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
switch (attr_tensor_type) {
|
||||
case onnx::TensorProto_DataType_STRING: {
|
||||
ParseAttrInScalar_string_string(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
return ParseAttrInScalar_string_string(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_INT32: {
|
||||
ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
return ParseAttrInScalar_int32_int32(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_INT64: {
|
||||
ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
return ParseAttrInScalar_int64_int64(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_UINT64: {
|
||||
ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
return ParseAttrInScalar_uint64_uint64(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_FLOAT: {
|
||||
ParseAttrInScalar_float_float(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
return ParseAttrInScalar_float_float(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_DOUBLE: {
|
||||
ParseAttrInScalar_double_double(prim, attr_name, attr_tensor);
|
||||
break;
|
||||
return ParseAttrInScalar_double_double(attr_tensor);
|
||||
}
|
||||
case onnx::TensorProto_DataType_BOOL: {
|
||||
ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor);
|
||||
auto value = prim->GetAttr(attr_name);
|
||||
break;
|
||||
return ParseAttrInScalar_int32_bool(attr_tensor);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
||||
return false;
|
||||
default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
||||
return {};
|
||||
}
|
||||
return true;
|
||||
return {};
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
|
@ -268,7 +365,6 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr
|
|||
prim->set_attr(attr_name, MakeValue<bool>(attr_value));
|
||||
}
|
||||
}
|
||||
|
||||
return ret == EOK;
|
||||
}
|
||||
|
||||
|
@ -280,22 +376,46 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con
|
|||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
||||
switch (kParseTypeSwitchMap[ref_attr_name]) {
|
||||
case FORM_PARSE_TYPE: {
|
||||
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_SCALAR: {
|
||||
return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_TENSOR: {
|
||||
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
|
||||
return false;
|
||||
string type;
|
||||
std::size_t pos(0);
|
||||
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("type:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
|
||||
}
|
||||
std::unordered_map<std::string, ValuePtr> kv;
|
||||
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||
switch (kParseTypeSwitchMap[type]) {
|
||||
case FORM_PARSE_TYPE: {
|
||||
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_SCALAR: {
|
||||
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
|
||||
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_TENSOR: {
|
||||
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
|
||||
}
|
||||
default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
|
||||
if (kv.size() == 1) {
|
||||
std::unordered_map<std::string, ValuePtr>::iterator iter = kv.begin();
|
||||
prim->AddAttr(attr_name, iter->second);
|
||||
} else {
|
||||
auto res = ParserScalarAttrValue(ref_attr_name, kv);
|
||||
prim->AddAttr(attr_name, res);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
|
@ -321,53 +441,6 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
ValuePtr value_ptr = nullptr;
|
||||
switch (attr_tensor_type) {
|
||||
case onnx::TensorProto_DataType_INT32: {
|
||||
std::vector<int32> add_data;
|
||||
for (int i = 0; i < attr_tensor.int32_data_size(); ++i) {
|
||||
add_data.push_back(attr_tensor.int32_data(i));
|
||||
}
|
||||
if (add_data.size() == 1) {
|
||||
value_ptr = MakeValue(add_data[0]);
|
||||
} else if (!add_data.empty()) {
|
||||
value_ptr = MakeValue<std::vector<int32> >(add_data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_FLOAT: {
|
||||
std::vector<float> add_data;
|
||||
for (int i = 0; i < attr_tensor.float_data_size(); ++i) {
|
||||
add_data.push_back(attr_tensor.float_data(i));
|
||||
}
|
||||
|
||||
if (add_data.size() == 1) {
|
||||
value_ptr = MakeValue(add_data[0]);
|
||||
} else if (!add_data.empty()) {
|
||||
value_ptr = MakeValue<std::vector<float> >(add_data);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_UNDEFINED: {
|
||||
std::vector<ValuePtr> elems;
|
||||
value_ptr = std::make_shared<ValueTuple>(elems);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
auto new_value_node = NewValueNode(value_ptr);
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
new_value_node->set_abstract(value_ptr->ToAbstract());
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
|
@ -382,23 +455,56 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name,
|
||||
const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
switch (kParseTypeSwitchMap[ref_attr_name]) {
|
||||
case FORM_PARSE_SCALAR: {
|
||||
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_TENSOR: {
|
||||
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_TYPE: {
|
||||
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name";
|
||||
return false;
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name,
|
||||
const onnx::AttributeProto &attr_proto) {
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
string type;
|
||||
std::size_t pos(0);
|
||||
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("type:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
|
||||
}
|
||||
std::unordered_map<std::string, ValuePtr> kv;
|
||||
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||
switch (kParseTypeSwitchMap[type]) {
|
||||
case FORM_PARSE_TYPE: {
|
||||
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
|
||||
}
|
||||
case FORM_PARSE_SCALAR: {
|
||||
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
|
||||
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_TENSOR: {
|
||||
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
|
||||
}
|
||||
default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ValueNodePtr new_value_node;
|
||||
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
|
||||
if (kv.size() == 1) {
|
||||
auto iter = kv.begin();
|
||||
new_value_node = NewValueNode(iter->second);
|
||||
new_value_node->set_abstract(iter->second->ToAbstract());
|
||||
} else {
|
||||
auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv);
|
||||
new_value_node = NewValueNode(value_ptr);
|
||||
new_value_node->set_abstract(value_ptr->ToAbstract());
|
||||
}
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
|
||||
|
@ -408,22 +514,23 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &
|
|||
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
|
||||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
||||
|
||||
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor);
|
||||
return GetAttrValueForValueNode(value_node_name, attr_proto);
|
||||
}
|
||||
|
||||
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
|
||||
std::vector<int> shape_vec;
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape_vec.push_back(attr_tensor.dims(i));
|
||||
std::unordered_map<std::string, abstract::AbstractTensorPtr>
|
||||
AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
|
||||
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
|
||||
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||
std::vector<int> shape_vec;
|
||||
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
|
||||
shape_vec.push_back(attr_tensor.dims(j));
|
||||
}
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
|
||||
kv.insert(std::pair<string, abstract::AbstractTensorPtr>(attr_tensor.name(), abstract_tensor));
|
||||
}
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tensor);
|
||||
return abstract_tensor;
|
||||
return kv;
|
||||
}
|
||||
|
||||
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
|
@ -437,25 +544,16 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|||
const std::string &node_name = node_proto.output(0);
|
||||
const std::string &fullname_with_scope = node_proto.domain();
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
PrimitivePtr prim = std::make_shared<Primitive>(node_type);
|
||||
PrimitivePtr prim = std::make_shared<mindspore::Primitive>(node_type);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
prim->set_instance_name(node_type);
|
||||
|
||||
abstract::AbstractTensorPtr abstract = nullptr;
|
||||
abstract::AbstractTensorPtr abstract_first = nullptr;
|
||||
abstract::AbstractTensorPtr abstract_second = nullptr;
|
||||
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
|
||||
string shape_ref_attr_name;
|
||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||
if (attr_proto.name() == kCNodeShapeAttr) {
|
||||
abstract = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
if (attr_proto.name() == kCNodeShape1Attr) {
|
||||
abstract_first = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
if (attr_proto.name() == kCNodeShape2Attr) {
|
||||
abstract_second = GetAbstractForCNode(attr_proto);
|
||||
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
|
||||
shape_ref_attr_name = attr_proto.ref_attr_name();
|
||||
kv = GetAbstractForCNode(attr_proto);
|
||||
continue;
|
||||
}
|
||||
if (!GetAttrValueForCNode(prim, attr_proto)) {
|
||||
|
@ -463,6 +561,7 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.clear();
|
||||
for (int i = 0; i < node_proto.input_size(); ++i) {
|
||||
|
@ -481,26 +580,20 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|||
inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr));
|
||||
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
if (node_type == "LayerNorm") {
|
||||
AbstractBasePtrList elem;
|
||||
elem.push_back(abstract);
|
||||
elem.push_back(abstract_first);
|
||||
elem.push_back(abstract_second);
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else if (node_type == "ArgMaxWithValue") {
|
||||
AbstractBasePtrList elem;
|
||||
elem.push_back(abstract);
|
||||
elem.push_back(abstract_first);
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else if (nullptr == abstract) {
|
||||
if (0 == kv.size()) {
|
||||
AbstractBasePtrList elem;
|
||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||
elem.push_back(cnode_ptr->input(index)->abstract());
|
||||
}
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
} else if (1 == kv.size()) {
|
||||
std::unordered_map<std::string, abstract::AbstractTensorPtr>::iterator iter = kv.begin();
|
||||
cnode_ptr->set_abstract(iter->second);
|
||||
} else {
|
||||
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
|
||||
cnode_ptr->set_abstract(abstract);
|
||||
}
|
||||
|
||||
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
|
||||
anfnode_build_map_[node_name] = cnode_ptr;
|
||||
return cnode_ptr;
|
||||
|
@ -652,7 +745,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
|||
|
||||
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
|
||||
auto onnx_model = new onnx::ModelProto;
|
||||
if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) {
|
||||
if (ReadProtoFromBinaryFile((const char *) model_path.c_str(), onnx_model) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path;
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -42,9 +42,9 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
|||
int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override;
|
||||
|
||||
private:
|
||||
int ConverterConstTensor() override{ return RET_ERROR; };
|
||||
int ConverterCNode() override{ return RET_ERROR; };
|
||||
int AddReturnCNode() override{ return RET_ERROR; };
|
||||
int ConverterConstTensor() override { return RET_ERROR; };
|
||||
int ConverterCNode() override { return RET_ERROR; };
|
||||
int AddReturnCNode() override { return RET_ERROR; };
|
||||
bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
|
||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType);
|
||||
|
@ -59,18 +59,15 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
|||
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
|
||||
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor);
|
||||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
|
||||
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
|
||||
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
|
||||
const onnx::TensorProto &attr_tensor);
|
||||
bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_proto);
|
||||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||
std::unordered_map<std::string,
|
||||
abstract::AbstractTensorPtr> GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||
|
||||
private:
|
||||
std::string producer_name_;
|
||||
|
|
Loading…
Reference in New Issue