forked from mindspore-Ecosystem/mindspore
Fix mindir export and import problem
This commit is contained in:
parent
5211733add
commit
2e1319005b
|
@ -118,6 +118,7 @@ class IrExportBuilder {
|
|||
bool SetScalarToAttributeProtoForInt_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetScalarToAttributeProtoForInt_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetTypeToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
|
@ -883,9 +884,9 @@ bool IrExportBuilder::SetScalarToAttributeProtoForInt_ir(const ValuePtr &value,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
|
||||
bool IrExportBuilder::SetTypeToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "AttributeProto is null!";
|
||||
}
|
||||
if (value->isa<Int>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
|
@ -905,7 +906,31 @@ bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_
|
|||
return false;
|
||||
}
|
||||
tensor_proto->set_data_type(data_type);
|
||||
} else if (value->isa<StringImm>()) {
|
||||
} else if (value->isa<UInt>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
auto uint_value = value->cast<FloatPtr>();
|
||||
auto data_type = GetMindirDataBitsFloatType(uint_value->nbits());
|
||||
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
||||
return false;
|
||||
}
|
||||
tensor_proto->set_data_type(data_type);
|
||||
} else if (value->isa<Bool>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "AttributeProto is null!";
|
||||
}
|
||||
if (value->isa<StringImm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
|
||||
attr_proto->add_strings(GetValue<std::string>(value));
|
||||
} else if (value->isa<BoolImm>()) {
|
||||
|
@ -963,11 +988,18 @@ bool IrExportBuilder::SetScalarToAttributeProtoForInt_irs(const ValuePtr &value,
|
|||
|
||||
bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string) {
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Value is nullptr";
|
||||
return false;
|
||||
}
|
||||
string value_name = "value" + std::to_string(GetTupleIndex());
|
||||
if (seq_string != nullptr) {
|
||||
*seq_string += value_name + ",";
|
||||
}
|
||||
return SetScalarToAttributeProto_irs(value, attr_proto);
|
||||
if (value->isa<StringImm>() || value->isa<Scalar>()) {
|
||||
return SetScalarToAttributeProto_irs(value, attr_proto);
|
||||
}
|
||||
return SetTypeToAttributeProto_irs(value, attr_proto);
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
|
||||
|
|
|
@ -235,6 +235,8 @@ string GetTypeString(const std::string &ref_attr_name, size_t *pos) {
|
|||
return ref_attr_name.substr(*pos, string("type:").length() - 1);
|
||||
} else if ((*pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||
return ref_attr_name.substr(*pos, string("tensor:").length() - 1);
|
||||
} else if (ref_attr_name == "none") {
|
||||
return ref_attr_name;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
@ -655,6 +657,10 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
|
|||
ObtainCNodeAttrInTensorForm(prim, attr_proto);
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_NONE: {
|
||||
prim->AddAttr(attr_name, kNone);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
|
||||
return false;
|
||||
|
|
Loading…
Reference in New Issue