mindir:support switchlayer

This commit is contained in:
lanzhineng 2021-08-13 16:16:10 +08:00
parent b69f492f25
commit 93a1956978
5 changed files with 144 additions and 59 deletions

View File

@ -121,30 +121,6 @@ using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult;
using mindspore::abstract::AnalysisContextPtr;
// Some operators are not defined.
inline bool ResetCNodeFromLoad(const AnfNodePtr &node) {
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
// Process partial("DeadNode",args) when the graph is loaded.
auto operatorPtr = node->cast<CNodePtr>()->input(0);
// Set abstract of switch(c,f,t) to null
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
node->set_abstract(nullptr);
return true;
}
// If the operator is not a primitive, the abstract will been set to null.
// Because there are not some operators in front end, the abstract of primitive should be reserved.
if (prim == nullptr) {
node->set_abstract(nullptr);
return true;
}
// Previous inferred value
return true;
}
return false;
}
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
const abstract::AbstractBasePtrList &args_spec, bool clear) {
MS_LOG(DEBUG) << "AbstractAnalyze start";
@ -156,17 +132,22 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
engine->Clear();
for (auto &node : manager->all_nodes()) {
MS_EXCEPTION_IF_NULL(node);
const AbstractBasePtr &prev_inferred = node->abstract();
// Handle previous inferred value for CNode if is loaded from MindIR
if (res->is_load() && ResetCNodeFromLoad(node)) {
if (res->is_load()) {
// If the primitive is not defined in front end,keep the inferred value loaded from MindIR.
auto primitive = GetCNodePrimitive(node);
if (primitive != nullptr && abstract::GetPrimEvaluator(primitive, engine) == nullptr) {
MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
continue;
}
}
const AbstractBasePtr &prev_inferred = node->abstract();
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
node->set_abstract(nullptr);
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
}
}
}
@ -275,7 +256,9 @@ void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &load
MS_EXCEPTION_IF_NULL(root_type);
MS_EXCEPTION_IF_NULL(loaded_type);
if (root_shape->shape() != loaded_shape->shape()) {
auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
(root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
if (!shapeEqu) {
MS_EXCEPTION(ValueError) << "The " << index
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
<< ", input shape of loaded graph: " << loaded_shape->ToString();
@ -531,8 +514,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
auto func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
func_graph->DumpFuncGraph(fg_name);
ExportIR(fg_name + ".dat", func_graph);
DumpIR(fg_name + ".ir", func_graph);
ExportIR(fg_name + ".dat", func_graph);
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
}
counter++;

View File

@ -359,7 +359,6 @@ void AnalysisEngine::Clear() {
root_context_ = nullptr;
}
namespace {
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
// Custom Primitive with python infer_shape, infer_type
MS_EXCEPTION_IF_NULL(prim);
@ -396,7 +395,8 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
engine->prim_py_evaluators_[prim_py] = evaluator;
return evaluator;
}
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
return nullptr;
}
// return a default evaluator
@ -416,11 +416,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
}
}
if (evaluator == nullptr) {
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
MS_LOG(DEBUG) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
}
return evaluator;
}
} // namespace
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
MS_EXCEPTION_IF_NULL(func);
@ -430,6 +429,9 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbs
}
auto primitive = func->prim();
auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
if (evaluator == nullptr) {
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
}
evaluators_[func] = evaluator;
return evaluator;
}
@ -1012,7 +1014,9 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
auto evaluator = GetPrimEvaluator(primitive, nullptr);
MS_EXCEPTION_IF_NULL(evaluator);
if (evaluator == nullptr) {
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << primitive->name() << ").";
}
if (!evaluator->isa<TrivialPrimEvaluator>()) {
MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but "
<< evaluator->ToString();

View File

@ -347,7 +347,7 @@ template <typename T>
AbstractBasePtr FromValue(const T &value, bool broaden = false) {
return FromValueInside(MakeValue(value), broaden);
}
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine);
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
} // namespace abstract
} // namespace mindspore

View File

