!21803 mindir: support control flow switch layer

Merge pull request !21803 from lanzhineng/mindir_control_flow
This commit is contained in:
i-robot 2021-08-16 02:15:16 +00:00 committed by Gitee
commit 223f500bab
5 changed files with 144 additions and 59 deletions

View File

@ -121,30 +121,6 @@ using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult; using abstract::AnalysisResult;
using mindspore::abstract::AnalysisContextPtr; using mindspore::abstract::AnalysisContextPtr;
// Some operators are not defined.
inline bool ResetCNodeFromLoad(const AnfNodePtr &node) {
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
// Process partial("DeadNode",args) when the graph is loaded.
auto operatorPtr = node->cast<CNodePtr>()->input(0);
// Set abstract of switch(c,f,t) to null
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
node->set_abstract(nullptr);
return true;
}
// If the operator is not a primitive, the abstract will been set to null.
// Because there are not some operators in front end, the abstract of primitive should be reserved.
if (prim == nullptr) {
node->set_abstract(nullptr);
return true;
}
// Previous inferred value
return true;
}
return false;
}
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
const abstract::AbstractBasePtrList &args_spec, bool clear) { const abstract::AbstractBasePtrList &args_spec, bool clear) {
MS_LOG(DEBUG) << "AbstractAnalyze start"; MS_LOG(DEBUG) << "AbstractAnalyze start";
@ -156,17 +132,22 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
engine->Clear(); engine->Clear();
for (auto &node : manager->all_nodes()) { for (auto &node : manager->all_nodes()) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
const AbstractBasePtr &prev_inferred = node->abstract();
// Handle previous inferred value for CNode if is loaded from MindIR // Handle previous inferred value for CNode if is loaded from MindIR
if (res->is_load() && ResetCNodeFromLoad(node)) { if (res->is_load()) {
// If the primitive is not defined in front end,keep the inferred value loaded from MindIR.
auto primitive = GetCNodePrimitive(node);
if (primitive != nullptr && abstract::GetPrimEvaluator(primitive, engine) == nullptr) {
MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
continue; continue;
} }
}
const AbstractBasePtr &prev_inferred = node->abstract();
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) { if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
node->set_abstract(nullptr); node->set_abstract(nullptr);
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr"; MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
} }
} }
} }
@ -275,7 +256,9 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
MS_EXCEPTION_IF_NULL(root_type); MS_EXCEPTION_IF_NULL(root_type);
MS_EXCEPTION_IF_NULL(loaded_type); MS_EXCEPTION_IF_NULL(loaded_type);
if (root_shape->shape() != loaded_shape->shape()) { auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
(root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
if (!shapeEqu) {
MS_EXCEPTION(ValueError) << "The " << index MS_EXCEPTION(ValueError) << "The " << index
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString() << " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
<< ", input shape of loaded graph: " << loaded_shape->ToString(); << ", input shape of loaded graph: " << loaded_shape->ToString();
@ -531,8 +514,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
auto func_graph = res->func_graph(); auto func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
func_graph->DumpFuncGraph(fg_name); func_graph->DumpFuncGraph(fg_name);
ExportIR(fg_name + ".dat", func_graph);
DumpIR(fg_name + ".ir", func_graph); DumpIR(fg_name + ".ir", func_graph);
ExportIR(fg_name + ".dat", func_graph);
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
} }
counter++; counter++;

View File

@ -359,7 +359,6 @@ void AnalysisEngine::Clear() {
root_context_ = nullptr; root_context_ = nullptr;
} }
namespace {
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
// Custom Primitive with python infer_shape, infer_type // Custom Primitive with python infer_shape, infer_type
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
@ -396,7 +395,8 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
engine->prim_py_evaluators_[prim_py] = evaluator; engine->prim_py_evaluators_[prim_py] = evaluator;
return evaluator; return evaluator;
} }
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
return nullptr;
} }
// return a default evaluator // return a default evaluator
@ -416,11 +416,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
} }
} }
if (evaluator == nullptr) { if (evaluator == nullptr) {
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ")."; MS_LOG(DEBUG) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
} }
return evaluator; return evaluator;
} }
} // namespace
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) { EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
MS_EXCEPTION_IF_NULL(func); MS_EXCEPTION_IF_NULL(func);
@ -430,6 +429,9 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbs
} }
auto primitive = func->prim(); auto primitive = func->prim();
auto evaluator = GetPrimEvaluator(primitive, shared_from_this()); auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
if (evaluator == nullptr) {
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
}
evaluators_[func] = evaluator; evaluators_[func] = evaluator;
return evaluator; return evaluator;
} }
@ -1012,7 +1014,9 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
auto evaluator = GetPrimEvaluator(primitive, nullptr); auto evaluator = GetPrimEvaluator(primitive, nullptr);
MS_EXCEPTION_IF_NULL(evaluator); if (evaluator == nullptr) {
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
}
if (!evaluator->isa<TrivialPrimEvaluator>()) { if (!evaluator->isa<TrivialPrimEvaluator>()) {
MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but " MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but "
<< evaluator->ToString(); << evaluator->ToString();

View File

@ -347,7 +347,7 @@ template <typename T>
AbstractBasePtr FromValue(const T &value, bool broaden = false) { AbstractBasePtr FromValue(const T &value, bool broaden = false) {
return FromValueInside(MakeValue(value), broaden); return FromValueInside(MakeValue(value), broaden);
} }
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine);
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -137,10 +137,11 @@ class IrExportBuilder {
mind_ir::ModelProto model_; mind_ir::ModelProto model_;
mind_ir::NodeProto *last_node_{nullptr}; mind_ir::NodeProto *last_node_{nullptr};
std::list<FuncGraphPtr> todo_; std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, size_t> node_index_map_; std::map<AnfNodePtr, std::string> node_index_map_;
std::set<std::string> nodeName_; std::set<std::string> nodeName_;
size_t node_index_{0}; size_t node_index_{0};
size_t shape_index_{0}; size_t shape_index_{0};
bool top_graph{true};
}; };
using IrExporterPtr = std::shared_ptr<IrExporter>; using IrExporterPtr = std::shared_ptr<IrExporter>;
@ -185,9 +186,11 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
nodeName_.clear(); nodeName_.clear();
// Build the main funcGraph // Build the main funcGraph
nodeName_.insert(func_graph->ToString()); nodeName_.insert(func_graph->ToString());
top_graph = true;
BuildFuncGraph(func_graph, graph_proto, save_tensor_data); BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
std::set<FuncGraphPtr> graphVisited; std::set<FuncGraphPtr> graphVisited;
graphVisited.insert(func_graph); graphVisited.insert(func_graph);
top_graph = false;
while (!todo_.empty()) { while (!todo_.empty()) {
FuncGraphPtr fg = todo_.back(); FuncGraphPtr fg = todo_.back();
todo_.pop_back(); todo_.pop_back();
@ -204,6 +207,7 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
} }
// Release resource // Release resource
nodeName_.clear(); nodeName_.clear();
node_index_map_.clear();
} }
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
@ -227,8 +231,8 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
} }
std::string param_name = GetUniqueNodeName(param); std::string param_name = GetUniqueNodeName(param);
if (param->has_default()) { if (top_graph && param->has_default()) {
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has default."; MS_LOG(DEBUG) << "Parameter: '" << item->DebugString() << "' has default. address: " << (size_t)param.get();
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter(); mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
parameter_proto->set_name(param_name); parameter_proto->set_name(param_name);
SetParamToTensorProto(param, parameter_proto); SetParamToTensorProto(param, parameter_proto);
@ -308,7 +312,7 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
} else if (type->isa<Tuple>()) { } else if (type->isa<Tuple>()) {
auto tup_shape = shape->cast<abstract::TupleShapePtr>(); auto tup_shape = shape->cast<abstract::TupleShapePtr>();
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
} else if (type->isa<Number>() || type->isa<String>()) { } else if (type->isa<Number>() || type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
value_proto->set_denotation(type->type_name()); value_proto->set_denotation(type->type_name());
} else { } else {
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
@ -541,28 +545,19 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
// Naming anfnode // Naming anfnode
// 1. parameter is unique in one func_graph // 1. parameter is unique in one func_graph
// 2. cnode and valuenode may be reduplicative, so add index to identify. // 2. cnode and valuenode may be reduplicative, so add index to identify.
std::string node_name = "";
if (node->isa<Parameter>()) {
node_name = GetNodeName(node);
} else if (node->isa<CNode>()) {
auto iter = node_index_map_.find(node); auto iter = node_index_map_.find(node);
if (iter != node_index_map_.end()) { if (iter != node_index_map_.end()) {
node_name = GetNodeName(node) + ":" + std::to_string(iter->second); return iter->second;
} else { } else {
std::string node_name = GetNodeName(node);
while (nodeName_.count(node_name) > 0) {
auto node_idx = GetNodeIndex(); auto node_idx = GetNodeIndex();
node_index_map_[node] = node_idx; node_name = node_name + ":" + std::to_string(node_idx);
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
} }
} else if (node->isa<ValueNode>()) { node_index_map_[node] = node_name;
auto node_idx = GetNodeIndex();
node_index_map_[node] = node_idx;
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
} else {
MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString();
}
MS_LOG(DEBUG) << "Node name: " << node_name;
return node_name; return node_name;
} }
}
std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
std::string node_name = ""; std::string node_name = "";

