From dc63dea10393ba82879948d868dacce147e649ba Mon Sep 17 00:00:00 2001 From: lanzhineng Date: Sun, 8 Aug 2021 16:04:21 +0800 Subject: [PATCH] mindIR support control flow --- mindspore/ccsrc/pipeline/jit/action.cc | 47 ++++- mindspore/ccsrc/pipeline/jit/resource.h | 4 + .../transform/express_ir/mindir_exporter.cc | 71 ++++++-- .../core/load_mindir/anf_model_parser.cc | 161 +++++++++++++---- mindspore/core/load_mindir/anf_model_parser.h | 2 + mindspore/core/proto/mind_ir.proto | 6 + tests/st/control/test_if_mindir.py | 171 ++++++++++++++++++ 7 files changed, 402 insertions(+), 60 deletions(-) create mode 100644 tests/st/control/test_if_mindir.py diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 6460b9786dc..df4f77a8f18 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -121,6 +121,28 @@ using CompileGraphs = compile::CompileGraphs; using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; +inline bool ResetCNodeFromLoad(const AnfNodePtr &node) { + if (node->isa() && node->cast()->get_load_flag()) { + // Process partial("DeadNode",args) when the graph is loaded. + auto operatorPtr = node->cast()->input(0); + // Set abstract of switch(c,f,t) to null + auto prim = GetValueNode(operatorPtr); + if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) { + node->set_abstract(nullptr); + return true; + } + // Set abstract of switch(c,f,t)() to null + prim = GetCNodePrimitive(operatorPtr); + if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) { + node->set_abstract(nullptr); + return true; + } + // Previous inferred value + return true; + } + return false; +} + abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, const abstract::AbstractBasePtrList &args_spec, bool clear) { MS_LOG(DEBUG) << "AbstractAnalyze start"; @@ -133,10 +155,19 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph for (auto &node : manager->all_nodes()) { MS_EXCEPTION_IF_NULL(node); const AbstractBasePtr &prev_inferred = node->abstract(); - // Keep previous inferred value for CNode if is loaded from MindIR. - if (node->isa() && node->cast()->get_load_flag()) { + + // AbstractFunction has context,but contexts in cache have been cleaned. + if (prev_inferred != nullptr && prev_inferred->isa()) { + node->set_abstract(nullptr); + MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr"; continue; } + + // Handle previous inferred value for CNode if is loaded from MindIR + if (res->is_load() && ResetCNodeFromLoad(node)) { + continue; + } + // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. if (!node->isa() || (prev_inferred != nullptr && prev_inferred->isa())) { node->set_abstract(nullptr); @@ -200,6 +231,7 @@ const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) { if (graph->has_attr("is_load")) { loaded_graph = graph; loaded_graph_num += 1; + res->set_is_load(true); } } if (loaded_graph_num == 0) { @@ -218,6 +250,8 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load FuncGraphPtr root_graph = *(manager->roots().begin()); auto root_inputs = root_graph->get_inputs(); auto loaded_inputs = loaded_graph->get_inputs(); + MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString(); + MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString(); size_t root_inputs_num = root_inputs.size(); size_t loaded_inputs_num = loaded_inputs.size(); @@ -229,10 +263,18 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load auto root_input = root_inputs[index]; auto loaded_input = loaded_inputs[index]; + MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1); + MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1); + MS_LOG(DEBUG) << "root_input abstract[" << index + << "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL"); + MS_LOG(DEBUG) << "loaded_input abstract [" << index + << "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL"); + auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast(root_input->Shape()); auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast(loaded_input->Shape()); auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast(root_input->Type()); auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast(loaded_input->Type()); + MS_EXCEPTION_IF_NULL(root_shape); MS_EXCEPTION_IF_NULL(loaded_shape); MS_EXCEPTION_IF_NULL(root_type); @@ -454,6 +496,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { } // Analyze AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec); + // The top graph may be replaced by infer, update the top graph when the infer is done parse::Parser::UpdateTopFuncGraph(result.context->func_graph()); diff --git a/mindspore/ccsrc/pipeline/jit/resource.h b/mindspore/ccsrc/pipeline/jit/resource.h index 8981d825acf..f31bf37376c 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.h +++ b/mindspore/ccsrc/pipeline/jit/resource.h @@ -79,6 +79,8 @@ class Resource : public ResourceBase { gpu_loopsink_flag_ = flag; gpu_loopsink_size_ = size; } + void set_is_load(bool flag) { is_load_ = flag; } + bool is_load() { return is_load_; } bool gpu_loopsink_flag() { return gpu_loopsink_flag_; } int64_t gpu_loopsink_size() { return gpu_loopsink_size_; } // Reclaim resource and clear the cache. @@ -93,6 +95,8 @@ class Resource : public ResourceBase { py::object input_; bool is_cleaned_; bool gpu_loopsink_flag_{false}; + // The func_graph_ is loaded from mindir + bool is_load_{false}; int64_t gpu_loopsink_size_{1}; }; diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index 53626814add..4ee7217e8a9 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -138,6 +138,7 @@ class IrExportBuilder { mind_ir::NodeProto *last_node_{nullptr}; std::list todo_; std::map node_index_map_; + std::set nodeName_; size_t node_index_{0}; size_t shape_index_{0}; }; @@ -145,16 +146,7 @@ class IrExportBuilder { using IrExporterPtr = std::shared_ptr; std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) { - if ((builder_ == nullptr) || (func_graph == nullptr)) { - MS_LOG(EXCEPTION) << "Input params is null."; - } - - // Export model info - builder_->BuildModelInfo(); - - // Export model and return string - builder_->BuildModel(func_graph); - + (void)GetDumpProto(func_graph); return builder_->GetProtoString(func_graph); } @@ -168,7 +160,6 @@ mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, boo // Export model and return string builder_->BuildModel(func_graph, save_tensor_data); - return builder_->Model(); } @@ -191,16 +182,34 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso graph_proto->set_bprop_hash(func_graph->bprop_hash()); ResetNodeIndex(); todo_.clear(); - todo_.push_back(func_graph); + nodeName_.clear(); + // Build the main funcGraph + nodeName_.insert(func_graph->ToString()); + BuildFuncGraph(func_graph, graph_proto, save_tensor_data); + std::set graphVisited; + graphVisited.insert(func_graph); while (!todo_.empty()) { FuncGraphPtr fg = todo_.back(); todo_.pop_back(); - BuildFuncGraph(fg, graph_proto, save_tensor_data); + if (graphVisited.count(fg) > 0) { + continue; + } + if (nodeName_.count(fg->ToString()) > 0) { + MS_LOG(EXCEPTION) << "There is a duplicate name: " << fg->ToString(); + } + nodeName_.insert(fg->ToString()); + graphVisited.insert(fg); + auto graph = model_.add_functions(); + BuildFuncGraph(fg, graph, save_tensor_data); } + // Release resource + nodeName_.clear(); } void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, bool save_tensor_data) { + // Export funcGraph name. + graph_proto->set_name(func_graph->ToString()); // Export parameters // 1. parameters should be mapped to ValueInfoProto // 2. parameters with default value should be mapped to Initializer @@ -232,6 +241,10 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G input_proto->set_name(param_name); SetValueInfoProto(param, input_proto); } + if (nodeName_.count(param_name) > 0) { + MS_LOG(EXCEPTION) << "parameter name is duplicate:" << param_name; + } + nodeName_.insert(param_name); } } @@ -383,9 +396,13 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { } else if (IsValueNode(node)) { FuncGraphPtr fg = GetValueNode(node); todo_.push_back(fg); - type_name = fg->ToString(); + type_name = "REF::" + fg->ToString(); } else if (node->isa() || node->isa()) { - type_name = node->ToString(); + auto nodeName = GetUniqueNodeName(node); + type_name = "REF::" + nodeName; + if (nodeName_.count(nodeName) == 0) { + MS_LOG(EXCEPTION) << "There is not the name: " << nodeName; + } } else { MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name(); } @@ -424,6 +441,9 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt tensor_proto->set_data_type(mind_ir::TensorProto_DataType_UINT64); tensor_proto->add_dims(1); } + } else if (type->isa()) { + attr_proto->set_type(mind_ir::AttributeProto_AttributeType_GRAPH); + *seq_string += type->type_name() + ","; } else if (type->isa() || type->isa() || type->isa()) { *seq_string += type->type_name() + ","; } else { @@ -468,6 +488,10 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons // Build cnode mind_ir::NodeProto *node_proto = graph_proto->add_node(); std::string output_name = GetUniqueNodeName(node); + if (nodeName_.count(output_name) > 0) { + MS_LOG(EXCEPTION) << "There is a duplicate name: " << output_name; + } + nodeName_.insert(output_name); node_proto->add_output(output_name); node_proto->set_name(output_name); node_proto->set_domain(node->fullname_with_scope()); @@ -475,7 +499,9 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons std::string type_name = GetOpTypeName(op); node_proto->set_op_type(type_name); last_node_ = node_proto; + // Maybe Tensor or Function or nullptr SetShapeToNodeProto(node, node_proto); + (void)std::for_each(input_names.begin(), input_names.end(), [&node_proto](const string &name) { node_proto->add_input(name); }); @@ -490,13 +516,17 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value); SetValueToAttributeProto(attr_value, attr_proto); } - } else { - MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name(); } } std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) { std::string node_name = GetUniqueNodeName(node); + // FuncGraph will be added to functions and the input name is the function name. + if (IsValueNode(node)) { + FuncGraphPtr fg = GetValueNode(node); + todo_.push_back(fg); + return fg->ToString(); + } if (node->isa()) { // When node input is a ValueNode, need to create a Constant Node mind_ir::NodeProto *node_proto = graph_proto->add_node(); @@ -539,7 +569,12 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { if ((node != nullptr) && (node->func_graph() != nullptr)) { node_name = node->func_graph()->ToString() + ":"; } - node_name += node->ToString(); + if (node->isa()) { + // Needn't value + node_name += node->AnfNode::ToString(); + } else { + node_name += node->ToString(); + } MS_LOG(DEBUG) << "GetNodeName: " << node_name; return node_name; } diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index d44d8eac9c0..c38868d0d42 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -635,14 +635,12 @@ bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_ const mind_ir::AttributeProto &attr_proto) { const std::string &ref_attr_name = attr_proto.ref_attr_name(); if (ref_attr_name.find("UMonad") != std::string::npos) { - const ValuePtr kUMonad = std::make_shared(); auto monad_abs = kUMonad->ToAbstract(); auto new_value_node = NewValueNode(kUMonad); MS_EXCEPTION_IF_NULL(new_value_node); new_value_node->set_abstract(monad_abs); anfnode_build_map_[value_node_name] = new_value_node; } else if (ref_attr_name.find("IOMonad") != std::string::npos) { - const ValuePtr kIOMonad = std::make_shared(); auto monad_abs = kIOMonad->ToAbstract(); auto new_value_node = NewValueNode(kIOMonad); MS_EXCEPTION_IF_NULL(new_value_node); @@ -768,17 +766,22 @@ std::unordered_map MSANFModelParser::Get return kv; } -CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, - const mind_ir::NodeProto &node_proto) { - MS_EXCEPTION_IF_NULL(outputFuncGraph); - if (!node_proto.has_op_type()) { - MS_LOG(ERROR) << "Get CNode op_type failed!"; - return nullptr; - } - const std::string &node_name = node_proto.output(0); - const std::string &fullname_with_scope = node_proto.domain(); +AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_proto) { + const std::string kOperatorTypeFlag = std::string("REF::"); + const size_t kOpTypeFlagSize = kOperatorTypeFlag.length(); const std::string &node_type = node_proto.op_type(); + MS_LOG(DEBUG) << "Process Operator :" << node_type; + // Operator maybe CNode,FuncGraph or Parameter. + if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) { + auto it = anfnode_build_map_.find(node_type.substr(kOpTypeFlagSize)); + if (it != anfnode_build_map_.end()) { + return it->second; + } + MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type; + } + + // Operator is primitive. std::shared_ptr prim; auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap(); if (op_primc_fns.find(node_type) != op_primc_fns.end()) { @@ -794,51 +797,65 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc } } MS_EXCEPTION_IF_NULL(prim); + for (int i = 0; i < node_proto.attribute_size(); ++i) { + const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i); + // CNode abstract + if (attr_proto.ref_attr_name().find("shape:") != string::npos) { + continue; + } + if (!GetAttrValueForCNode(prim, attr_proto)) { + MS_LOG(EXCEPTION) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString(); + } + } + prim->set_attr("is_load", MakeValue(true)); + return std::make_shared(prim); +} + +// Set CNode abstract. +void MSANFModelParser::SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) { + const std::string &node_type = node_proto.op_type(); + // Handle control flow operator. + auto operatorPtr = cnode_ptr->input(0); + // Set abstract of switch(c,f,t),switchLayer(c,tup) and + // partial(func,args) to null + auto prim = GetValueNode(operatorPtr); + if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) || + IsPrimitiveEquals(prim::kPrimPartial, prim)) { + cnode_ptr->set_abstract(nullptr); + return; + } + // Set abstract of switch(c,f,t)() to null + prim = GetCNodePrimitive(operatorPtr); + if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) { + cnode_ptr->set_abstract(nullptr); + return; + } std::unordered_map kv; string shape_ref_attr_name; + for (int i = 0; i < node_proto.attribute_size(); ++i) { const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i); if (attr_proto.ref_attr_name().find("shape:") != string::npos) { shape_ref_attr_name = attr_proto.ref_attr_name(); kv = GetAbstractForCNode(attr_proto); - continue; - } - - if (!GetAttrValueForCNode(prim, attr_proto)) { - MS_LOG(ERROR) << "Get CNode attr failed!"; - return nullptr; + break; } } - std::vector inputs; - inputs.clear(); - for (int i = 0; i < node_proto.input_size(); ++i) { - const std::string &input_name = node_proto.input(i); - if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { - MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; - return nullptr; - } - - inputs.push_back(anfnode_build_map_[input_name]); - } - prim->set_attr("is_load", MakeValue(true)); - CNodePtr cnode_ptr = outputFuncGraph->NewCNode(prim, inputs); - MS_EXCEPTION_IF_NULL(cnode_ptr); - + // Because there is not context in unit test, + // abstract->broaden() is replaced by abstract->set_value(kAnyValue). if (kv.size() == 0) { if (node_type == "UpdateState") { - const ValuePtr kUMonad = std::make_shared(); - auto monad_abs = kUMonad->ToAbstract(); - cnode_ptr->set_abstract(monad_abs); + cnode_ptr->set_abstract(kUMonad->ToAbstract()); } else if (node_type == "Depend") { - const ValuePtr kBool = std::make_shared(true); cnode_ptr->set_abstract(kBool->ToAbstract()); } else { AbstractBasePtrList elem; for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { auto abs = cnode_ptr->input(index)->abstract(); if (abs != nullptr) { + abs->set_value(kAnyValue); elem.push_back(abs); } } @@ -848,22 +865,56 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc } } else if (kv.size() == 1) { std::unordered_map::iterator iter = kv.begin(); - cnode_ptr->set_abstract(iter->second); + if (iter->second != nullptr) { + iter->second->set_value(kAnyValue); + cnode_ptr->set_abstract(iter->second); + } } else { auto abstract = ParserAttrShape(shape_ref_attr_name, kv); if (abstract == nullptr) { + cnode_ptr->set_abstract(nullptr); MS_LOG(ERROR) << "Node's attribute is nullptr."; + } else { + abstract->set_value(kAnyValue); + cnode_ptr->set_abstract(abstract); + } + } +} + +CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const mind_ir::NodeProto &node_proto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + if (!node_proto.has_op_type()) { + MS_LOG(ERROR) << "Get CNode op_type failed!"; + return nullptr; + } + const std::string &node_name = node_proto.output(0); + MS_LOG(DEBUG) << "Process CNode: " << node_name; + // Build inputs. + std::vector inputs; + inputs.push_back(BuildOperatorNode(node_proto)); + for (int i = 0; i < node_proto.input_size(); ++i) { + const std::string &input_name = node_proto.input(i); + if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { + MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; return nullptr; } - cnode_ptr->set_abstract(abstract); + inputs.push_back(anfnode_build_map_[input_name]); } + CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode_ptr); + SetCNodeAbastract(node_proto, cnode_ptr); + + const std::string &fullname_with_scope = node_proto.domain(); string debug_info_name = ParseCNodeName(node_name); auto debug_info_ptr = std::make_shared(debug_info_name); cnode_ptr->set_debug_info(debug_info_ptr); cnode_ptr->set_fullname_with_scope(fullname_with_scope); cnode_ptr->set_load_flag(true); - + if (anfnode_build_map_.count(node_name) > 0) { + MS_LOG(EXCEPTION) << "Duplicate CNode name: " << node_name; + } anfnode_build_map_[node_name] = cnode_ptr; return cnode_ptr; } @@ -991,11 +1042,41 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) { MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; } const mind_ir::GraphProto &graphBuild = model_proto.graph(); + + // Forward declare FuncGraph name + // Compatible with the previous proto. + if (graphBuild.has_name()) { + anfnode_build_map_[graphBuild.name()] = std::make_shared(dstGraph); + } + for (int i = 0; i < model_proto.functions_size(); ++i) { + FuncGraphPtr graph = std::make_shared(); + const auto &graph_proto = model_proto.functions(i); + if (!graph_proto.has_name()) { + MS_LOG(EXCEPTION) << "The function has not a name. Please export mindIR again. "; + } + if (anfnode_build_map_.count(graph_proto.name()) > 0) { + MS_LOG(EXCEPTION) << "There is a duplication function graph name: " << graph_proto.name(); + } + anfnode_build_map_[graph_proto.name()] = std::make_shared(graph); + } + + // Parser the proto. if (!BuildFuncGraph(dstGraph, graphBuild)) { MS_LOG(ERROR) << "Build funcgraph failed!"; return nullptr; } - MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; + MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name(); + for (int i = 0; i < model_proto.functions_size(); ++i) { + const auto &graph_proto = model_proto.functions(i); + FuncGraphPtr graph = GetValueNode(anfnode_build_map_[graph_proto.name()]); + if (!BuildFuncGraph(graph, graph_proto)) { + MS_LOG(ERROR) << "Build funcgraph failed!"; + return nullptr; + } + MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graph_proto.name(); + } + // Release resource + anfnode_build_map_.clear(); return dstGraph; } } // namespace mindspore diff --git a/mindspore/core/load_mindir/anf_model_parser.h b/mindspore/core/load_mindir/anf_model_parser.h index 4d7ce1adecb..dffc78deeff 100644 --- a/mindspore/core/load_mindir/anf_model_parser.h +++ b/mindspore/core/load_mindir/anf_model_parser.h @@ -62,6 +62,8 @@ class MSANFModelParser { ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto); bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto); + AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto); + void SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr); bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto); bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor); diff --git a/mindspore/core/proto/mind_ir.proto b/mindspore/core/proto/mind_ir.proto index cd6182b9e15..8d9c9ecc434 100644 --- a/mindspore/core/proto/mind_ir.proto +++ b/mindspore/core/proto/mind_ir.proto @@ -23,6 +23,9 @@ message AttributeProto { TENSOR = 17; GRAPH = 18; TENSORS = 19; + TUPLE = 20; // tuple + LIST = 21; // list + DICT = 22; // dictionary } optional string name = 1; optional float f = 2; @@ -40,6 +43,8 @@ message AttributeProto { optional string doc_string = 14; optional string ref_attr_name = 15; optional AttributeType type = 16; + repeated AttributeProto values = 17; // tuple, list,dict of value + optional AttributeType type_val = 18; // type type info } @@ -70,6 +75,7 @@ message ModelProto { optional string model_version = 5; optional string doc_string = 6; optional GraphProto graph = 7; + repeated GraphProto functions = 8; // all the graphs without the main graph. } diff --git a/tests/st/control/test_if_mindir.py b/tests/st/control/test_if_mindir.py new file mode 100644 index 00000000000..42f8a2929a2 --- /dev/null +++ b/tests/st/control/test_if_mindir.py @@ -0,0 +1,171 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import context +from mindspore.common.tensor import Tensor +from mindspore.common.initializer import TruncatedNormal +from mindspore.common.parameter import ParameterTuple +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.train.serialization import export, load + + +def weight_variable(): + return TruncatedNormal(0.02) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +class LeNet5(nn.Cell): + def __init__(self): + super(LeNet5, self).__init__() + self.batch_size = 32 + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.reshape(x, (self.batch_size, -1)) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +class WithLossCell(nn.Cell): + def __init__(self, network): + super(WithLossCell, self).__init__(auto_prefix=False) + self.loss = nn.SoftmaxCrossEntropyWithLogits() + self.network = network + + def construct(self, x, label): + predict = self.network(x) + return self.loss(predict, label) + + +class TrainOneStepCell(nn.Cell): + def __init__(self, network): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_train() + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) + self.hyper_map = C.HyperMap() + self.grad = C.GradOperation(get_by_list=True) + + def construct(self, x, label): + weights = self.weights + grads = self.grad(self.network, weights)(x, label) + return self.optimizer(grads) + + +class SingleIfNet(nn.Cell): + def construct(self, x, y): + x += 1 + if x < y: + y += x + else: + y -= x + y += 5 + return y + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_export_lenet_grad_mindir(): + context.set_context(mode=context.GRAPH_MODE) + network = LeNet5() + network.set_train() + predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.zeros([32, 10]).astype(np.float32)) + net = TrainOneStepCell(WithLossCell(network)) + export(net, predict, label, file_name="lenet_grad", file_format='MINDIR') + verify_name = "lenet_grad.mindir" + assert os.path.exists(verify_name) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_load_mindir_and_run(): + context.set_context(mode=context.GRAPH_MODE) + network = LeNet5() + network.set_train() + + inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) + outputs0 = network(inputs0) + + inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32)) + export(network, inputs, file_name="test_lenet_load", file_format='MINDIR') + mindir_name = "test_lenet_load.mindir" + assert os.path.exists(mindir_name) + + graph = load(mindir_name) + loaded_net = nn.GraphCell(graph) + outputs_after_load = loaded_net(inputs0) + assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_single_if(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True, save_graphs_path="./ifir") + network = SingleIfNet() + + x = Tensor(np.array([1]).astype(np.float32)) + y = Tensor(np.array([2]).astype(np.float32)) + origin_out = network(x, y) + + file_name = "if_net" + export(network, x, y, file_name=file_name, file_format='MINDIR') + mindir_name = file_name + ".mindir" + assert os.path.exists(mindir_name) + + graph = load(mindir_name) + loaded_net = nn.GraphCell(graph) + outputs_after_load = loaded_net(x, y) + assert origin_out == outputs_after_load