mindir:support control flow while

This commit is contained in:
lanzhineng 2021-08-10 17:01:26 +08:00
parent 3b0c3e640b
commit 787b19c014
3 changed files with 136 additions and 15 deletions

View File

@ -121,19 +121,21 @@ using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult;
using mindspore::abstract::AnalysisContextPtr;
// Some operators are not defined.
inline bool ResetCNodeFromLoad(const AnfNodePtr &node) {
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
// Process partial("DeadNode",args) when the graph is loaded.
auto operatorPtr = node->cast<CNodePtr>()->input(0);
// Set abstract of switch(c,f,t) to null
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
node->set_abstract(nullptr);
return true;
}
// Set abstract of switch(c,f,t)() to null
prim = GetCNodePrimitive(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
// If the operator is not a primitive, the abstract will been set to null.
// Because there are not some operators in front end, the abstract of primitive should be reserved.
if (prim == nullptr) {
node->set_abstract(nullptr);
return true;
}
@ -156,13 +158,6 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
MS_EXCEPTION_IF_NULL(node);
const AbstractBasePtr &prev_inferred = node->abstract();
// AbstractFunction has context,but contexts in cache have been cleaned.
if (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>()) {
node->set_abstract(nullptr);
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
continue;
}
// Handle previous inferred value for CNode if is loaded from MindIR
if (res->is_load() && ResetCNodeFromLoad(node)) {
continue;
@ -536,8 +531,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
auto func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
func_graph->DumpFuncGraph(fg_name);
DumpIR(fg_name + ".ir", func_graph);
ExportIR(fg_name + ".dat", func_graph);
DumpIR(fg_name + ".ir", func_graph);
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
}
counter++;

View File

@ -776,6 +776,10 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
auto it = anfnode_build_map_.find(node_type.substr(kOpTypeFlagSize));
if (it != anfnode_build_map_.end()) {
auto funcGraph = GetValueNode<FuncGraphPtr>(it->second);
if (funcGraph != nullptr) {
return NewValueNode(funcGraph);
}
return it->second;
}
MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
@ -824,9 +828,10 @@ void MSANFModelParser::SetCNodeAbastract(const mind_ir::NodeProto &node_proto, C
cnode_ptr->set_abstract(nullptr);
return;
}
// Set abstract of switch(c,f,t)() to null
prim = GetCNodePrimitive(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
// If the operator is not a primitive, the abstract will been set to null.
// Because there are not some operators in front end, the abstract of primitive should be reserved.
if (prim == nullptr) {
cnode_ptr->set_abstract(nullptr);
return;
}

View File

@ -0,0 +1,121 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import export, load
class SingleWhileNet(nn.Cell):
def construct(self, x, y):
x += 1
while x < y:
x += 1
y += 2 * x
return y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_while():
context.set_context(mode=context.GRAPH_MODE)
network = SingleWhileNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
origin_out = network(x, y)
file_name = "while_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(x, y)
assert origin_out == outputs_after_load
class SingleWhileInlineNet(nn.Cell):
def construct(self, x, y):
x += 1
while x < y:
x += 1
y += x
return y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_while_inline_export():
context.set_context(mode=context.GRAPH_MODE)
network = SingleWhileInlineNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
file_name = "while_inline_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_while_inline_load():
context.set_context(mode=context.GRAPH_MODE)
network = SingleWhileInlineNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
file_name = "while_inline_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
load(mindir_name)
@pytest.mark.skip(reason="inline is not supported yet")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_while_inline():
context.set_context(mode=context.GRAPH_MODE)
network = SingleWhileInlineNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
origin_out = network(x, y)
file_name = "while_inline_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(x, y)
assert origin_out == outputs_after_load