forked from mindspore-Ecosystem/mindspore
!5719 [MSLITE]mindspore export inferface finetune lite
Merge pull request !5719 from zhengjun10/master
This commit is contained in:
commit
25a878d674
|
@ -1,2 +1,2 @@
|
||||||
ssd.pb
|
#ssd.pb
|
||||||
mobilenet_v2.pb
|
# mobilenet_v2.pb
|
||||||
|
|
|
@ -46,9 +46,6 @@ using uint64 = uint64_t;
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
|
||||||
static constexpr char kConstantValueNode[] = "Constant";
|
static constexpr char kConstantValueNode[] = "Constant";
|
||||||
static constexpr char kCNodeShapeAttr[] = "shape";
|
|
||||||
static constexpr char kCNodeShape1Attr[] = "shape1";
|
|
||||||
static constexpr char kCNodeShape2Attr[] = "shape2";
|
|
||||||
|
|
||||||
enum ParseForm : int {
|
enum ParseForm : int {
|
||||||
FORM_PARSE_TYPE = 0,
|
FORM_PARSE_TYPE = 0,
|
||||||
|
@ -69,20 +66,131 @@ static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{
|
||||||
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
|
{onnx::TensorProto_DataType_STRING, kObjectTypeString},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::shared_ptr<ValueTuple> ParserScalarAttrValue(const std::string &attr_name,
|
||||||
|
const std::unordered_map<string, ValuePtr> &kv) {
|
||||||
|
std::string str = attr_name;
|
||||||
|
auto replace = [&](const string &orgStr, const string &newStr) {
|
||||||
|
std::string::size_type pos(0);
|
||||||
|
while ((pos = str.find(orgStr)) != std::string::npos) {
|
||||||
|
str.replace(pos, orgStr.length(), newStr);
|
||||||
|
}
|
||||||
|
return str;
|
||||||
|
};
|
||||||
|
// remove "scalar:"
|
||||||
|
str = replace("scalar:", "");
|
||||||
|
// remove "Tuple"
|
||||||
|
str = replace("Tuple", "");
|
||||||
|
// remove "List"
|
||||||
|
str = replace("List", "");
|
||||||
|
std::stack<std::string> rules;
|
||||||
|
std::stack<ValuePtr> value;
|
||||||
|
int num = 0, count = 0;
|
||||||
|
for (size_t i = 0; i < str.length(); i++) {
|
||||||
|
if (str[i] == '[') {
|
||||||
|
rules.push("[");
|
||||||
|
} else if (str[i] == ']') {
|
||||||
|
// rules
|
||||||
|
std::vector<ValuePtr> vec;
|
||||||
|
while (rules.top() != "[") {
|
||||||
|
rules.pop();
|
||||||
|
vec.push_back(value.top());
|
||||||
|
value.pop();
|
||||||
|
}
|
||||||
|
// pop "["
|
||||||
|
rules.pop();
|
||||||
|
// make tuple for names
|
||||||
|
std::string res = "dummy";
|
||||||
|
// make tuple for values
|
||||||
|
reverse(vec.begin(), vec.end());
|
||||||
|
auto vt = std::make_shared<ValueTuple>(vec);
|
||||||
|
if (rules.empty() && value.empty()) {
|
||||||
|
return vt;
|
||||||
|
}
|
||||||
|
rules.push(res);
|
||||||
|
value.push(vt);
|
||||||
|
} else if (str[i] == ',') {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
count++;
|
||||||
|
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
|
||||||
|
auto value_name = str.substr(i - count + 1, count);
|
||||||
|
value.push(kv.at(value_name));
|
||||||
|
rules.push(value_name);
|
||||||
|
count = 0;
|
||||||
|
num++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<abstract::AbstractTuple>
|
||||||
|
ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, abstract::AbstractTensorPtr> &kv) {
|
||||||
|
std::string str = attr_name;
|
||||||
|
auto replace = [&](const string &orgStr, const string &newStr) {
|
||||||
|
std::string::size_type pos(0);
|
||||||
|
while ((pos = str.find(orgStr)) != std::string::npos) {
|
||||||
|
str.replace(pos, orgStr.length(), newStr);
|
||||||
|
}
|
||||||
|
return str;
|
||||||
|
};
|
||||||
|
// remove "scalar:"
|
||||||
|
str = replace("shape:", "");
|
||||||
|
// remove "Tuple"
|
||||||
|
str = replace("Tuple", "");
|
||||||
|
// remove "List"
|
||||||
|
str = replace("List", "");
|
||||||
|
std::stack<std::string> rules;
|
||||||
|
std::stack<abstract::AbstractBasePtr> value;
|
||||||
|
int num = 0, count = 0;
|
||||||
|
for (size_t i = 0; i < str.length(); i++) {
|
||||||
|
if (str[i] == '[') {
|
||||||
|
rules.push("[");
|
||||||
|
} else if (str[i] == ']') {
|
||||||
|
// rules
|
||||||
|
std::vector<abstract::AbstractBasePtr> vec;
|
||||||
|
while (rules.top() != "[") {
|
||||||
|
rules.pop();
|
||||||
|
vec.push_back(value.top());
|
||||||
|
value.pop();
|
||||||
|
}
|
||||||
|
// pop "["
|
||||||
|
rules.pop();
|
||||||
|
// make tuple for names
|
||||||
|
std::string res = "dummy";
|
||||||
|
// make tuple for values
|
||||||
|
reverse(vec.begin(), vec.end());
|
||||||
|
auto vt = std::make_shared<abstract::AbstractTuple>(vec);
|
||||||
|
if (rules.empty() && value.empty()) {
|
||||||
|
return vt;
|
||||||
|
}
|
||||||
|
rules.push(res);
|
||||||
|
value.push(vt);
|
||||||
|
} else if (str[i] == ',') {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
count++;
|
||||||
|
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
|
||||||
|
auto value_name = str.substr(i - count + 1, count);
|
||||||
|
value.push(kv.at(value_name));
|
||||||
|
rules.push(value_name);
|
||||||
|
count = 0;
|
||||||
|
num++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||||
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
|
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \
|
||||||
const onnx::TensorProto &attr_tensor) { \
|
if (attr_tensor.type##_data_size() == 1) { \
|
||||||
MS_EXCEPTION_IF_NULL(prim); \
|
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \
|
||||||
std::vector<ValuePtr> attr_value_vec; \
|
return MakeValue<valuetype>(value); \
|
||||||
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \
|
|
||||||
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \
|
|
||||||
attr_value_vec.push_back(MakeValue<valuetype>(value)); \
|
|
||||||
} \
|
|
||||||
if (attr_value_vec.size() == 1) { \
|
|
||||||
prim->AddAttr(attr_name, attr_value_vec[0]); \
|
|
||||||
} else { \
|
} else { \
|
||||||
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \
|
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \
|
||||||
} \
|
} \
|
||||||
|
return {}; \
|
||||||
}
|
}
|
||||||
|
|
||||||
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
|
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double)
|
||||||
|
@ -193,45 +301,34 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
|
ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) {
|
||||||
const onnx::TensorProto &attr_tensor) {
|
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
|
||||||
const int attr_tensor_type = attr_tensor.data_type();
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
switch (attr_tensor_type) {
|
switch (attr_tensor_type) {
|
||||||
case onnx::TensorProto_DataType_STRING: {
|
case onnx::TensorProto_DataType_STRING: {
|
||||||
ParseAttrInScalar_string_string(prim, attr_name, attr_tensor);
|
return ParseAttrInScalar_string_string(attr_tensor);
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case onnx::TensorProto_DataType_INT32: {
|
case onnx::TensorProto_DataType_INT32: {
|
||||||
ParseAttrInScalar_int32_int32(prim, attr_name, attr_tensor);
|
return ParseAttrInScalar_int32_int32(attr_tensor);
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case onnx::TensorProto_DataType_INT64: {
|
case onnx::TensorProto_DataType_INT64: {
|
||||||
ParseAttrInScalar_int64_int64(prim, attr_name, attr_tensor);
|
return ParseAttrInScalar_int64_int64(attr_tensor);
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case onnx::TensorProto_DataType_UINT64: {
|
case onnx::TensorProto_DataType_UINT64: {
|
||||||
ParseAttrInScalar_uint64_uint64(prim, attr_name, attr_tensor);
|
return ParseAttrInScalar_uint64_uint64(attr_tensor);
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case onnx::TensorProto_DataType_FLOAT: {
|
case onnx::TensorProto_DataType_FLOAT: {
|
||||||
ParseAttrInScalar_float_float(prim, attr_name, attr_tensor);
|
return ParseAttrInScalar_float_float(attr_tensor);
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case onnx::TensorProto_DataType_DOUBLE: {
|
case onnx::TensorProto_DataType_DOUBLE: {
|
||||||
ParseAttrInScalar_double_double(prim, attr_name, attr_tensor);
|
return ParseAttrInScalar_double_double(attr_tensor);
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case onnx::TensorProto_DataType_BOOL: {
|
case onnx::TensorProto_DataType_BOOL: {
|
||||||
ParseAttrInScalar_int32_bool(prim, attr_name, attr_tensor);
|
return ParseAttrInScalar_int32_bool(attr_tensor);
|
||||||
auto value = prim->GetAttr(attr_name);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
default:
|
default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
||||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
return {};
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
return true;
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
|
@ -268,7 +365,6 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr
|
||||||
prim->set_attr(attr_name, MakeValue<bool>(attr_value));
|
prim->set_attr(attr_name, MakeValue<bool>(attr_value));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret == EOK;
|
return ret == EOK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -280,22 +376,46 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||||
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
string type;
|
||||||
switch (kParseTypeSwitchMap[ref_attr_name]) {
|
std::size_t pos(0);
|
||||||
|
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
||||||
|
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
||||||
|
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
|
||||||
|
type = ref_attr_name.substr(pos, string("type:").length() - 1);
|
||||||
|
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||||
|
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
|
||||||
|
}
|
||||||
|
std::unordered_map<std::string, ValuePtr> kv;
|
||||||
|
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||||
|
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||||
|
switch (kParseTypeSwitchMap[type]) {
|
||||||
case FORM_PARSE_TYPE: {
|
case FORM_PARSE_TYPE: {
|
||||||
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
|
return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor);
|
||||||
}
|
}
|
||||||
case FORM_PARSE_SCALAR: {
|
case FORM_PARSE_SCALAR: {
|
||||||
return ObtainCNodeAttrInScalarForm(prim, attr_name, attr_tensor);
|
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
|
||||||
|
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
case FORM_PARSE_TENSOR: {
|
case FORM_PARSE_TENSOR: {
|
||||||
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
|
return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor);
|
||||||
}
|
}
|
||||||
default:
|
default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
|
||||||
MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
|
||||||
|
if (kv.size() == 1) {
|
||||||
|
std::unordered_map<std::string, ValuePtr>::iterator iter = kv.begin();
|
||||||
|
prim->AddAttr(attr_name, iter->second);
|
||||||
|
} else {
|
||||||
|
auto res = ParserScalarAttrValue(ref_attr_name, kv);
|
||||||
|
prim->AddAttr(attr_name, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
||||||
const onnx::TensorProto &attr_tensor) {
|
const onnx::TensorProto &attr_tensor) {
|
||||||
const int attr_tensor_type = attr_tensor.data_type();
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
|
@ -321,53 +441,6 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name,
|
|
||||||
const onnx::TensorProto &attr_tensor) {
|
|
||||||
const int attr_tensor_type = attr_tensor.data_type();
|
|
||||||
ValuePtr value_ptr = nullptr;
|
|
||||||
switch (attr_tensor_type) {
|
|
||||||
case onnx::TensorProto_DataType_INT32: {
|
|
||||||
std::vector<int32> add_data;
|
|
||||||
for (int i = 0; i < attr_tensor.int32_data_size(); ++i) {
|
|
||||||
add_data.push_back(attr_tensor.int32_data(i));
|
|
||||||
}
|
|
||||||
if (add_data.size() == 1) {
|
|
||||||
value_ptr = MakeValue(add_data[0]);
|
|
||||||
} else if (!add_data.empty()) {
|
|
||||||
value_ptr = MakeValue<std::vector<int32> >(add_data);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case onnx::TensorProto_DataType_FLOAT: {
|
|
||||||
std::vector<float> add_data;
|
|
||||||
for (int i = 0; i < attr_tensor.float_data_size(); ++i) {
|
|
||||||
add_data.push_back(attr_tensor.float_data(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (add_data.size() == 1) {
|
|
||||||
value_ptr = MakeValue(add_data[0]);
|
|
||||||
} else if (!add_data.empty()) {
|
|
||||||
value_ptr = MakeValue<std::vector<float> >(add_data);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case onnx::TensorProto_DataType_UNDEFINED: {
|
|
||||||
std::vector<ValuePtr> elems;
|
|
||||||
value_ptr = std::make_shared<ValueTuple>(elems);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto new_value_node = NewValueNode(value_ptr);
|
|
||||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
|
||||||
new_value_node->set_abstract(value_ptr->ToAbstract());
|
|
||||||
anfnode_build_map_[value_node_name] = new_value_node;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
||||||
const onnx::TensorProto &attr_tensor) {
|
const onnx::TensorProto &attr_tensor) {
|
||||||
const int attr_tensor_type = attr_tensor.data_type();
|
const int attr_tensor_type = attr_tensor.data_type();
|
||||||
|
@ -382,25 +455,58 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name,
|
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name,
|
||||||
const std::string &value_node_name,
|
const onnx::AttributeProto &attr_proto) {
|
||||||
const onnx::TensorProto &attr_tensor) {
|
if (!attr_proto.has_ref_attr_name()) {
|
||||||
switch (kParseTypeSwitchMap[ref_attr_name]) {
|
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||||
|
string type;
|
||||||
|
std::size_t pos(0);
|
||||||
|
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
||||||
|
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
||||||
|
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
|
||||||
|
type = ref_attr_name.substr(pos, string("type:").length() - 1);
|
||||||
|
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||||
|
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
|
||||||
|
}
|
||||||
|
std::unordered_map<std::string, ValuePtr> kv;
|
||||||
|
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||||
|
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||||
|
switch (kParseTypeSwitchMap[type]) {
|
||||||
|
case FORM_PARSE_TYPE: {
|
||||||
|
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
|
||||||
|
}
|
||||||
case FORM_PARSE_SCALAR: {
|
case FORM_PARSE_SCALAR: {
|
||||||
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor);
|
auto res = ObtainCNodeAttrInScalarForm(attr_tensor);
|
||||||
|
kv.insert(std::pair<string, ValuePtr>(attr_tensor.name(), res));
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
case FORM_PARSE_TENSOR: {
|
case FORM_PARSE_TENSOR: {
|
||||||
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
|
return ObtainValueNodeInTensorForm(value_node_name, attr_tensor);
|
||||||
}
|
}
|
||||||
case FORM_PARSE_TYPE: {
|
default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name";
|
||||||
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor);
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name";
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ValueNodePtr new_value_node;
|
||||||
|
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) {
|
||||||
|
if (kv.size() == 1) {
|
||||||
|
auto iter = kv.begin();
|
||||||
|
new_value_node = NewValueNode(iter->second);
|
||||||
|
new_value_node->set_abstract(iter->second->ToAbstract());
|
||||||
|
} else {
|
||||||
|
auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv);
|
||||||
|
new_value_node = NewValueNode(value_ptr);
|
||||||
|
new_value_node->set_abstract(value_ptr->ToAbstract());
|
||||||
|
}
|
||||||
|
anfnode_build_map_[value_node_name] = new_value_node;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
|
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) {
|
||||||
const std::string &value_node_name = node_proto.output(0);
|
const std::string &value_node_name = node_proto.output(0);
|
||||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
|
const onnx::AttributeProto &attr_proto = node_proto.attribute(0);
|
||||||
|
@ -408,22 +514,23 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &
|
||||||
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
|
MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
return GetAttrValueForValueNode(value_node_name, attr_proto);
|
||||||
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
|
||||||
|
|
||||||
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
|
std::unordered_map<std::string, abstract::AbstractTensorPtr>
|
||||||
|
AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) {
|
||||||
|
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
|
||||||
|
for (int i = 0; i < attr_proto.tensors_size(); i++) {
|
||||||
std::vector<int> shape_vec;
|
std::vector<int> shape_vec;
|
||||||
const onnx::TensorProto &attr_tensor = attr_proto.t();
|
const onnx::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
|
||||||
shape_vec.push_back(attr_tensor.dims(i));
|
shape_vec.push_back(attr_tensor.dims(j));
|
||||||
}
|
}
|
||||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]);
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec);
|
||||||
MS_EXCEPTION_IF_NULL(abstract_tensor);
|
kv.insert(std::pair<string, abstract::AbstractTensorPtr>(attr_tensor.name(), abstract_tensor));
|
||||||
return abstract_tensor;
|
}
|
||||||
|
return kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||||
|
@ -437,25 +544,16 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
||||||
const std::string &node_name = node_proto.output(0);
|
const std::string &node_name = node_proto.output(0);
|
||||||
const std::string &fullname_with_scope = node_proto.domain();
|
const std::string &fullname_with_scope = node_proto.domain();
|
||||||
const std::string &node_type = node_proto.op_type();
|
const std::string &node_type = node_proto.op_type();
|
||||||
PrimitivePtr prim = std::make_shared<Primitive>(node_type);
|
PrimitivePtr prim = std::make_shared<mindspore::Primitive>(node_type);
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
prim->set_instance_name(node_type);
|
prim->set_instance_name(node_type);
|
||||||
|
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
|
||||||
abstract::AbstractTensorPtr abstract = nullptr;
|
string shape_ref_attr_name;
|
||||||
abstract::AbstractTensorPtr abstract_first = nullptr;
|
|
||||||
abstract::AbstractTensorPtr abstract_second = nullptr;
|
|
||||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||||
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
|
const onnx::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||||
if (attr_proto.name() == kCNodeShapeAttr) {
|
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
|
||||||
abstract = GetAbstractForCNode(attr_proto);
|
shape_ref_attr_name = attr_proto.ref_attr_name();
|
||||||
continue;
|
kv = GetAbstractForCNode(attr_proto);
|
||||||
}
|
|
||||||
if (attr_proto.name() == kCNodeShape1Attr) {
|
|
||||||
abstract_first = GetAbstractForCNode(attr_proto);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (attr_proto.name() == kCNodeShape2Attr) {
|
|
||||||
abstract_second = GetAbstractForCNode(attr_proto);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (!GetAttrValueForCNode(prim, attr_proto)) {
|
if (!GetAttrValueForCNode(prim, attr_proto)) {
|
||||||
|
@ -463,6 +561,7 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> inputs;
|
std::vector<AnfNodePtr> inputs;
|
||||||
inputs.clear();
|
inputs.clear();
|
||||||
for (int i = 0; i < node_proto.input_size(); ++i) {
|
for (int i = 0; i < node_proto.input_size(); ++i) {
|
||||||
|
@ -481,26 +580,20 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
||||||
inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr));
|
inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr));
|
||||||
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||||
if (node_type == "LayerNorm") {
|
if (0 == kv.size()) {
|
||||||
AbstractBasePtrList elem;
|
|
||||||
elem.push_back(abstract);
|
|
||||||
elem.push_back(abstract_first);
|
|
||||||
elem.push_back(abstract_second);
|
|
||||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
|
||||||
} else if (node_type == "ArgMaxWithValue") {
|
|
||||||
AbstractBasePtrList elem;
|
|
||||||
elem.push_back(abstract);
|
|
||||||
elem.push_back(abstract_first);
|
|
||||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
|
||||||
} else if (nullptr == abstract) {
|
|
||||||
AbstractBasePtrList elem;
|
AbstractBasePtrList elem;
|
||||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||||
elem.push_back(cnode_ptr->input(index)->abstract());
|
elem.push_back(cnode_ptr->input(index)->abstract());
|
||||||
}
|
}
|
||||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||||
|
} else if (1 == kv.size()) {
|
||||||
|
std::unordered_map<std::string, abstract::AbstractTensorPtr>::iterator iter = kv.begin();
|
||||||
|
cnode_ptr->set_abstract(iter->second);
|
||||||
} else {
|
} else {
|
||||||
|
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
|
||||||
cnode_ptr->set_abstract(abstract);
|
cnode_ptr->set_abstract(abstract);
|
||||||
}
|
}
|
||||||
|
|
||||||
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
|
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
|
||||||
anfnode_build_map_[node_name] = cnode_ptr;
|
anfnode_build_map_[node_name] = cnode_ptr;
|
||||||
return cnode_ptr;
|
return cnode_ptr;
|
||||||
|
|
|
@ -59,18 +59,15 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
||||||
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
|
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
|
||||||
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
const onnx::TensorProto &attr_tensor);
|
const onnx::TensorProto &attr_tensor);
|
||||||
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
|
ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor);
|
||||||
const onnx::TensorProto &attr_tensor);
|
|
||||||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||||
const onnx::TensorProto &attr_tensor);
|
const onnx::TensorProto &attr_tensor);
|
||||||
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
|
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
|
||||||
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||||
|
bool GetAttrValueForValueNode(const std::string &value_node_name, const onnx::AttributeProto &attr_proto);
|
||||||
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
|
||||||
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
|
|
||||||
const onnx::TensorProto &attr_tensor);
|
|
||||||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||||
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
std::unordered_map<std::string,
|
||||||
|
abstract::AbstractTensorPtr> GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string producer_name_;
|
std::string producer_name_;
|
||||||
|
|
Loading…
Reference in New Issue