Fix mindir export and import problem

This commit is contained in:
l00591931 2021-11-12 09:39:50 +08:00
parent 5211733add
commit 2e1319005b
2 changed files with 43 additions and 5 deletions

View File

@ -118,6 +118,7 @@ class IrExportBuilder {
bool SetScalarToAttributeProtoForInt_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetScalarToAttributeProtoForInt_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetTypeToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string);
@ -883,9 +884,9 @@ bool IrExportBuilder::SetScalarToAttributeProtoForInt_ir(const ValuePtr &value,
return true;
}
bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (value == nullptr || attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
bool IrExportBuilder::SetTypeToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "AttributeProto is null!";
}
if (value->isa<Int>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
@ -905,7 +906,31 @@ bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_
return false;
}
tensor_proto->set_data_type(data_type);
} else if (value->isa<StringImm>()) {
} else if (value->isa<UInt>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
auto uint_value = value->cast<FloatPtr>();
auto data_type = GetMindirDataBitsFloatType(uint_value->nbits());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
} else if (value->isa<Bool>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
} else {
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
return false;
}
return true;
}
bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
if (attr_proto == nullptr) {
MS_LOG(EXCEPTION) << "AttributeProto is null!";
}
if (value->isa<StringImm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
attr_proto->add_strings(GetValue<std::string>(value));
} else if (value->isa<BoolImm>()) {
@ -963,11 +988,18 @@ bool IrExportBuilder::SetScalarToAttributeProtoForInt_irs(const ValuePtr &value,
bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string) {
if (value == nullptr) {
MS_LOG(ERROR) << "Value is nullptr";
return false;
}
string value_name = "value" + std::to_string(GetTupleIndex());
if (seq_string != nullptr) {
*seq_string += value_name + ",";
}
return SetScalarToAttributeProto_irs(value, attr_proto);
if (value->isa<StringImm>() || value->isa<Scalar>()) {
return SetScalarToAttributeProto_irs(value, attr_proto);
}
return SetTypeToAttributeProto_irs(value, attr_proto);
}
bool IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,

View File

@ -235,6 +235,8 @@ string GetTypeString(const std::string &ref_attr_name, size_t *pos) {
return ref_attr_name.substr(*pos, string("type:").length() - 1);
} else if ((*pos = ref_attr_name.find("tensor:")) != std::string::npos) {
return ref_attr_name.substr(*pos, string("tensor:").length() - 1);
} else if (ref_attr_name == "none") {
return ref_attr_name;
}
return "";
}
@ -655,6 +657,10 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
ObtainCNodeAttrInTensorForm(prim, attr_proto);
break;
}
case FORM_PARSE_NONE: {
prim->AddAttr(attr_name, kNone);
break;
}
default:
MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
return false;