mindIR support control flow

This commit is contained in:
lanzhineng 2021-08-08 16:04:21 +08:00
parent 5e1d8c1a87
commit dc63dea103
7 changed files with 402 additions and 60 deletions

View File

@ -121,6 +121,28 @@ using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult;
using mindspore::abstract::AnalysisContextPtr;
inline bool ResetCNodeFromLoad(const AnfNodePtr &node) {
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
// Process partial("DeadNode",args) when the graph is loaded.
auto operatorPtr = node->cast<CNodePtr>()->input(0);
// Set abstract of switch(c,f,t) to null
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
node->set_abstract(nullptr);
return true;
}
// Set abstract of switch(c,f,t)() to null
prim = GetCNodePrimitive(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
node->set_abstract(nullptr);
return true;
}
// Previous inferred value
return true;
}
return false;
}
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
const abstract::AbstractBasePtrList &args_spec, bool clear) {
MS_LOG(DEBUG) << "AbstractAnalyze start";
@ -133,10 +155,19 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
for (auto &node : manager->all_nodes()) {
MS_EXCEPTION_IF_NULL(node);
const AbstractBasePtr &prev_inferred = node->abstract();
// Keep previous inferred value for CNode if is loaded from MindIR.
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
// AbstractFunction has context,but contexts in cache have been cleaned.
if (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>()) {
node->set_abstract(nullptr);
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
continue;
}
// Handle previous inferred value for CNode if is loaded from MindIR
if (res->is_load() && ResetCNodeFromLoad(node)) {
continue;
}
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
node->set_abstract(nullptr);
@ -200,6 +231,7 @@ const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
if (graph->has_attr("is_load")) {
loaded_graph = graph;
loaded_graph_num += 1;
res->set_is_load(true);
}
}
if (loaded_graph_num == 0) {
@ -218,6 +250,8 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
FuncGraphPtr root_graph = *(manager->roots().begin());
auto root_inputs = root_graph->get_inputs();
auto loaded_inputs = loaded_graph->get_inputs();
MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString();
MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString();
size_t root_inputs_num = root_inputs.size();
size_t loaded_inputs_num = loaded_inputs.size();
@ -229,10 +263,18 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
auto root_input = root_inputs[index];
auto loaded_input = loaded_inputs[index];
MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1);
MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1);
MS_LOG(DEBUG) << "root_input abstract[" << index
<< "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL");
MS_LOG(DEBUG) << "loaded_input abstract [" << index
<< "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL");
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
MS_EXCEPTION_IF_NULL(root_shape);
MS_EXCEPTION_IF_NULL(loaded_shape);
MS_EXCEPTION_IF_NULL(root_type);
@ -454,6 +496,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
}
// Analyze
AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec);
// The top graph may be replaced by infer, update the top graph when the infer is done
parse::Parser::UpdateTopFuncGraph(result.context->func_graph());

View File

@ -79,6 +79,8 @@ class Resource : public ResourceBase {
gpu_loopsink_flag_ = flag;
gpu_loopsink_size_ = size;
}
void set_is_load(bool flag) { is_load_ = flag; }
bool is_load() { return is_load_; }
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; }
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; }
// Reclaim resource and clear the cache.
@ -93,6 +95,8 @@ class Resource : public ResourceBase {
py::object input_;
bool is_cleaned_;
bool gpu_loopsink_flag_{false};
// The func_graph_ is loaded from mindir
bool is_load_{false};
int64_t gpu_loopsink_size_{1};
};

View File

