forked from mindspore-Ecosystem/mindspore
!21803 mindir: support control flow switch layer
Merge pull request !21803 from lanzhineng/mindir_control_flow
This commit is contained in:
commit
223f500bab
|
@ -121,30 +121,6 @@ using CompileGraphs = compile::CompileGraphs;
|
|||
using abstract::AnalysisResult;
|
||||
using mindspore::abstract::AnalysisContextPtr;
|
||||
|
||||
// Some operators are not defined.
|
||||
inline bool ResetCNodeFromLoad(const AnfNodePtr &node) {
|
||||
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
|
||||
// Process partial("DeadNode",args) when the graph is loaded.
|
||||
auto operatorPtr = node->cast<CNodePtr>()->input(0);
|
||||
// Set abstract of switch(c,f,t) to null
|
||||
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
|
||||
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
|
||||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
|
||||
node->set_abstract(nullptr);
|
||||
return true;
|
||||
}
|
||||
// If the operator is not a primitive, the abstract will been set to null.
|
||||
// Because there are not some operators in front end, the abstract of primitive should be reserved.
|
||||
if (prim == nullptr) {
|
||||
node->set_abstract(nullptr);
|
||||
return true;
|
||||
}
|
||||
// Previous inferred value
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
||||
const abstract::AbstractBasePtrList &args_spec, bool clear) {
|
||||
MS_LOG(DEBUG) << "AbstractAnalyze start";
|
||||
|
@ -156,17 +132,22 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
|
|||
engine->Clear();
|
||||
for (auto &node : manager->all_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const AbstractBasePtr &prev_inferred = node->abstract();
|
||||
|
||||
// Handle previous inferred value for CNode if is loaded from MindIR
|
||||
if (res->is_load() && ResetCNodeFromLoad(node)) {
|
||||
continue;
|
||||
if (res->is_load()) {
|
||||
// If the primitive is not defined in front end,keep the inferred value loaded from MindIR.
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
if (primitive != nullptr && abstract::GetPrimEvaluator(primitive, engine) == nullptr) {
|
||||
MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const AbstractBasePtr &prev_inferred = node->abstract();
|
||||
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
|
||||
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
|
||||
node->set_abstract(nullptr);
|
||||
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
|
||||
MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -275,7 +256,9 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
|
|||
MS_EXCEPTION_IF_NULL(root_type);
|
||||
MS_EXCEPTION_IF_NULL(loaded_type);
|
||||
|
||||
if (root_shape->shape() != loaded_shape->shape()) {
|
||||
auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
|
||||
(root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
|
||||
if (!shapeEqu) {
|
||||
MS_EXCEPTION(ValueError) << "The " << index
|
||||
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
|
||||
<< ", input shape of loaded graph: " << loaded_shape->ToString();
|
||||
|
@ -531,8 +514,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
|
|||
auto func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph->DumpFuncGraph(fg_name);
|
||||
ExportIR(fg_name + ".dat", func_graph);
|
||||
DumpIR(fg_name + ".ir", func_graph);
|
||||
ExportIR(fg_name + ".dat", func_graph);
|
||||
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
|
||||
}
|
||||
counter++;
|
||||
|
|
|
@ -359,7 +359,6 @@ void AnalysisEngine::Clear() {
|
|||
root_context_ = nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
|
||||
// Custom Primitive with python infer_shape, infer_type
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
@ -396,7 +395,8 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
engine->prim_py_evaluators_[prim_py] = evaluator;
|
||||
return evaluator;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
|
||||
MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// return a default evaluator
|
||||
|
@ -416,11 +416,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
}
|
||||
}
|
||||
if (evaluator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
|
||||
MS_LOG(DEBUG) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
|
||||
}
|
||||
return evaluator;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
|
@ -430,6 +429,9 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbs
|
|||
}
|
||||
auto primitive = func->prim();
|
||||
auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
|
||||
if (evaluator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
|
||||
}
|
||||
evaluators_[func] = evaluator;
|
||||
return evaluator;
|
||||
}
|
||||
|
@ -1012,7 +1014,9 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
|
|||
|
||||
EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
|
||||
auto evaluator = GetPrimEvaluator(primitive, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
if (evaluator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
|
||||
}
|
||||
if (!evaluator->isa<TrivialPrimEvaluator>()) {
|
||||
MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but "
|
||||
<< evaluator->ToString();
|
||||
|
|
|
@ -347,7 +347,7 @@ template <typename T>
|
|||
AbstractBasePtr FromValue(const T &value, bool broaden = false) {
|
||||
return FromValueInside(MakeValue(value), broaden);
|
||||
}
|
||||
|
||||
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine);
|
||||
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -137,10 +137,11 @@ class IrExportBuilder {
|
|||
mind_ir::ModelProto model_;
|
||||
mind_ir::NodeProto *last_node_{nullptr};
|
||||
std::list<FuncGraphPtr> todo_;
|
||||
std::map<AnfNodePtr, size_t> node_index_map_;
|
||||
std::map<AnfNodePtr, std::string> node_index_map_;
|
||||
std::set<std::string> nodeName_;
|
||||
size_t node_index_{0};
|
||||
size_t shape_index_{0};
|
||||
bool top_graph{true};
|
||||
};
|
||||
|
||||
using IrExporterPtr = std::shared_ptr<IrExporter>;
|
||||
|
@ -185,9 +186,11 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
|
|||
nodeName_.clear();
|
||||
// Build the main funcGraph
|
||||
nodeName_.insert(func_graph->ToString());
|
||||
top_graph = true;
|
||||
BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
|
||||
std::set<FuncGraphPtr> graphVisited;
|
||||
graphVisited.insert(func_graph);
|
||||
top_graph = false;
|
||||
while (!todo_.empty()) {
|
||||
FuncGraphPtr fg = todo_.back();
|
||||
todo_.pop_back();
|
||||
|
@ -204,6 +207,7 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
|
|||
}
|
||||
// Release resource
|
||||
nodeName_.clear();
|
||||
node_index_map_.clear();
|
||||
}
|
||||
|
||||
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
|
||||
|
@ -227,8 +231,8 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
|||
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
|
||||
}
|
||||
std::string param_name = GetUniqueNodeName(param);
|
||||
if (param->has_default()) {
|
||||
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has default.";
|
||||
if (top_graph && param->has_default()) {
|
||||
MS_LOG(DEBUG) << "Parameter: '" << item->DebugString() << "' has default. address: " << (size_t)param.get();
|
||||
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
|
||||
parameter_proto->set_name(param_name);
|
||||
SetParamToTensorProto(param, parameter_proto);
|
||||
|
@ -308,7 +312,7 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
|
|||
} else if (type->isa<Tuple>()) {
|
||||
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
|
||||
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
|
||||
} else if (type->isa<Number>() || type->isa<String>()) {
|
||||
} else if (type->isa<Number>() || type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
|
||||
value_proto->set_denotation(type->type_name());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
|
||||
|
@ -541,27 +545,18 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
|
|||
// Naming anfnode
|
||||
// 1. parameter is unique in one func_graph
|
||||
// 2. cnode and valuenode may be reduplicative, so add index to identify.
|
||||
std::string node_name = "";
|
||||
if (node->isa<Parameter>()) {
|
||||
node_name = GetNodeName(node);
|
||||
} else if (node->isa<CNode>()) {
|
||||
auto iter = node_index_map_.find(node);
|
||||
if (iter != node_index_map_.end()) {
|
||||
node_name = GetNodeName(node) + ":" + std::to_string(iter->second);
|
||||
} else {
|
||||
auto node_idx = GetNodeIndex();
|
||||
node_index_map_[node] = node_idx;
|
||||
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
|
||||
}
|
||||
} else if (node->isa<ValueNode>()) {
|
||||
auto node_idx = GetNodeIndex();
|
||||
node_index_map_[node] = node_idx;
|
||||
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
|
||||
auto iter = node_index_map_.find(node);
|
||||
if (iter != node_index_map_.end()) {
|
||||
return iter->second;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString();
|
||||
std::string node_name = GetNodeName(node);
|
||||
while (nodeName_.count(node_name) > 0) {
|
||||
auto node_idx = GetNodeIndex();
|
||||
node_name = node_name + ":" + std::to_string(node_idx);
|
||||
}
|
||||
node_index_map_[node] = node_name;
|
||||
return node_name;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Node name: " << node_name;
|
||||
return node_name;
|
||||
}
|
||||
|
||||
std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import export, load
|
||||
|
||||
|
||||
class CaseNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CaseNet, self).__init__()
|
||||
self.conv = nn.Conv2d(1, 1, 3)
|
||||
self.relu = nn.ReLU()
|
||||
self.relu1 = nn.ReLU()
|
||||
self.softmax = nn.Softmax()
|
||||
self.layers1 = (self.relu, self.softmax)
|
||||
self.layers2 = (self.conv, self.relu1)
|
||||
|
||||
def construct(self, x, index1, index2):
|
||||
x = self.layers1[index1](x)
|
||||
x = self.layers2[index2](x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mindir_switch_layer():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = CaseNet()
|
||||
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
|
||||
idx = Tensor(0, mstype.int32)
|
||||
idx2 = Tensor(-1, mstype.int32)
|
||||
|
||||
file_name = "switch_layer_net"
|
||||
mindir_name = file_name + ".mindir"
|
||||
export(net, data, idx, idx2, file_name=file_name, file_format='MINDIR')
|
||||
assert os.path.exists(mindir_name)
|
||||
|
||||
graph = load(mindir_name)
|
||||
loaded_net = nn.GraphCell(graph)
|
||||
outputs_after_load = loaded_net(data, idx, idx2)
|
||||
relu = nn.ReLU()
|
||||
true_value = relu(data)
|
||||
ret = np.allclose(outputs_after_load.asnumpy(), true_value.asnumpy())
|
||||
assert ret
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depend on export")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mindir_export():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = CaseNet()
|
||||
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
|
||||
idx = Tensor(0, mstype.int32)
|
||||
idx2 = Tensor(-1, mstype.int32)
|
||||
|
||||
file_name = "switch_layer_net"
|
||||
mindir_name = file_name + ".mindir"
|
||||
export(net, data, idx, idx2, file_name=file_name, file_format='MINDIR')
|
||||
assert os.path.exists(mindir_name)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depend on export")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mindir_load():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
|
||||
idx = Tensor(0, mstype.int32)
|
||||
idx2 = Tensor(-1, mstype.int32)
|
||||
|
||||
file_name = "switch_layer_net"
|
||||
mindir_name = file_name + ".mindir"
|
||||
graph = load(mindir_name)
|
||||
loaded_net = nn.GraphCell(graph)
|
||||
outputs_after_load = loaded_net(data, idx, idx2)
|
||||
relu = nn.ReLU()
|
||||
true_value = relu(data)
|
||||
ret = np.allclose(outputs_after_load.asnumpy(), true_value.asnumpy())
|
||||
assert ret
|
Loading…
Reference in New Issue