From 281284d1b97c5313a8db885fe901a8c9bbd0d0a9 Mon Sep 17 00:00:00 2001 From: lianliguang Date: Wed, 15 Dec 2021 17:02:21 +0800 Subject: [PATCH] refactor get abstract for mindir export and load --- .../cxx_api/model/acl/acl_model_multi.cc | 18 +- .../transform/express_ir/mindir_exporter.cc | 95 +++-- .../core/load_mindir/anf_model_parser.cc | 347 +++++++++--------- mindspore/core/load_mindir/anf_model_parser.h | 12 +- mindspore/core/proto/mind_ir.proto | 3 + 5 files changed, 249 insertions(+), 226 deletions(-) diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc index 10754f83d5d..960fa09551e 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc @@ -164,10 +164,20 @@ void AclModelMulti::SetInputs() { for (const auto &in : inputs) { auto input_param = std::dynamic_pointer_cast(in); MS_EXCEPTION_IF_NULL(input_param); - MS_EXCEPTION_IF_NULL(input_param->abstract()); - auto input_value = input_param->abstract()->GetValueTrack(); - auto tensor = input_value->cast(); - MS_EXCEPTION_IF_NULL(tensor); + auto input_abs = input_param->abstract(); + MS_EXCEPTION_IF_NULL(input_abs); + auto tensor_abs = input_abs->cast(); + if (tensor_abs == nullptr) { + MS_LOG(EXCEPTION) << "The graph input type is not a tensor. input args info:" << input_abs->ToString(); + } + auto shape_ptr = tensor_abs->BuildShape(); + MS_EXCEPTION_IF_NULL(shape_ptr); + auto tensor_shape = shape_ptr->cast(); + MS_EXCEPTION_IF_NULL(tensor_shape); + auto elem = tensor_abs->element(); + MS_EXCEPTION_IF_NULL(elem); + auto type_id = elem->BuildType()->type_id(); + auto tensor = std::make_shared(type_id, tensor_shape->shape()); std::vector shape = tensor->shape_c(); auto input_tensor = MSTensor::CreateTensor(input_param->name(), static_cast(tensor->data_type_c()), diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index ebb3b7baa15..0a75ad08162 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -111,12 +111,11 @@ class IrExportBuilder { bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto); - bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto); - bool SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs, mind_ir::TensorProto *const tensor_proto); + bool SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto); bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto); bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto); - bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const abstract::AbstractBasePtr &abstract, - mind_ir::AttributeProto *const attr_proto, std::string *const seq_string); + bool SetShapeToNodeProto(const abstract::AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto, + std::string *const seq_string); bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); @@ -399,30 +398,18 @@ bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn return true; } if (type->isa() && shape->isa()) { - auto tensor = type->cast(); - MS_EXCEPTION_IF_NULL(tensor); - auto elem_type = tensor->element(); - const auto &dims = shape->cast()->shape(); mind_ir::TensorProto *tensor_proto = value_proto->add_tensor(); - auto data_type = GetMindirDataType(elem_type->type_id()); - if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { - return false; - } - tensor_proto->set_data_type(data_type); - if (dims.size() == 0) { - MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0."; - } else { - for (const auto &dim : dims) { - MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; - tensor_proto->add_dims(dim); - } - } - if (!SetTensorProtoForRef(type, node->abstract(), tensor_proto)) { + if (!SetTensorProto(node->abstract(), tensor_proto)) { return false; } } else if (type->isa()) { - auto tup_shape = shape->cast(); - value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); + std::string seq_string = "shape:"; + mind_ir::AttributeProto *attribute = value_proto->mutable_attr_info(); + if (!SetShapeToNodeProto(node->abstract(), attribute, &seq_string)) { + MS_LOG(ERROR) << "Set shape to Proto for " << node->DebugString() << " failed."; + return false; + } + attribute->set_ref_attr_name(seq_string); } else { value_proto->set_denotation(type->type_name()); } @@ -454,14 +441,16 @@ bool IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir:: return true; } -bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, - mind_ir::TensorProto *const tensor_proto) { +bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto) { + auto type = abstract->BuildType(); + auto shape = abstract->BuildShape(); if (!type->isa() || !shape->isa()) { MS_LOG(ERROR) << "Type or shape is not supported! " << type->ToString(); return false; } auto tensor = type->cast(); - const auto &dims = shape->cast()->shape(); + auto tensor_shape = shape->cast(); + const auto &dims = tensor_shape->shape(); auto data_type = GetMindirDataType(tensor->element()->type_id()); if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { return false; @@ -470,22 +459,29 @@ bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &sh for (const auto &dim : dims) { tensor_proto->add_dims(dim); } - return true; -} - -bool IrExportBuilder::SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs, - mind_ir::TensorProto *const tensor_proto) { + if (tensor_shape->IsDynamic()) { + auto min_shape = tensor_shape->min_shape(); + auto max_shape = tensor_shape->max_shape(); + for (auto item : min_shape) { + tensor_proto->add_min_dims(item); + } + for (auto item : max_shape) { + tensor_proto->add_max_dims(item); + } + } + // Deal Ref if (!type->isa()) { return true; } - auto abs_ref = abs->cast(); + + auto abs_ref = abstract->cast(); if (abs_ref == nullptr) { - MS_LOG(ERROR) << "The abstract " << abs->ToString() << " should be AbstractRef."; + MS_LOG(ERROR) << "The abstract " << abstract->ToString() << " should be AbstractRef."; return false; } auto ref_key_value = abs_ref->ref_key_value(); if (ref_key_value == nullptr) { - MS_LOG(INFO) << "The ref_key_value of abstract ref " << abs->ToString() << " is nullptr"; + MS_LOG(INFO) << "The ref_key_value of abstract ref " << abstract->ToString() << " is nullptr"; return true; } tensor_proto->set_ref_key(ref_key_value->name()); @@ -497,7 +493,7 @@ bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir:: MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; } MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString(); - return SetTensorProto(param->Type(), param->Shape(), tensor_proto); + return SetTensorProto(param->abstract(), tensor_proto); } bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { @@ -569,18 +565,16 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { return type_name; } -bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const AbstractBasePtr &abs, - mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { - MS_EXCEPTION_IF_NULL(type); - MS_EXCEPTION_IF_NULL(shape); +bool IrExportBuilder::SetShapeToNodeProto(const AbstractBasePtr &abs, mind_ir::AttributeProto *const attr_proto, + std::string *const seq_string) { + auto type = abs->BuildType(); + auto shape = abs->BuildShape(); MS_EXCEPTION_IF_NULL(seq_string); if (type->isa()) { *seq_string += "Tuple["; - auto elements = type->cast()->elements(); - auto tuple_shape = shape->cast()->shape(); - auto tuple_abs = abs->cast()->elements(); - for (size_t i = 0; i < elements.size(); i++) { - if (!SetShapeToNodeProto(elements[i], tuple_shape[i], tuple_abs[i], attr_proto, seq_string)) { + auto tuple_abs = abs->cast(); + for (size_t i = 0; i < tuple_abs->size(); i++) { + if (!SetShapeToNodeProto((*tuple_abs)[i], attr_proto, seq_string)) { return false; } } @@ -590,7 +584,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt *seq_string += shape_name + ","; mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); tensor_proto->set_name(shape_name); - return SetTensorProto(type, shape, tensor_proto) && SetTensorProtoForRef(type, abs, tensor_proto); + return SetTensorProto(abs, tensor_proto); } else if (type->isa()) { if (type->isa()) { attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); @@ -608,9 +602,10 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt } else if (type->isa() || type->isa() || type->isa()) { *seq_string += type->type_name() + ","; } else if (type->isa()) { - auto csr_tensor_type = type->cast(); - MS_EXCEPTION_IF_NULL(csr_tensor_type); - if (!SetShapeToNodeProto(csr_tensor_type->element(), shape, abs, attr_proto, seq_string)) return false; + auto cst_tensor_abs = abs->cast(); + if (!SetShapeToNodeProto(cst_tensor_abs->element(), attr_proto, seq_string)) { + return false; + } } else { MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name(); return false; @@ -634,7 +629,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro ResetTupleIndex(); std::string seq_string = "shape:"; mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); - if (!SetShapeToNodeProto(type, shape, abs, attr_proto, &seq_string)) { + if (!SetShapeToNodeProto(abs, attr_proto, &seq_string)) { MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed."; return false; } diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index d7148528f4c..c1319c3080c 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -241,11 +241,107 @@ string GetTypeString(const std::string &ref_attr_name, size_t *pos) { return ""; } } // namespace +tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor, + bool need_load_data) { + ShapeVector shape; + const int attr_tensor_type = attr_tensor.data_type(); + for (int i = 0; i < attr_tensor.dims_size(); ++i) { + shape.push_back(attr_tensor.dims(i)); + } + tensor::TensorPtr tensor = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); -tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto) { + if (!IsIncLoad() || load_tensor_map_.find(attr_tensor.name()) == load_tensor_map_.end()) { + load_tensor_map_[attr_tensor.name()] = tensor; + } + + MS_EXCEPTION_IF_NULL(tensor); + const std::string &tensor_buf = attr_tensor.raw_data(); + if (attr_tensor.has_raw_data()) { + auto *tensor_data_buf = reinterpret_cast(tensor->data_c()); + auto ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size()); + if (ret != 0) { + MS_LOG(ERROR) << "Failed to get tensor form tensor proto."; + return nullptr; + } + } else if (need_load_data) { + MS_LOG(ERROR) << "Failed to get tensor form tensor proto."; + return nullptr; + } + return tensor; +} + +bool MSANFModelParser::SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto &attr_proto, + const AnfNodePtr &node_ptr) { + mindspore::HashMap kv; + string shape_ref_attr_name; + if (attr_proto.ref_attr_name().find("shape:") == string::npos) { + MS_LOG(ERROR) << "Cannot use a attr_proto " << attr_proto.ref_attr_name() << " to init shape."; + return false; + } + bool is_tuple_or_list = false; + + shape_ref_attr_name = attr_proto.ref_attr_name(); + is_tuple_or_list = + shape_ref_attr_name.find("Tuple[") != string::npos || shape_ref_attr_name.find("List[") != string::npos; + kv = GetAbstractForNode(attr_proto); + if (kv.empty()) { + return SetEmptyTensorProtoCNodeAbstract(node_ptr); + } else if (!is_tuple_or_list) { + auto iter = kv.begin(); + if (iter->second != nullptr) { + node_ptr->set_abstract(iter->second); + } + } else { + auto abstract = ParserAttrShape(shape_ref_attr_name, kv); + node_ptr->set_abstract(abstract); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Node's attribute is nullptr."; + return false; + } + } + return true; +} + +void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr) { + auto prim = GetCNodePrimitive(cnode_ptr); + auto prim_to_add_attr = prim; + if (prim_to_add_attr != nullptr) { + if (prim->isa()) { + auto func = prim->cast()->function(); + if (func != nullptr && func->isa()) { + prim_to_add_attr = func->cast(); + } + } + prim_to_add_attr->set_attr("is_load", MakeValue(true)); + } + for (int i = 0; i < node_proto.attribute_size(); ++i) { + const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i); + // CNode abstract + if (attr_proto.ref_attr_name().find("shape:") != string::npos) { + SetCNodeAbstract(attr_proto, cnode_ptr); + continue; + } + if (prim_to_add_attr != nullptr) { + if (!GetAttrValueForCNode(prim_to_add_attr, attr_proto)) { + MS_LOG(ERROR) << "Parser prim: " << prim->ToString() << " attributes error : " << attr_proto.DebugString(); + } + } + } +} + +abstract::AbstractTensorPtr MSANFModelParser::GetAbsTensorFromTensorProto(const mind_ir::TensorProto &tensor_proto) { ShapeVector shape; for (int i = 0; i < tensor_proto.dims_size(); ++i) { - shape.push_back(tensor_proto.dims(i)); + shape.emplace_back(tensor_proto.dims(i)); + } + ShapeVector min_shape; + for (int i = 0; i < tensor_proto.min_dims_size(); ++i) { + min_shape.emplace_back(tensor_proto.min_dims(i)); + } + + ShapeVector max_shape; + for (int i = 0; i < tensor_proto.max_dims_size(); ++i) { + max_shape.emplace_back(tensor_proto.min_dims(i)); } if (!tensor_proto.has_data_type()) { @@ -253,14 +349,20 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T MS_LOG(ERROR) << "mind_ir TensorProto has no data_type."; return nullptr; } - if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) { + auto iter = kDefaultValueSwitchMap.find(tensor_proto.data_type()); + if (iter == kDefaultValueSwitchMap.end()) { MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed"; MS_LOG(ERROR) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!"; return nullptr; } - - tensor::TensorPtr tensor_info = - std::make_shared(kDefaultValueSwitchMap[tensor_proto.data_type()], shape); + auto tensor_shape = std::make_shared(shape, min_shape, max_shape); + auto tensor_info = std::make_shared(TypeIdToType(iter->second), tensor_shape); + if (tensor_proto.has_ref_key()) { + auto ref_key = std::make_shared(tensor_proto.ref_key()); + auto abs_ref_key = ref_key->ToAbstract(); + auto abs_ref = std::make_shared(abs_ref_key, tensor_info); + return abs_ref; + } return tensor_info; } @@ -277,51 +379,34 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, node->set_debug_info(debug_info_ptr); node->set_name(debug_info_name); - tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto); - if (tensor_info == nullptr) { - return false; - } + ParamInfoPtr param_info = std::make_shared(); + param_info->set_name(debug_info_name); + MS_LOG(DEBUG) << "Load parameter name: " << parameter_proto.name(); - if (!IsIncLoad() || load_tensor_map_.find(parameter_proto.name()) == load_tensor_map_.end()) { - load_tensor_map_[parameter_proto.name()] = tensor_info; - } else { + if (IsIncLoad() && load_tensor_map_.find(parameter_proto.name()) != load_tensor_map_.end()) { MS_LOG(DEBUG) << "Parameter: " << parameter_proto.name() << " has been already loaded, use it again."; - tensor::TensorPtr load_tensor_info = load_tensor_map_[parameter_proto.name()]; - auto tensor_abstract = load_tensor_info->ToAbstract(); - MS_EXCEPTION_IF_NULL(tensor_abstract); - node->set_abstract(tensor_abstract); + auto load_tensor_info = load_tensor_map_[parameter_proto.name()]; + MS_EXCEPTION_IF_NULL(load_tensor_info); + load_tensor_info->set_param_info(param_info); node->set_default_param(load_tensor_info); + node->set_abstract(load_tensor_info->ToAbstract()); anfnode_build_map_[parameter_proto.name()] = node; return true; } - ParamInfoPtr param_info = std::make_shared(); - param_info->set_name(debug_info_name); - tensor_info->set_param_info(param_info); - - auto tensor_abstract = tensor_info->ToAbstract(); - MS_EXCEPTION_IF_NULL(tensor_abstract); - node->set_abstract(tensor_abstract); - + auto tensor = GenerateTensorPtrFromTensorProto(parameter_proto, false); + tensor->set_param_info(param_info); if (parameter_proto.has_raw_data()) { - std::string initial_data = parameter_proto.raw_data(); - auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); - MS_EXCEPTION_IF_NULL(tensor_data_buf); - auto ret = memcpy_s(tensor_data_buf, static_cast(tensor_info->data().nbytes()), initial_data.data(), - initial_data.size()); - if (ret != 0) { - MS_LOG(ERROR) << "Build parameter occur memcpy_s error."; - return false; - } - node->set_default_param(tensor_info); + node->set_default_param(tensor); } else if (parameter_proto.has_external_data()) { - auto ret = GetTensorDataFromExternal(parameter_proto, tensor_info); + auto ret = GetTensorDataFromExternal(parameter_proto, tensor); if (!ret) { return false; } - node->set_default_param(tensor_info); + node->set_default_param(tensor); } else { MS_LOG(DEBUG) << "Parameter will load initialized data."; } + node->set_abstract(tensor->ToAbstract()); anfnode_build_map_[parameter_proto.name()] = node; return true; @@ -398,7 +483,7 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi // Set abstract of the parameter if (value_proto.tensor_size() > 0) { const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0); - tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto); + auto tensor_info = GetAbsTensorFromTensorProto(tensor_proto); if (tensor_info == nullptr) { MS_LOG(ERROR) << "Get tensor_info fail."; return false; @@ -406,8 +491,7 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi if (tensor_proto.has_ref_key() && top_graph_ != nullptr) { auto ref_key = std::make_shared(tensor_proto.ref_key()); auto abs_ref_key = ref_key->ToAbstract(); - auto abs_value = tensor_info->ToAbstract()->Broaden()->cast(); - auto abs_ref = std::make_shared(abs_ref_key, abs_value); + auto abs_ref = std::make_shared(abs_ref_key, tensor_info); node->set_abstract(abs_ref); auto parameters = top_graph_->parameters(); for (auto parameter : parameters) { @@ -421,8 +505,7 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi } } } else { - auto tensor_abstract = tensor_info->ToAbstract(); - node->set_abstract(tensor_abstract); + node->set_abstract(tensor_info); } } else if (value_proto.has_denotation()) { if (value_proto.denotation() == "UMonadType") { @@ -432,6 +515,12 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi } MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation(); } + if (value_proto.has_attr_info()) { + if (!SetNodeAbstractFromAttrProto(value_proto.attr_info(), node)) { + MS_LOG(ERROR) << "Failed to get abstract from proto."; + return false; + } + } anfnode_build_map_[value_proto.name()] = node; return true; } @@ -589,19 +678,10 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) { MS_EXCEPTION_IF_NULL(prim); const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0); - const int attr_tensor_type = attr_tensor.data_type(); - ShapeVector shape; - for (int i = 0; i < attr_tensor.dims_size(); ++i) { - shape.push_back(attr_tensor.dims(i)); - } - tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); - MS_EXCEPTION_IF_NULL(tensor_info); - const std::string &tensor_buf = attr_tensor.raw_data(); - auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); - auto ret = - memcpy_s(tensor_data_buf, static_cast(tensor_info->data().nbytes()), tensor_buf.data(), tensor_buf.size()); - if (ret != 0) { - MS_LOG(ERROR) << "Obtain CNode in TensorForm occur memcpy_s error."; + auto tensor_info = GenerateTensorPtrFromTensorProto(attr_tensor); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Failed to get attr[" << attr_proto.name() << "] for node " << prim->ToString() + << " from the proto."; return false; } prim->AddAttr(attr_proto.name(), MakeValue(tensor_info)); @@ -682,21 +762,10 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name, const mind_ir::TensorProto &attr_tensor) { - const int attr_tensor_type = attr_tensor.data_type(); - ShapeVector shape; - for (int i = 0; i < attr_tensor.dims_size(); ++i) { - shape.push_back(attr_tensor.dims(i)); + tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(attr_tensor); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Failed to get the tensor for ValueNode."; } - tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); - MS_EXCEPTION_IF_NULL(tensor_info); - const std::string &tensor_buf = attr_tensor.raw_data(); - auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); - auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); - if (ret != 0) { - MS_LOG(ERROR) << "Obtain ValueNode in TensorForm occur memcpy_s error."; - return false; - } - auto new_value_node = NewValueNode(MakeValue(tensor_info)); MS_EXCEPTION_IF_NULL(new_value_node); auto tensor_abstract = tensor_info->ToAbstract(); @@ -740,9 +809,9 @@ bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_n MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; return false; } - auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); - abstract::AbstractTypePtr abs_type = std::make_shared(std::make_shared()); - new_value_node->set_abstract(abs_type); + auto value = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); + auto new_value_node = NewValueNode(value); + new_value_node->set_abstract(value->ToAbstract()); anfnode_build_map_[value_node_name] = new_value_node; return true; } @@ -872,29 +941,13 @@ bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node return GetAttrValueForValueNode(value_node_name, attr_proto); } -mindspore::HashMap MSANFModelParser::GetAbstractForCNode( +mindspore::HashMap MSANFModelParser::GetAbstractForNode( const mind_ir::AttributeProto &attr_proto) { mindspore::HashMap kv; for (int i = 0; i < attr_proto.tensors_size(); ++i) { - ShapeVector shape_vec; const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i); - for (int j = 0; j < attr_tensor.dims_size(); ++j) { - shape_vec.push_back(attr_tensor.dims(j)); - } - tensor::TensorPtr tensor_info = - std::make_shared(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec); - MS_EXCEPTION_IF_NULL(tensor_info); - if (attr_tensor.has_ref_key()) { - auto ref_key = std::make_shared(attr_tensor.ref_key()); - auto abs_ref_key = ref_key->ToAbstract(); - auto abs_value = tensor_info->ToAbstract()->Broaden()->cast(); - auto abs_ref = std::make_shared(abs_ref_key, abs_value); - (void)kv.emplace(attr_tensor.name(), abs_ref); - } else { - auto abstract = tensor_info->ToAbstract(); - MS_EXCEPTION_IF_NULL(abstract); - (void)kv.emplace(attr_tensor.name(), abstract); - } + auto tensor_info = GetAbsTensorFromTensorProto(attr_tensor); + (void)kv.emplace(attr_tensor.name(), tensor_info); } return kv; } @@ -946,26 +999,6 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr prim->set_instance_name(node_type); } } - MS_EXCEPTION_IF_NULL(prim); - auto prim_to_add_attr = prim; - if (prim->isa()) { - auto func = prim->cast()->function(); - if (func != nullptr && func->isa()) { - prim_to_add_attr = func->cast(); - prim_to_add_attr->set_attr("is_load", MakeValue(true)); - } - } - for (int i = 0; i < node_proto.attribute_size(); ++i) { - const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i); - // CNode abstract - if (attr_proto.ref_attr_name().find("shape:") != string::npos) { - continue; - } - if (!GetAttrValueForCNode(prim_to_add_attr, attr_proto)) { - MS_LOG(ERROR) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString(); - return nullptr; - } - } prim->set_attr("is_load", MakeValue(true)); return std::make_shared(prim); } @@ -991,70 +1024,47 @@ bool MSANFModelParser::CheckCNodePrim(CNodePtr cnode_ptr) { return false; } -void MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, const std::string &node_type) { - if (node_type == "UpdateState") { - cnode_ptr->set_abstract(kUMonad->ToAbstract()); - } else if (node_type == "Depend") { - cnode_ptr->set_abstract(kBool->ToAbstract()); - } else { - AbstractBasePtrList elem; - for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { - auto abs = cnode_ptr->input(index)->abstract(); - if (abs != nullptr) { - if (abs->GetValueTrack() == nullptr) { - abs->set_value(kAnyValue); +bool MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr) { + auto primitive = GetCNodePrimitive(node_ptr); + if (primitive != nullptr) { + auto node_type = primitive->name(); + if (node_type == "UpdateState") { + node_ptr->set_abstract(kUMonad->ToAbstract()); + } else if (node_type == "Depend") { + node_ptr->set_abstract(kBool->ToAbstract()); + } else { + auto cnode_ptr = node_ptr->cast(); + AbstractBasePtrList elem; + for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { + auto abs = cnode_ptr->input(index)->abstract(); + if (abs != nullptr) { + if (abs->GetValueTrack() == nullptr) { + abs->set_value(kAnyValue); + } + elem.push_back(abs); } - elem.push_back(abs); + } + if (!elem.empty()) { + node_ptr->set_abstract(std::make_shared(elem)); } } - if (!elem.empty()) { - cnode_ptr->set_abstract(std::make_shared(elem)); - } + } else { + MS_LOG(ERROR) << "Failed to get the abstract of node:" << node_ptr->DebugString(); + return false; } + return true; } // Set CNode abstract. -void MSANFModelParser::SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) { +bool MSANFModelParser::SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr) { if (CheckCNodePrim(cnode_ptr)) { - return; + return true; } - - mindspore::HashMap kv; - string shape_ref_attr_name; - - bool is_tuple_or_list = false; - for (int i = 0; i < node_proto.attribute_size(); ++i) { - const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i); - if (attr_proto.ref_attr_name().find("shape:") != string::npos) { - shape_ref_attr_name = attr_proto.ref_attr_name(); - is_tuple_or_list = - shape_ref_attr_name.find("Tuple[") != string::npos || shape_ref_attr_name.find("List[") != string::npos; - kv = GetAbstractForCNode(attr_proto); - break; - } - } - - // Because there is not context in unit test, - // abstract->broaden() is replaced by abstract->set_value(kAnyValue). - const std::string &node_type = node_proto.op_type(); - if (kv.size() == 0) { - SetEmptyTensorProtoCNodeAbstract(cnode_ptr, node_type); - } else if (kv.size() == 1 && !is_tuple_or_list) { - auto iter = kv.begin(); - if (iter->second != nullptr) { - iter->second->set_value(kAnyValue); - cnode_ptr->set_abstract(iter->second); - } - } else { - auto abstract = ParserAttrShape(shape_ref_attr_name, kv); - if (abstract == nullptr) { - cnode_ptr->set_abstract(nullptr); - MS_LOG(ERROR) << "Node's attribute is nullptr."; - } else { - abstract->set_value(kAnyValue); - cnode_ptr->set_abstract(abstract); - } + if (!SetNodeAbstractFromAttrProto(attr_proto, cnode_ptr)) { + MS_LOG(ERROR) << "Failed to get CNode abstract from proto."; + return false; } + return true; } CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, @@ -1085,8 +1095,8 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(cnode_ptr); - SetCNodeAbstract(node_proto, cnode_ptr); - + // Set Abstract and prim attr for CNode + SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr); const std::string &fullname_with_scope = node_proto.domain(); string debug_info_name = ParseCNodeName(node_name); auto debug_info_ptr = std::make_shared(debug_info_name); @@ -1148,6 +1158,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra outputFuncGraph->set_return(return_node); MS_LOG(DEBUG) << "Construct funcgraph finined, all success!"; } + return true; } @@ -1386,7 +1397,7 @@ AnfNodePtr MSANFModelParser::GetAnfNode(const std::string &node_name) { return nullptr; } FuncGraphPtr func_graph_ptr = GetValueNode(it->second); - if (func_graph_ptr) { + if (func_graph_ptr != nullptr) { return NewValueNode(func_graph_ptr); } else { return it->second; diff --git a/mindspore/core/load_mindir/anf_model_parser.h b/mindspore/core/load_mindir/anf_model_parser.h index 095960d3e50..a9411b79fbe 100644 --- a/mindspore/core/load_mindir/anf_model_parser.h +++ b/mindspore/core/load_mindir/anf_model_parser.h @@ -63,7 +63,7 @@ class MSANFModelParser { bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map &weights); bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info); bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto); - tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto); + abstract::AbstractTensorPtr GetAbsTensorFromTensorProto(const mind_ir::TensorProto &tensor_proto); CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto); bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); @@ -76,8 +76,10 @@ class MSANFModelParser { bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto); AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto); bool CheckCNodePrim(CNodePtr cnode_ptr); - void SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, const std::string &node_type); - void SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr); + bool SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr); + bool SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr); + bool SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto &attr_proto, const AnfNodePtr &node_ptr); + void SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr); bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto); bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor); @@ -85,9 +87,11 @@ class MSANFModelParser { bool ObtainValueNodeInNoneForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); bool little_endian() { return little_endian_; } - mindspore::HashMap GetAbstractForCNode( + mindspore::HashMap GetAbstractForNode( const mind_ir::AttributeProto &attr_proto); AnfNodePtr GetAnfNode(const std::string &node_name); + tensor::TensorPtr GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor, + bool need_load_data = true); FuncGraphPtr top_graph_ = nullptr; std::string producer_name_; diff --git a/mindspore/core/proto/mind_ir.proto b/mindspore/core/proto/mind_ir.proto index b979955d065..e302021f9db 100644 --- a/mindspore/core/proto/mind_ir.proto +++ b/mindspore/core/proto/mind_ir.proto @@ -52,6 +52,7 @@ message ValueInfoProto { repeated TensorProto tensor = 2; optional string doc_string = 3; optional string denotation = 4; + optional AttributeProto attr_info = 5; // graph input info for other type } @@ -150,6 +151,8 @@ message TensorProto { repeated uint64 uint64_data = 11; optional ExternalDataProto external_data = 12; optional string ref_key = 13; + repeated int64 min_dims = 14; + repeated int64 max_dims = 15; } message ParallelProto {