!21411 mindir: support control flow

Merge pull request !21411 from lanzhineng/mindir_control_flow
This commit is contained in:
i-robot 2021-08-09 02:03:02 +00:00 committed by Gitee
commit 4d71b3bd76
7 changed files with 402 additions and 60 deletions

View File

@ -121,6 +121,28 @@ using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult; using abstract::AnalysisResult;
using mindspore::abstract::AnalysisContextPtr; using mindspore::abstract::AnalysisContextPtr;
inline bool ResetCNodeFromLoad(const AnfNodePtr &node) {
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
// Process partial("DeadNode",args) when the graph is loaded.
auto operatorPtr = node->cast<CNodePtr>()->input(0);
// Set abstract of switch(c,f,t) to null
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
node->set_abstract(nullptr);
return true;
}
// Set abstract of switch(c,f,t)() to null
prim = GetCNodePrimitive(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
node->set_abstract(nullptr);
return true;
}
// Previous inferred value
return true;
}
return false;
}
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
const abstract::AbstractBasePtrList &args_spec, bool clear) { const abstract::AbstractBasePtrList &args_spec, bool clear) {
MS_LOG(DEBUG) << "AbstractAnalyze start"; MS_LOG(DEBUG) << "AbstractAnalyze start";
@ -133,10 +155,19 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
for (auto &node : manager->all_nodes()) { for (auto &node : manager->all_nodes()) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
const AbstractBasePtr &prev_inferred = node->abstract(); const AbstractBasePtr &prev_inferred = node->abstract();
// Keep previous inferred value for CNode if is loaded from MindIR.
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) { // AbstractFunction has context,but contexts in cache have been cleaned.
if (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>()) {
node->set_abstract(nullptr);
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
continue; continue;
} }
// Handle previous inferred value for CNode if is loaded from MindIR
if (res->is_load() && ResetCNodeFromLoad(node)) {
continue;
}
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) { if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
node->set_abstract(nullptr); node->set_abstract(nullptr);
@ -200,6 +231,7 @@ const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
if (graph->has_attr("is_load")) { if (graph->has_attr("is_load")) {
loaded_graph = graph; loaded_graph = graph;
loaded_graph_num += 1; loaded_graph_num += 1;
res->set_is_load(true);
} }
} }
if (loaded_graph_num == 0) { if (loaded_graph_num == 0) {
@ -218,6 +250,8 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
FuncGraphPtr root_graph = *(manager->roots().begin()); FuncGraphPtr root_graph = *(manager->roots().begin());
auto root_inputs = root_graph->get_inputs(); auto root_inputs = root_graph->get_inputs();
auto loaded_inputs = loaded_graph->get_inputs(); auto loaded_inputs = loaded_graph->get_inputs();
MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString();
MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString();
size_t root_inputs_num = root_inputs.size(); size_t root_inputs_num = root_inputs.size();
size_t loaded_inputs_num = loaded_inputs.size(); size_t loaded_inputs_num = loaded_inputs.size();
@ -229,10 +263,18 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
auto root_input = root_inputs[index]; auto root_input = root_inputs[index];
auto loaded_input = loaded_inputs[index]; auto loaded_input = loaded_inputs[index];
MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1);
MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1);
MS_LOG(DEBUG) << "root_input abstract[" << index
<< "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL");
MS_LOG(DEBUG) << "loaded_input abstract [" << index
<< "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL");
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape()); auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape()); auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type()); auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type()); auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
MS_EXCEPTION_IF_NULL(root_shape); MS_EXCEPTION_IF_NULL(root_shape);
MS_EXCEPTION_IF_NULL(loaded_shape); MS_EXCEPTION_IF_NULL(loaded_shape);
MS_EXCEPTION_IF_NULL(root_type); MS_EXCEPTION_IF_NULL(root_type);
@ -454,6 +496,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
} }
// Analyze // Analyze
AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec); AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec);
// The top graph may be replaced by infer, update the top graph when the infer is done // The top graph may be replaced by infer, update the top graph when the infer is done
parse::Parser::UpdateTopFuncGraph(result.context->func_graph()); parse::Parser::UpdateTopFuncGraph(result.context->func_graph());

View File

