From 2fc470c2cb3ac978a50fab8792d9096fa97df2e7 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Mon, 1 Nov 2021 14:33:46 +0800 Subject: [PATCH] Enable mindir of ref --- .../transform/express_ir/mindir_exporter.cc | 37 +++++++++++++--- .../core/load_mindir/anf_model_parser.cc | 44 ++++++++++++++++--- mindspore/core/load_mindir/anf_model_parser.h | 1 + mindspore/core/proto/mind_ir.proto | 1 + 4 files changed, 72 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index 9346e08be96..7b0e67755b2 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -107,10 +107,11 @@ class IrExportBuilder { bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto); bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto); + bool SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs, mind_ir::TensorProto *const tensor_proto); bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto); bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto); - bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto, - std::string *const seq_string); + bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const abstract::AbstractBasePtr &abstract, + mind_ir::AttributeProto *const attr_proto, std::string *const seq_string); bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); @@ -349,6 +350,9 @@ bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn tensor_proto->add_dims(dim); } } + if (!SetTensorProtoForRef(type, node->abstract(), tensor_proto)) { + return false; + } } else if (type->isa()) { auto tup_shape = shape->cast(); value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); @@ -402,6 +406,25 @@ bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &sh return true; } +bool IrExportBuilder::SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs, + mind_ir::TensorProto *const tensor_proto) { + if (!type->isa()) { + return true; + } + auto abs_ref = abs->cast(); + if (abs_ref == nullptr) { + MS_LOG(ERROR) << "The abstract " << abs->ToString() << " should be AbstractRef."; + return false; + } + auto ref_key_value = abs_ref->ref_key_value(); + if (ref_key_value == nullptr) { + MS_LOG(ERROR) << "The ref_key_value of abstract ref " << abs->ToString() << " is nullptr"; + return false; + } + tensor_proto->set_ref_key(ref_key_value->name()); + return true; +} + bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) { if (param == nullptr || tensor_proto == nullptr) { MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; @@ -479,7 +502,7 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { return type_name; } -bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, +bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const AbstractBasePtr &abs, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(shape); @@ -488,8 +511,9 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt *seq_string += "Tuple["; auto elements = type->cast()->elements(); auto tuple_shape = shape->cast()->shape(); + auto tuple_abs = abs->cast()->elements(); for (size_t i = 0; i < elements.size(); i++) { - if (!SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string)) { + if (!SetShapeToNodeProto(elements[i], tuple_shape[i], tuple_abs[i], attr_proto, seq_string)) { return false; } } @@ -499,7 +523,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt *seq_string += shape_name + ","; mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); tensor_proto->set_name(shape_name); - return SetTensorProto(type, shape, tensor_proto); + return SetTensorProto(type, shape, tensor_proto) && SetTensorProtoForRef(type, abs, tensor_proto); } else if (type->isa()) { if (type->isa()) { attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); @@ -531,6 +555,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro MS_EXCEPTION_IF_NULL(node); auto type = node->Type(); auto shape = node->Shape(); + auto abs = node->abstract(); // For the bprop fg which has not been renormalized. if (type == nullptr || shape == nullptr) { return true; @@ -538,7 +563,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro ResetTupleIndex(); std::string seq_string = "shape:"; mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); - if (!SetShapeToNodeProto(type, shape, attr_proto, &seq_string)) { + if (!SetShapeToNodeProto(type, shape, abs, attr_proto, &seq_string)) { MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed."; return false; } diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index 685aa327186..74ca3688ef4 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -394,11 +394,36 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0); tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto); if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Get tensor_info fail."; return false; } - auto tensor_abstract = tensor_info->ToAbstract(); - node->set_abstract(tensor_abstract); + if (tensor_proto.has_ref_key() && top_graph_ != nullptr) { + auto ref_key = std::make_shared(tensor_proto.ref_key()); + auto abs_ref_key = ref_key->ToAbstract(); + auto abs_value = tensor_info->ToAbstract()->Broaden()->cast(); + auto abs_ref = std::make_shared(abs_ref_key, abs_value); + node->set_abstract(abs_ref); + auto parameters = top_graph_->parameters(); + for (auto parameter : parameters) { + auto parameter_abs = parameter->abstract(); + if (parameter_abs->isa()) { + auto parameter_abs_value = parameter_abs->cast()->ref_key_value(); + if (parameter_abs_value->name() == tensor_proto.ref_key()) { + node->set_default_param(parameter->cast()->default_param()); + break; + } + } + } + } else { + auto tensor_abstract = tensor_info->ToAbstract(); + node->set_abstract(tensor_abstract); + } } else if (value_proto.has_denotation()) { + if (value_proto.denotation() == "UMonadType") { + node->set_abstract(kUMonad->ToAbstract()); + } else if (value_proto.denotation() == "IOMonadType") { + node->set_abstract(kIOMonad->ToAbstract()); + } MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation(); } anfnode_build_map_[value_proto.name()] = node; @@ -849,9 +874,17 @@ std::unordered_map MSANFModelParser::Get tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec); MS_EXCEPTION_IF_NULL(tensor_info); - auto abstract = tensor_info->ToAbstract(); - MS_EXCEPTION_IF_NULL(abstract); - kv.insert(std::pair(attr_tensor.name(), abstract)); + if (attr_tensor.has_ref_key()) { + auto ref_key = std::make_shared(attr_tensor.ref_key()); + auto abs_ref_key = ref_key->ToAbstract(); + auto abs_value = tensor_info->ToAbstract()->Broaden()->cast(); + auto abs_ref = std::make_shared(abs_ref_key, abs_value); + kv.insert(std::pair(attr_tensor.name(), abs_ref)); + } else { + auto abstract = tensor_info->ToAbstract(); + MS_EXCEPTION_IF_NULL(abstract); + kv.insert(std::pair(attr_tensor.name(), abstract)); + } } return kv; } @@ -1198,6 +1231,7 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) { return nullptr; } MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name(); + top_graph_ = dstGraph; for (int i = 0; i < model_proto.functions_size(); ++i) { const auto &graph_proto = model_proto.functions(i); FuncGraphPtr graph = GetValueNode(anfnode_build_map_[graph_proto.name()]); diff --git a/mindspore/core/load_mindir/anf_model_parser.h b/mindspore/core/load_mindir/anf_model_parser.h index 2b7a7eef34e..d92be64f053 100644 --- a/mindspore/core/load_mindir/anf_model_parser.h +++ b/mindspore/core/load_mindir/anf_model_parser.h @@ -81,6 +81,7 @@ class MSANFModelParser { const mind_ir::AttributeProto &attr_proto); AnfNodePtr GetAnfNode(const std::string &node_name); + FuncGraphPtr top_graph_ = nullptr; std::string producer_name_; std::string model_version_; std::string ir_version_; diff --git a/mindspore/core/proto/mind_ir.proto b/mindspore/core/proto/mind_ir.proto index dd25812e63f..8e32424ef7f 100644 --- a/mindspore/core/proto/mind_ir.proto +++ b/mindspore/core/proto/mind_ir.proto @@ -133,6 +133,7 @@ message TensorProto { repeated double double_data = 10; repeated uint64 uint64_data = 11; optional ExternalDataProto external_data = 12; + optional string ref_key = 13; }