@ -138,6 +138,7 @@ class IrExportBuilder {
mind_ir::NodeProto *last_node_{nullptr};
std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, size_t> node_index_map_;
std::set<std::string> nodeName_;
size_t node_index_{0};
size_t shape_index_{0};
};
@ -145,16 +146,7 @@ class IrExportBuilder {
using IrExporterPtr = std::shared_ptr<IrExporter>;
std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
if ((builder_ == nullptr) || (func_graph == nullptr)) {
MS_LOG(EXCEPTION) << "Input params is null.";
}
// Export model info
builder_->BuildModelInfo();
// Export model and return string
builder_->BuildModel(func_graph);
(void)GetDumpProto(func_graph);
return builder_->GetProtoString(func_graph);
}
@ -168,7 +160,6 @@ mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, boo
// Export model and return string
builder_->BuildModel(func_graph, save_tensor_data);
return builder_->Model();
}
@ -191,16 +182,34 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
graph_proto->set_bprop_hash(func_graph->bprop_hash());
ResetNodeIndex();
todo_.clear();
todo_.push_back(func_graph);
nodeName_.clear();
// Build the main funcGraph
nodeName_.insert(func_graph->ToString());
BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
std::set<FuncGraphPtr> graphVisited;
graphVisited.insert(func_graph);
while (!todo_.empty()) {
FuncGraphPtr fg = todo_.back();
todo_.pop_back();
BuildFuncGraph(fg, graph_proto, save_tensor_data);
if (graphVisited.count(fg) > 0) {
continue;
}
if (nodeName_.count(fg->ToString()) > 0) {
MS_LOG(EXCEPTION) << "There is a duplicate name: " << fg->ToString();
}
nodeName_.insert(fg->ToString());
graphVisited.insert(fg);
auto graph = model_.add_functions();
BuildFuncGraph(fg, graph, save_tensor_data);
}
// Release resource
nodeName_.clear();
}
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
bool save_tensor_data) {
// Export funcGraph name.
graph_proto->set_name(func_graph->ToString());
// Export parameters
// 1. parameters should be mapped to ValueInfoProto
// 2. parameters with default value should be mapped to Initializer
@ -232,6 +241,10 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
input_proto->set_name(param_name);
SetValueInfoProto(param, input_proto);
}
if (nodeName_.count(param_name) > 0) {
MS_LOG(EXCEPTION) << "parameter name is duplicate:" << param_name;
}
nodeName_.insert(param_name);
}
}
@ -383,9 +396,13 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
} else if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
todo_.push_back(fg);
type_name = fg->ToString();
type_name = "REF::" + fg->ToString();
} else if (node->isa<CNode>() || node->isa<Parameter>()) {
type_name = node->ToString();
auto nodeName = GetUniqueNodeName(node);
type_name = "REF::" + nodeName;
if (nodeName_.count(nodeName) == 0) {
MS_LOG(EXCEPTION) << "There is not the name: " << nodeName;
}
} else {
MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name();
}
@ -424,6 +441,9 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_UINT64);
tensor_proto->add_dims(1);
}
} else if (type->isa<Function>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_GRAPH);
*seq_string += type->type_name() + ",";
} else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
*seq_string += type->type_name() + ",";
} else {
@ -468,6 +488,10 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
// Build cnode
mind_ir::NodeProto *node_proto = graph_proto->add_node();
std::string output_name = GetUniqueNodeName(node);
if (nodeName_.count(output_name) > 0) {
MS_LOG(EXCEPTION) << "There is a duplicate name: " << output_name;
}
nodeName_.insert(output_name);
node_proto->add_output(output_name);
node_proto->set_name(output_name);
node_proto->set_domain(node->fullname_with_scope());
@ -475,7 +499,9 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
std::string type_name = GetOpTypeName(op);
node_proto->set_op_type(type_name);
last_node_ = node_proto;
// Maybe Tensor or Function or nullptr
SetShapeToNodeProto(node, node_proto);
(void)std::for_each(input_names.begin(), input_names.end(),
[&node_proto](const string &name) { node_proto->add_input(name); });
@ -490,13 +516,17 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
SetValueToAttributeProto(attr_value, attr_proto);
}
} else {
MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name();
}
}
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
std::string node_name = GetUniqueNodeName(node);
// FuncGraph will be added to functions and the input name is the function name.
if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
todo_.push_back(fg);
return fg->ToString();
}
if (node->isa<ValueNode>()) {
// When node input is a ValueNode, need to create a Constant Node
mind_ir::NodeProto *node_proto = graph_proto->add_node();
@ -539,7 +569,12 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
if ((node != nullptr) && (node->func_graph() != nullptr)) {
node_name = node->func_graph()->ToString() + ":";
}
node_name += node->ToString();
if (node->isa<ValueNode>()) {
// Needn't value
node_name += node->AnfNode::ToString();
} else {
node_name += node->ToString();
}
MS_LOG(DEBUG) << "GetNodeName: " << node_name;
return node_name;
}

View File