@ -79,6 +79,8 @@ class Resource : public ResourceBase {
gpu_loopsink_flag_ = flag; gpu_loopsink_flag_ = flag;
gpu_loopsink_size_ = size; gpu_loopsink_size_ = size;
} }
void set_is_load(bool flag) { is_load_ = flag; }
bool is_load() { return is_load_; }
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; } bool gpu_loopsink_flag() { return gpu_loopsink_flag_; }
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; } int64_t gpu_loopsink_size() { return gpu_loopsink_size_; }
// Reclaim resource and clear the cache. // Reclaim resource and clear the cache.
@ -93,6 +95,8 @@ class Resource : public ResourceBase {
py::object input_; py::object input_;
bool is_cleaned_; bool is_cleaned_;
bool gpu_loopsink_flag_{false}; bool gpu_loopsink_flag_{false};
// The func_graph_ is loaded from mindir
bool is_load_{false};
int64_t gpu_loopsink_size_{1}; int64_t gpu_loopsink_size_{1};
}; };

View File

@ -138,6 +138,7 @@ class IrExportBuilder {
mind_ir::NodeProto *last_node_{nullptr}; mind_ir::NodeProto *last_node_{nullptr};
std::list<FuncGraphPtr> todo_; std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, size_t> node_index_map_; std::map<AnfNodePtr, size_t> node_index_map_;
std::set<std::string> nodeName_;
size_t node_index_{0}; size_t node_index_{0};
size_t shape_index_{0}; size_t shape_index_{0};
}; };
@ -145,16 +146,7 @@ class IrExportBuilder {
using IrExporterPtr = std::shared_ptr<IrExporter>; using IrExporterPtr = std::shared_ptr<IrExporter>;
std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) { std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
if ((builder_ == nullptr) || (func_graph == nullptr)) { (void)GetDumpProto(func_graph);
MS_LOG(EXCEPTION) << "Input params is null.";
}
// Export model info
builder_->BuildModelInfo();
// Export model and return string
builder_->BuildModel(func_graph);
return builder_->GetProtoString(func_graph); return builder_->GetProtoString(func_graph);
} }
@ -168,7 +160,6 @@ mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, boo
// Export model and return string // Export model and return string
builder_->BuildModel(func_graph, save_tensor_data); builder_->BuildModel(func_graph, save_tensor_data);
return builder_->Model(); return builder_->Model();
} }
@ -191,16 +182,34 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
graph_proto->set_bprop_hash(func_graph->bprop_hash()); graph_proto->set_bprop_hash(func_graph->bprop_hash());
ResetNodeIndex(); ResetNodeIndex();
todo_.clear(); todo_.clear();
todo_.push_back(func_graph); nodeName_.clear();
// Build the main funcGraph
nodeName_.insert(func_graph->ToString());
BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
std::set<FuncGraphPtr> graphVisited;
graphVisited.insert(func_graph);
while (!todo_.empty()) { while (!todo_.empty()) {
FuncGraphPtr fg = todo_.back(); FuncGraphPtr fg = todo_.back();
todo_.pop_back(); todo_.pop_back();
BuildFuncGraph(fg, graph_proto, save_tensor_data); if (graphVisited.count(fg) > 0) {
continue;
}
if (nodeName_.count(fg->ToString()) > 0) {
MS_LOG(EXCEPTION) << "There is a duplicate name: " << fg->ToString();
}
nodeName_.insert(fg->ToString());
graphVisited.insert(fg);
auto graph = model_.add_functions();
BuildFuncGraph(fg, graph, save_tensor_data);
} }
// Release resource
nodeName_.clear();
} }
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data) { bool save_tensor_data) {
// Export funcGraph name.
graph_proto->set_name(func_graph->ToString());
// Export parameters // Export parameters
// 1. parameters should be mapped to ValueInfoProto // 1. parameters should be mapped to ValueInfoProto
// 2. parameters with default value should be mapped to Initializer // 2. parameters with default value should be mapped to Initializer
@ -232,6 +241,10 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
input_proto->set_name(param_name); input_proto->set_name(param_name);
SetValueInfoProto(param, input_proto); SetValueInfoProto(param, input_proto);
} }
if (nodeName_.count(param_name) > 0) {
MS_LOG(EXCEPTION) << "parameter name is duplicate:" << param_name;
}
nodeName_.insert(param_name);
} }
} }
@ -383,9 +396,13 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
} else if (IsValueNode<FuncGraph>(node)) { } else if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node); FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
todo_.push_back(fg); todo_.push_back(fg);
type_name = fg->ToString(); type_name = "REF::" + fg->ToString();
} else if (node->isa<CNode>() || node->isa<Parameter>()) { } else if (node->isa<CNode>() || node->isa<Parameter>()) {
type_name = node->ToString(); auto nodeName = GetUniqueNodeName(node);
type_name = "REF::" + nodeName;
if (nodeName_.count(nodeName) == 0) {
MS_LOG(EXCEPTION) << "There is not the name: " << nodeName;
}
} else { } else {
MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name(); MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name();
} }
@ -424,6 +441,9 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_UINT64); tensor_proto->set_data_type(mind_ir::TensorProto_DataType_UINT64);
tensor_proto->add_dims(1); tensor_proto->add_dims(1);
} }
} else if (type->isa<Function>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_GRAPH);
*seq_string += type->type_name() + ",";
} else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) { } else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
*seq_string += type->type_name() + ","; *seq_string += type->type_name() + ",";
} else { } else {
@ -468,6 +488,10 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
// Build cnode // Build cnode
mind_ir::NodeProto *node_proto = graph_proto->add_node(); mind_ir::NodeProto *node_proto = graph_proto->add_node();
std::string output_name = GetUniqueNodeName(node); std::string output_name = GetUniqueNodeName(node);
if (nodeName_.count(output_name) > 0) {
MS_LOG(EXCEPTION) << "There is a duplicate name: " << output_name;
}
nodeName_.insert(output_name);
node_proto->add_output(output_name); node_proto->add_output(output_name);
node_proto->set_name(output_name); node_proto->set_name(output_name);
node_proto->set_domain(node->fullname_with_scope()); node_proto->set_domain(node->fullname_with_scope());
@ -475,7 +499,9 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
std::string type_name = GetOpTypeName(op); std::string type_name = GetOpTypeName(op);
node_proto->set_op_type(type_name); node_proto->set_op_type(type_name);
last_node_ = node_proto; last_node_ = node_proto;
// Maybe Tensor or Function or nullptr
SetShapeToNodeProto(node, node_proto); SetShapeToNodeProto(node, node_proto);
(void)std::for_each(input_names.begin(), input_names.end(), (void)std::for_each(input_names.begin(), input_names.end(),
[&node_proto](const string &name) { node_proto->add_input(name); }); [&node_proto](const string &name) { node_proto->add_input(name); });
@ -490,13 +516,17 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value); CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
SetValueToAttributeProto(attr_value, attr_proto); SetValueToAttributeProto(attr_value, attr_proto);
} }
} else {
MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name();
} }
} }
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) { std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
std::string node_name = GetUniqueNodeName(node); std::string node_name = GetUniqueNodeName(node);
// FuncGraph will be added to functions and the input name is the function name.
if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
todo_.push_back(fg);
return fg->ToString();
}
if (node->isa<ValueNode>()) { if (node->isa<ValueNode>()) {
// When node input is a ValueNode, need to create a Constant Node // When node input is a ValueNode, need to create a Constant Node
mind_ir::NodeProto *node_proto = graph_proto->add_node(); mind_ir::NodeProto *node_proto = graph_proto->add_node();
@ -539,7 +569,12 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
if ((node != nullptr) && (node->func_graph() != nullptr)) { if ((node != nullptr) && (node->func_graph() != nullptr)) {
node_name = node->func_graph()->ToString() + ":"; node_name = node->func_graph()->ToString() + ":";
} }
node_name += node->ToString(); if (node->isa<ValueNode>()) {
// Needn't value
node_name += node->AnfNode::ToString();
} else {
node_name += node->ToString();
}
MS_LOG(DEBUG) << "GetNodeName: " << node_name; MS_LOG(DEBUG) << "GetNodeName: " << node_name;
return node_name; return node_name;
} }