View File

@ -0,0 +1,103 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor, nn
from mindspore.common import dtype as mstype
from mindspore.train.serialization import export, load
class CaseNet(nn.Cell):
def __init__(self):
super(CaseNet, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
self.relu = nn.ReLU()
self.relu1 = nn.ReLU()
self.softmax = nn.Softmax()
self.layers1 = (self.relu, self.softmax)
self.layers2 = (self.conv, self.relu1)
def construct(self, x, index1, index2):
x = self.layers1[index1](x)
x = self.layers2[index2](x)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_mindir_switch_layer():
context.set_context(mode=context.GRAPH_MODE)
net = CaseNet()
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
idx = Tensor(0, mstype.int32)
idx2 = Tensor(-1, mstype.int32)
file_name = "switch_layer_net"
mindir_name = file_name + ".mindir"
export(net, data, idx, idx2, file_name=file_name, file_format='MINDIR')
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(data, idx, idx2)
relu = nn.ReLU()
true_value = relu(data)
ret = np.allclose(outputs_after_load.asnumpy(), true_value.asnumpy())
assert ret
@pytest.mark.skip(reason="depend on export")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_mindir_export():
context.set_context(mode=context.GRAPH_MODE)
net = CaseNet()
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
idx = Tensor(0, mstype.int32)
idx2 = Tensor(-1, mstype.int32)
file_name = "switch_layer_net"
mindir_name = file_name + ".mindir"
export(net, data, idx, idx2, file_name=file_name, file_format='MINDIR')
assert os.path.exists(mindir_name)
@pytest.mark.skip(reason="depend on export")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_mindir_load():
context.set_context(mode=context.GRAPH_MODE)
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
idx = Tensor(0, mstype.int32)
idx2 = Tensor(-1, mstype.int32)
file_name = "switch_layer_net"
mindir_name = file_name + ".mindir"
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(data, idx, idx2)
relu = nn.ReLU()
true_value = relu(data)
ret = np.allclose(outputs_after_load.asnumpy(), true_value.asnumpy())
assert ret