@ -635,14 +635,12 @@ bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_
const mind_ir::AttributeProto &attr_proto) {
const std::string &ref_attr_name = attr_proto.ref_attr_name();
if (ref_attr_name.find("UMonad") != std::string::npos) {
const ValuePtr kUMonad = std::make_shared<UMonad>();
auto monad_abs = kUMonad->ToAbstract();
auto new_value_node = NewValueNode(kUMonad);
MS_EXCEPTION_IF_NULL(new_value_node);
new_value_node->set_abstract(monad_abs);
anfnode_build_map_[value_node_name] = new_value_node;
} else if (ref_attr_name.find("IOMonad") != std::string::npos) {
const ValuePtr kIOMonad = std::make_shared<IOMonad>();
auto monad_abs = kIOMonad->ToAbstract();
auto new_value_node = NewValueNode(kIOMonad);
MS_EXCEPTION_IF_NULL(new_value_node);
@ -768,17 +766,22 @@ std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::Get
return kv;
}
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::NodeProto &node_proto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
if (!node_proto.has_op_type()) {
MS_LOG(ERROR) << "Get CNode op_type failed!";
return nullptr;
}
const std::string &node_name = node_proto.output(0);
const std::string &fullname_with_scope = node_proto.domain();
AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_proto) {
const std::string kOperatorTypeFlag = std::string("REF::");
const size_t kOpTypeFlagSize = kOperatorTypeFlag.length();
const std::string &node_type = node_proto.op_type();
MS_LOG(DEBUG) << "Process Operator :" << node_type;
// Operator maybe CNode,FuncGraph or Parameter.
if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
auto it = anfnode_build_map_.find(node_type.substr(kOpTypeFlagSize));
if (it != anfnode_build_map_.end()) {
return it->second;
}
MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
}
// Operator is primitive.
std::shared_ptr<Primitive> prim;
auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
@ -794,51 +797,65 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
}
}
MS_EXCEPTION_IF_NULL(prim);
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
// CNode abstract
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
continue;
}
if (!GetAttrValueForCNode(prim, attr_proto)) {
MS_LOG(EXCEPTION) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString();
}
}
prim->set_attr("is_load", MakeValue(true));
return std::make_shared<ValueNode>(prim);
}
// Set CNode abstract.
void MSANFModelParser::SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) {
const std::string &node_type = node_proto.op_type();
// Handle control flow operator.
auto operatorPtr = cnode_ptr->input(0);
// Set abstract of switch(c,f,t),switchLayer(c,tup) and
// partial(func,args) to null
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
cnode_ptr->set_abstract(nullptr);
return;
}
// Set abstract of switch(c,f,t)() to null
prim = GetCNodePrimitive(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
cnode_ptr->set_abstract(nullptr);
return;
}
std::unordered_map<std::string, abstract::AbstractBasePtr> kv;
string shape_ref_attr_name;
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
shape_ref_attr_name = attr_proto.ref_attr_name();
kv = GetAbstractForCNode(attr_proto);
continue;
}
if (!GetAttrValueForCNode(prim, attr_proto)) {
MS_LOG(ERROR) << "Get CNode attr failed!";
return nullptr;
break;
}
}
std::vector<AnfNodePtr> inputs;
inputs.clear();
for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
return nullptr;
}
inputs.push_back(anfnode_build_map_[input_name]);
}
prim->set_attr("is_load", MakeValue(true));
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(prim, inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
// Because there is not context in unit test,
// abstract->broaden() is replaced by abstract->set_value(kAnyValue).
if (kv.size() == 0) {
if (node_type == "UpdateState") {
const ValuePtr kUMonad = std::make_shared<UMonad>();
auto monad_abs = kUMonad->ToAbstract();
cnode_ptr->set_abstract(monad_abs);
cnode_ptr->set_abstract(kUMonad->ToAbstract());
} else if (node_type == "Depend") {
const ValuePtr kBool = std::make_shared<BoolImm>(true);
cnode_ptr->set_abstract(kBool->ToAbstract());
} else {
AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
auto abs = cnode_ptr->input(index)->abstract();
if (abs != nullptr) {
abs->set_value(kAnyValue);
elem.push_back(abs);
}
}
@ -848,22 +865,56 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
}
} else if (kv.size() == 1) {
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
cnode_ptr->set_abstract(iter->second);
if (iter->second != nullptr) {
iter->second->set_value(kAnyValue);
cnode_ptr->set_abstract(iter->second);
}
} else {
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
if (abstract == nullptr) {
cnode_ptr->set_abstract(nullptr);
MS_LOG(ERROR) << "Node's attribute is nullptr.";
} else {
abstract->set_value(kAnyValue);
cnode_ptr->set_abstract(abstract);
}
}
}
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::NodeProto &node_proto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
if (!node_proto.has_op_type()) {
MS_LOG(ERROR) << "Get CNode op_type failed!";
return nullptr;
}
const std::string &node_name = node_proto.output(0);
MS_LOG(DEBUG) << "Process CNode: " << node_name;
// Build inputs.
std::vector<AnfNodePtr> inputs;
inputs.push_back(BuildOperatorNode(node_proto));
for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
return nullptr;
}
cnode_ptr->set_abstract(abstract);
inputs.push_back(anfnode_build_map_[input_name]);
}
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
SetCNodeAbastract(node_proto, cnode_ptr);
const std::string &fullname_with_scope = node_proto.domain();
string debug_info_name = ParseCNodeName(node_name);
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
cnode_ptr->set_debug_info(debug_info_ptr);
cnode_ptr->set_fullname_with_scope(fullname_with_scope);
cnode_ptr->set_load_flag(true);
if (anfnode_build_map_.count(node_name) > 0) {
MS_LOG(EXCEPTION) << "Duplicate CNode name: " << node_name;
}
anfnode_build_map_[node_name] = cnode_ptr;
return cnode_ptr;
}
@ -991,11 +1042,41 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
}
const mind_ir::GraphProto &graphBuild = model_proto.graph();
// Forward declare FuncGraph name
// Compatible with the previous proto.
if (graphBuild.has_name()) {
anfnode_build_map_[graphBuild.name()] = std::make_shared<ValueNode>(dstGraph);
}
for (int i = 0; i < model_proto.functions_size(); ++i) {
FuncGraphPtr graph = std::make_shared<FuncGraph>();
const auto &graph_proto = model_proto.functions(i);
if (!graph_proto.has_name()) {
MS_LOG(EXCEPTION) << "The function has not a name. Please export mindIR again. ";
}
if (anfnode_build_map_.count(graph_proto.name()) > 0) {
MS_LOG(EXCEPTION) << "There is a duplication function graph name: " << graph_proto.name();
}
anfnode_build_map_[graph_proto.name()] = std::make_shared<ValueNode>(graph);
}
// Parser the proto.
if (!BuildFuncGraph(dstGraph, graphBuild)) {
MS_LOG(ERROR) << "Build funcgraph failed!";
return nullptr;
}
MS_LOG(INFO) << "Parse pb to build FuncGraph Success!";
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name();
for (int i = 0; i < model_proto.functions_size(); ++i) {
const auto &graph_proto = model_proto.functions(i);
FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[graph_proto.name()]);
if (!BuildFuncGraph(graph, graph_proto)) {
MS_LOG(ERROR) << "Build funcgraph failed!";
return nullptr;
}
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graph_proto.name();
}
// Release resource
anfnode_build_map_.clear();
return dstGraph;
}
} // namespace mindspore