View File

@ -635,14 +635,12 @@ bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_
const mind_ir::AttributeProto &attr_proto) { const mind_ir::AttributeProto &attr_proto) {
const std::string &ref_attr_name = attr_proto.ref_attr_name(); const std::string &ref_attr_name = attr_proto.ref_attr_name();
if (ref_attr_name.find("UMonad") != std::string::npos) { if (ref_attr_name.find("UMonad") != std::string::npos) {
const ValuePtr kUMonad = std::make_shared<UMonad>();
auto monad_abs = kUMonad->ToAbstract(); auto monad_abs = kUMonad->ToAbstract();
auto new_value_node = NewValueNode(kUMonad); auto new_value_node = NewValueNode(kUMonad);
MS_EXCEPTION_IF_NULL(new_value_node); MS_EXCEPTION_IF_NULL(new_value_node);
new_value_node->set_abstract(monad_abs); new_value_node->set_abstract(monad_abs);
anfnode_build_map_[value_node_name] = new_value_node; anfnode_build_map_[value_node_name] = new_value_node;
} else if (ref_attr_name.find("IOMonad") != std::string::npos) { } else if (ref_attr_name.find("IOMonad") != std::string::npos) {
const ValuePtr kIOMonad = std::make_shared<IOMonad>();
auto monad_abs = kIOMonad->ToAbstract(); auto monad_abs = kIOMonad->ToAbstract();
auto new_value_node = NewValueNode(kIOMonad); auto new_value_node = NewValueNode(kIOMonad);
MS_EXCEPTION_IF_NULL(new_value_node); MS_EXCEPTION_IF_NULL(new_value_node);
@ -768,17 +766,22 @@ std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::Get
return kv; return kv;
} }
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_proto) {
const mind_ir::NodeProto &node_proto) { const std::string kOperatorTypeFlag = std::string("REF::");
MS_EXCEPTION_IF_NULL(outputFuncGraph); const size_t kOpTypeFlagSize = kOperatorTypeFlag.length();
if (!node_proto.has_op_type()) {
MS_LOG(ERROR) << "Get CNode op_type failed!";
return nullptr;
}
const std::string &node_name = node_proto.output(0);
const std::string &fullname_with_scope = node_proto.domain();
const std::string &node_type = node_proto.op_type(); const std::string &node_type = node_proto.op_type();
MS_LOG(DEBUG) << "Process Operator :" << node_type;
// Operator maybe CNode,FuncGraph or Parameter.
if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
auto it = anfnode_build_map_.find(node_type.substr(kOpTypeFlagSize));
if (it != anfnode_build_map_.end()) {
return it->second;
}
MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
}
// Operator is primitive.
std::shared_ptr<Primitive> prim; std::shared_ptr<Primitive> prim;
auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap(); auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
if (op_primc_fns.find(node_type) != op_primc_fns.end()) { if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
@ -794,51 +797,65 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
} }
} }
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
// CNode abstract
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
continue;
}
if (!GetAttrValueForCNode(prim, attr_proto)) {
MS_LOG(EXCEPTION) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString();
}
}
prim->set_attr("is_load", MakeValue(true));
return std::make_shared<ValueNode>(prim);
}
// Set CNode abstract.
void MSANFModelParser::SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) {
const std::string &node_type = node_proto.op_type();
// Handle control flow operator.
auto operatorPtr = cnode_ptr->input(0);
// Set abstract of switch(c,f,t),switchLayer(c,tup) and
// partial(func,args) to null
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
cnode_ptr->set_abstract(nullptr);
return;
}
// Set abstract of switch(c,f,t)() to null
prim = GetCNodePrimitive(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
cnode_ptr->set_abstract(nullptr);
return;
}
std::unordered_map<std::string, abstract::AbstractBasePtr> kv; std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
string shape_ref_attr_name; string shape_ref_attr_name;
for (int i = 0; i < node_proto.attribute_size(); ++i) { for (int i = 0; i < node_proto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i); const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
if (attr_proto.ref_attr_name().find("shape:") != string::npos) { if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
shape_ref_attr_name = attr_proto.ref_attr_name(); shape_ref_attr_name = attr_proto.ref_attr_name();
kv = GetAbstractForCNode(attr_proto); kv = GetAbstractForCNode(attr_proto);
continue; break;
}
if (!GetAttrValueForCNode(prim, attr_proto)) {
MS_LOG(ERROR) << "Get CNode attr failed!";
return nullptr;
} }
} }
std::vector<AnfNodePtr> inputs; // Because there is not context in unit test,
inputs.clear(); // abstract->broaden() is replaced by abstract->set_value(kAnyValue).
for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
return nullptr;
}
inputs.push_back(anfnode_build_map_[input_name]);
}
prim->set_attr("is_load", MakeValue(true));
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(prim, inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
if (kv.size() == 0) { if (kv.size() == 0) {
if (node_type == "UpdateState") { if (node_type == "UpdateState") {
const ValuePtr kUMonad = std::make_shared<UMonad>(); cnode_ptr->set_abstract(kUMonad->ToAbstract());
auto monad_abs = kUMonad->ToAbstract();
cnode_ptr->set_abstract(monad_abs);
} else if (node_type == "Depend") { } else if (node_type == "Depend") {
const ValuePtr kBool = std::make_shared<BoolImm>(true);
cnode_ptr->set_abstract(kBool->ToAbstract()); cnode_ptr->set_abstract(kBool->ToAbstract());
} else { } else {
AbstractBasePtrList elem; AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
auto abs = cnode_ptr->input(index)->abstract(); auto abs = cnode_ptr->input(index)->abstract();
if (abs != nullptr) { if (abs != nullptr) {
abs->set_value(kAnyValue);
elem.push_back(abs); elem.push_back(abs);
} }
} }
@ -848,22 +865,56 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
} }
} else if (kv.size() == 1) { } else if (kv.size() == 1) {
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin(); std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
cnode_ptr->set_abstract(iter->second); if (iter->second != nullptr) {
iter->second->set_value(kAnyValue);
cnode_ptr->set_abstract(iter->second);
}
} else { } else {
auto abstract = ParserAttrShape(shape_ref_attr_name, kv); auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
if (abstract == nullptr) { if (abstract == nullptr) {
cnode_ptr->set_abstract(nullptr);
MS_LOG(ERROR) << "Node's attribute is nullptr."; MS_LOG(ERROR) << "Node's attribute is nullptr.";
} else {
abstract->set_value(kAnyValue);
cnode_ptr->set_abstract(abstract);
}
}
}
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::NodeProto &node_proto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
if (!node_proto.has_op_type()) {
MS_LOG(ERROR) << "Get CNode op_type failed!";
return nullptr;
}
const std::string &node_name = node_proto.output(0);
MS_LOG(DEBUG) << "Process CNode: " << node_name;
// Build inputs.
std::vector<AnfNodePtr> inputs;
inputs.push_back(BuildOperatorNode(node_proto));
for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
return nullptr; return nullptr;
} }
cnode_ptr->set_abstract(abstract); inputs.push_back(anfnode_build_map_[input_name]);
} }
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
SetCNodeAbastract(node_proto, cnode_ptr);
const std::string &fullname_with_scope = node_proto.domain();
string debug_info_name = ParseCNodeName(node_name); string debug_info_name = ParseCNodeName(node_name);
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name); auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
cnode_ptr->set_debug_info(debug_info_ptr); cnode_ptr->set_debug_info(debug_info_ptr);
cnode_ptr->set_fullname_with_scope(fullname_with_scope); cnode_ptr->set_fullname_with_scope(fullname_with_scope);
cnode_ptr->set_load_flag(true); cnode_ptr->set_load_flag(true);
if (anfnode_build_map_.count(node_name) > 0) {
MS_LOG(EXCEPTION) << "Duplicate CNode name: " << node_name;
}
anfnode_build_map_[node_name] = cnode_ptr; anfnode_build_map_[node_name] = cnode_ptr;
return cnode_ptr; return cnode_ptr;
} }
@ -991,11 +1042,41 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
} }
const mind_ir::GraphProto &graphBuild = model_proto.graph(); const mind_ir::GraphProto &graphBuild = model_proto.graph();
// Forward declare FuncGraph name
// Compatible with the previous proto.
if (graphBuild.has_name()) {
anfnode_build_map_[graphBuild.name()] = std::make_shared<ValueNode>(dstGraph);
}
for (int i = 0; i < model_proto.functions_size(); ++i) {
FuncGraphPtr graph = std::make_shared<FuncGraph>();
const auto &graph_proto = model_proto.functions(i);
if (!graph_proto.has_name()) {
MS_LOG(EXCEPTION) << "The function has not a name. Please export mindIR again. ";
}
if (anfnode_build_map_.count(graph_proto.name()) > 0) {
MS_LOG(EXCEPTION) << "There is a duplication function graph name: " << graph_proto.name();
}
anfnode_build_map_[graph_proto.name()] = std::make_shared<ValueNode>(graph);
}
// Parser the proto.
if (!BuildFuncGraph(dstGraph, graphBuild)) { if (!BuildFuncGraph(dstGraph, graphBuild)) {
MS_LOG(ERROR) << "Build funcgraph failed!"; MS_LOG(ERROR) << "Build funcgraph failed!";
return nullptr; return nullptr;
} }
MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name();
for (int i = 0; i < model_proto.functions_size(); ++i) {
const auto &graph_proto = model_proto.functions(i);
FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[graph_proto.name()]);
if (!BuildFuncGraph(graph, graph_proto)) {
MS_LOG(ERROR) << "Build funcgraph failed!";
return nullptr;
}
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graph_proto.name();
}
// Release resource
anfnode_build_map_.clear();
return dstGraph; return dstGraph;
} }
} // namespace mindspore } // namespace mindspore

