!25734 Enable ref type in mindir export and loading

Merge pull request !25734 from LiangZhibo/mindir_ref
This commit is contained in:
i-robot 2021-11-08 06:03:52 +00:00 committed by Gitee
commit 2cd87ddf79
4 changed files with 72 additions and 11 deletions

View File

@ -107,10 +107,11 @@ class IrExportBuilder {
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
bool SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto);
bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
bool SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs, mind_ir::TensorProto *const tensor_proto);
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string);
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const abstract::AbstractBasePtr &abstract,
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string);
bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
@ -349,6 +350,9 @@ bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
tensor_proto->add_dims(dim);
}
}
if (!SetTensorProtoForRef(type, node->abstract(), tensor_proto)) {
return false;
}
} else if (type->isa<Tuple>()) {
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
@ -402,6 +406,25 @@ bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &sh
return true;
}
bool IrExportBuilder::SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs,
mind_ir::TensorProto *const tensor_proto) {
if (!type->isa<RefType>()) {
return true;
}
auto abs_ref = abs->cast<abstract::AbstractRefPtr>();
if (abs_ref == nullptr) {
MS_LOG(ERROR) << "The abstract " << abs->ToString() << " should be AbstractRef.";
return false;
}
auto ref_key_value = abs_ref->ref_key_value();
if (ref_key_value == nullptr) {
MS_LOG(ERROR) << "The ref_key_value of abstract ref " << abs->ToString() << " is nullptr";
return false;
}
tensor_proto->set_ref_key(ref_key_value->name());
return true;
}
bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto) {
if (param == nullptr || tensor_proto == nullptr) {
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
@ -479,7 +502,7 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
return type_name;
}
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const AbstractBasePtr &abs,
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(shape);
@ -488,8 +511,9 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
*seq_string += "Tuple[";
auto elements = type->cast<TuplePtr>()->elements();
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>()->elements();
for (size_t i = 0; i < elements.size(); i++) {
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string)) {
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], tuple_abs[i], attr_proto, seq_string)) {
return false;
}
}
@ -499,7 +523,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
*seq_string += shape_name + ",";
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name(shape_name);
return SetTensorProto(type, shape, tensor_proto);
return SetTensorProto(type, shape, tensor_proto) && SetTensorProtoForRef(type, abs, tensor_proto);
} else if (type->isa<Number>()) {
if (type->isa<Bool>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
@ -531,6 +555,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro
MS_EXCEPTION_IF_NULL(node);
auto type = node->Type();
auto shape = node->Shape();
auto abs = node->abstract();
// For the bprop fg which has not been renormalized.
if (type == nullptr || shape == nullptr) {
return true;
@ -538,7 +563,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro
ResetTupleIndex();
std::string seq_string = "shape:";
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
if (!SetShapeToNodeProto(type, shape, attr_proto, &seq_string)) {
if (!SetShapeToNodeProto(type, shape, abs, attr_proto, &seq_string)) {
MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed.";
return false;
}

View File

@ -394,11 +394,36 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Get tensor_info fail.";
return false;
}
auto tensor_abstract = tensor_info->ToAbstract();
node->set_abstract(tensor_abstract);
if (tensor_proto.has_ref_key() && top_graph_ != nullptr) {
auto ref_key = std::make_shared<RefKey>(tensor_proto.ref_key());
auto abs_ref_key = ref_key->ToAbstract();
auto abs_value = tensor_info->ToAbstract()->Broaden()->cast<abstract::AbstractTensorPtr>();
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
node->set_abstract(abs_ref);
auto parameters = top_graph_->parameters();
for (auto parameter : parameters) {
auto parameter_abs = parameter->abstract();
if (parameter_abs->isa<abstract::AbstractRef>()) {
auto parameter_abs_value = parameter_abs->cast<abstract::AbstractRefPtr>()->ref_key_value();
if (parameter_abs_value->name() == tensor_proto.ref_key()) {
node->set_default_param(parameter->cast<ParameterPtr>()->default_param());
break;
}
}
}
} else {
auto tensor_abstract = tensor_info->ToAbstract();
node->set_abstract(tensor_abstract);
}
} else if (value_proto.has_denotation()) {
if (value_proto.denotation() == "UMonadType") {
node->set_abstract(kUMonad->ToAbstract());
} else if (value_proto.denotation() == "IOMonadType") {
node->set_abstract(kIOMonad->ToAbstract());
}
MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation();
}
anfnode_build_map_[value_proto.name()] = node;
@ -849,9 +874,17 @@ std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::Get
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
MS_EXCEPTION_IF_NULL(tensor_info);
auto abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(abstract);
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract));
if (attr_tensor.has_ref_key()) {
auto ref_key = std::make_shared<RefKey>(attr_tensor.ref_key());
auto abs_ref_key = ref_key->ToAbstract();
auto abs_value = tensor_info->ToAbstract()->Broaden()->cast<abstract::AbstractTensorPtr>();
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abs_ref));
} else {
auto abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(abstract);
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract));
}
}
return kv;
}
@ -1198,6 +1231,7 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
return nullptr;
}
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name();
top_graph_ = dstGraph;
for (int i = 0; i < model_proto.functions_size(); ++i) {
const auto &graph_proto = model_proto.functions(i);
FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[graph_proto.name()]);

View File

@ -81,6 +81,7 @@ class MSANFModelParser {
const mind_ir::AttributeProto &attr_proto);
AnfNodePtr GetAnfNode(const std::string &node_name);
FuncGraphPtr top_graph_ = nullptr;
std::string producer_name_;
std::string model_version_;
std::string ir_version_;

View File

@ -133,6 +133,7 @@ message TensorProto {
repeated double double_data = 10;
repeated uint64 uint64_data = 11;
optional ExternalDataProto external_data = 12;
optional string ref_key = 13;
}