@ -137,10 +137,11 @@ class IrExportBuilder {
mind_ir::ModelProto model_;
mind_ir::NodeProto *last_node_{nullptr};
std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, size_t> node_index_map_;
std::map<AnfNodePtr, std::string> node_index_map_;
std::set<std::string> nodeName_;
size_t node_index_{0};
size_t shape_index_{0};
bool top_graph{true};
};
using IrExporterPtr = std::shared_ptr<IrExporter>;
@ -185,9 +186,11 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
nodeName_.clear();
// Build the main funcGraph
nodeName_.insert(func_graph->ToString());
top_graph = true;
BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
std::set<FuncGraphPtr> graphVisited;
graphVisited.insert(func_graph);
top_graph = false;
while (!todo_.empty()) {
FuncGraphPtr fg = todo_.back();
todo_.pop_back();
@ -204,6 +207,7 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso
}
// Release resource
nodeName_.clear();
node_index_map_.clear();
}
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
@ -227,8 +231,8 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
}
std::string param_name = GetUniqueNodeName(param);
if (param->has_default()) {
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has default.";
if (top_graph && param->has_default()) {
MS_LOG(DEBUG) << "Parameter: '" << item->DebugString() << "' has default. address: " << (size_t)param.get();
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
parameter_proto->set_name(param_name);
SetParamToTensorProto(param, parameter_proto);
@ -308,7 +312,7 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
} else if (type->isa<Tuple>()) {
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
} else if (type->isa<Number>() || type->isa<String>()) {
} else if (type->isa<Number>() || type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
value_proto->set_denotation(type->type_name());
} else {
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
@ -541,27 +545,18 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
// Naming anfnode
// 1. parameter is unique in one func_graph
// 2. cnode and valuenode may be reduplicative, so add index to identify.
std::string node_name = "";
if (node->isa<Parameter>()) {
node_name = GetNodeName(node);
} else if (node->isa<CNode>()) {
auto iter = node_index_map_.find(node);
if (iter != node_index_map_.end()) {
node_name = GetNodeName(node) + ":" + std::to_string(iter->second);
return iter->second;
} else {
std::string node_name = GetNodeName(node);
while (nodeName_.count(node_name) > 0) {
auto node_idx = GetNodeIndex();
node_index_map_[node] = node_idx;
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
node_name = node_name + ":" + std::to_string(node_idx);
}
} else if (node->isa<ValueNode>()) {
auto node_idx = GetNodeIndex();
node_index_map_[node] = node_idx;
node_name = GetNodeName(node) + ":" + std::to_string(node_idx);
} else {
MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString();
}
MS_LOG(DEBUG) << "Node name: " << node_name;
node_index_map_[node] = node_name;
return node_name;
}
}
std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {

View File

@ -0,0 +1,103 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor, nn
from mindspore.common import dtype as mstype
from mindspore.train.serialization import export, load
class CaseNet(nn.Cell):
def __init__(self):
super(CaseNet, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
self.relu = nn.ReLU()
self.relu1 = nn.ReLU()
self.softmax = nn.Softmax()
self.layers1 = (self.relu, self.softmax)
self.layers2 = (self.conv, self.relu1)
def construct(self, x, index1, index2):
x = self.layers1[index1](x)
x = self.layers2[index2](x)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_mindir_switch_layer():
context.set_context(mode=context.GRAPH_MODE)
net = CaseNet()
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
idx = Tensor(0, mstype.int32)
idx2 = Tensor(-1, mstype.int32)
file_name = "switch_layer_net"
mindir_name = file_name + ".mindir"
export(net, data, idx, idx2, file_name=file_name, file_format='MINDIR')
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(data, idx, idx2)
relu = nn.ReLU()
true_value = relu(data)
ret = np.allclose(outputs_after_load.asnumpy(), true_value.asnumpy())
assert ret
@pytest.mark.skip(reason="depend on export")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_mindir_export():
context.set_context(mode=context.GRAPH_MODE)
net = CaseNet()
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
idx = Tensor(0, mstype.int32)
idx2 = Tensor(-1, mstype.int32)
file_name = "switch_layer_net"
mindir_name = file_name + ".mindir"
export(net, data, idx, idx2, file_name=file_name, file_format='MINDIR')
assert os.path.exists(mindir_name)
@pytest.mark.skip(reason="depend on export")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_mindir_load():
context.set_context(mode=context.GRAPH_MODE)
data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32)
idx = Tensor(0, mstype.int32)
idx2 = Tensor(-1, mstype.int32)
file_name = "switch_layer_net"
mindir_name = file_name + ".mindir"
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(data, idx, idx2)
relu = nn.ReLU()
true_value = relu(data)
ret = np.allclose(outputs_after_load.asnumpy(), true_value.asnumpy())
assert ret