View File

@ -62,6 +62,8 @@ class MSANFModelParser {
ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto); ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto); bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
void SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto); bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto);
bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor); bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor);

View File

@ -23,6 +23,9 @@ message AttributeProto {
TENSOR = 17; TENSOR = 17;
GRAPH = 18; GRAPH = 18;
TENSORS = 19; TENSORS = 19;
TUPLE = 20; // tuple
LIST = 21; // list
DICT = 22; // dictionary
} }
optional string name = 1; optional string name = 1;
optional float f = 2; optional float f = 2;
@ -40,6 +43,8 @@ message AttributeProto {
optional string doc_string = 14; optional string doc_string = 14;
optional string ref_attr_name = 15; optional string ref_attr_name = 15;
optional AttributeType type = 16; optional AttributeType type = 16;
repeated AttributeProto values = 17; // tuple, list,dict of value
optional AttributeType type_val = 18; // type type info
} }
@ -70,6 +75,7 @@ message ModelProto {
optional string model_version = 5; optional string model_version = 5;
optional string doc_string = 6; optional string doc_string = 6;
optional GraphProto graph = 7; optional GraphProto graph = 7;
repeated GraphProto functions = 8; // all the graphs without the main graph.
} }

View File

@ -0,0 +1,171 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import TruncatedNormal
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.train.serialization import export, load
def weight_variable():
return TruncatedNormal(0.02)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class WithLossCell(nn.Cell):
def __init__(self, network):
super(WithLossCell, self).__init__(auto_prefix=False)
self.loss = nn.SoftmaxCrossEntropyWithLogits()
self.network = network
def construct(self, x, label):
predict = self.network(x)
return self.loss(predict, label)
class TrainOneStepCell(nn.Cell):
def __init__(self, network):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_train()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation(get_by_list=True)
def construct(self, x, label):
weights = self.weights
grads = self.grad(self.network, weights)(x, label)
return self.optimizer(grads)
class SingleIfNet(nn.Cell):
def construct(self, x, y):
x += 1
if x < y:
y += x
else:
y -= x
y += 5
return y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_export_lenet_grad_mindir():
context.set_context(mode=context.GRAPH_MODE)
network = LeNet5()
network.set_train()
predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([32, 10]).astype(np.float32))
net = TrainOneStepCell(WithLossCell(network))
export(net, predict, label, file_name="lenet_grad", file_format='MINDIR')
verify_name = "lenet_grad.mindir"
assert os.path.exists(verify_name)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_load_mindir_and_run():
context.set_context(mode=context.GRAPH_MODE)
network = LeNet5()
network.set_train()
inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
outputs0 = network(inputs0)
inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32))
export(network, inputs, file_name="test_lenet_load", file_format='MINDIR')
mindir_name = "test_lenet_load.mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(inputs0)
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_if():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True, save_graphs_path="./ifir")
network = SingleIfNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
origin_out = network(x, y)
file_name = "if_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(x, y)
assert origin_out == outputs_after_load