View File

@ -62,6 +62,8 @@ class MSANFModelParser {
ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
void SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto);
bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor);

View File

@ -23,6 +23,9 @@ message AttributeProto {
TENSOR = 17;
GRAPH = 18;
TENSORS = 19;
TUPLE = 20; // tuple
LIST = 21; // list
DICT = 22; // dictionary
}
optional string name = 1;
optional float f = 2;
@ -40,6 +43,8 @@ message AttributeProto {
optional string doc_string = 14;
optional string ref_attr_name = 15;
optional AttributeType type = 16;
repeated AttributeProto values = 17; // tuple, list,dict of value
optional AttributeType type_val = 18; // type type info
}
@ -70,6 +75,7 @@ message ModelProto {
optional string model_version = 5;
optional string doc_string = 6;
optional GraphProto graph = 7;
repeated GraphProto functions = 8; // all the graphs without the main graph.
}

View File

@ -0,0 +1,171 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import TruncatedNormal
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.train.serialization import export, load
def weight_variable():
return TruncatedNormal(0.02)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class WithLossCell(nn.Cell):
def __init__(self, network):
super(WithLossCell, self).__init__(auto_prefix=False)
self.loss = nn.SoftmaxCrossEntropyWithLogits()
self.network = network
def construct(self, x, label):
predict = self.network(x)
return self.loss(predict, label)
class TrainOneStepCell(nn.Cell):
def __init__(self, network):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_train()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation(get_by_list=True)
def construct(self, x, label):
weights = self.weights
grads = self.grad(self.network, weights)(x, label)
return self.optimizer(grads)
class SingleIfNet(nn.Cell):
def construct(self, x, y):
x += 1
if x < y:
y += x
else:
y -= x
y += 5
return y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_export_lenet_grad_mindir():
context.set_context(mode=context.GRAPH_MODE)
network = LeNet5()
network.set_train()
predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([32, 10]).astype(np.float32))
net = TrainOneStepCell(WithLossCell(network))
export(net, predict, label, file_name="lenet_grad", file_format='MINDIR')
verify_name = "lenet_grad.mindir"
assert os.path.exists(verify_name)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_load_mindir_and_run():
context.set_context(mode=context.GRAPH_MODE)
network = LeNet5()
network.set_train()
inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
outputs0 = network(inputs0)
inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32))
export(network, inputs, file_name="test_lenet_load", file_format='MINDIR')
mindir_name = "test_lenet_load.mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(inputs0)
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_if():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True, save_graphs_path="./ifir")
network = SingleIfNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
origin_out = network(x, y)
file_name = "if_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(x, y)
assert origin_out == outputs_after_load