!21606 mindir: support control flow_ while

Merge pull request !21606 from lanzhineng/mindir_control_flow
This commit is contained in:
i-robot 2021-08-11 01:26:35 +00:00 committed by Gitee
commit 9211814983
3 changed files with 136 additions and 15 deletions

View File

@ -121,19 +121,21 @@ using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult; using abstract::AnalysisResult;
using mindspore::abstract::AnalysisContextPtr; using mindspore::abstract::AnalysisContextPtr;
// Some operators are not defined.
inline bool ResetCNodeFromLoad(const AnfNodePtr &node) { inline bool ResetCNodeFromLoad(const AnfNodePtr &node) {
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) { if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
// Process partial("DeadNode",args) when the graph is loaded. // Process partial("DeadNode",args) when the graph is loaded.
auto operatorPtr = node->cast<CNodePtr>()->input(0); auto operatorPtr = node->cast<CNodePtr>()->input(0);
// Set abstract of switch(c,f,t) to null // Set abstract of switch(c,f,t) to null
auto prim = GetValueNode<PrimitivePtr>(operatorPtr); auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) { if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
node->set_abstract(nullptr); node->set_abstract(nullptr);
return true; return true;
} }
// Set abstract of switch(c,f,t)() to null // If the operator is not a primitive, the abstract will been set to null.
prim = GetCNodePrimitive(operatorPtr); // Because there are not some operators in front end, the abstract of primitive should be reserved.
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) { if (prim == nullptr) {
node->set_abstract(nullptr); node->set_abstract(nullptr);
return true; return true;
} }
@ -156,13 +158,6 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
const AbstractBasePtr &prev_inferred = node->abstract(); const AbstractBasePtr &prev_inferred = node->abstract();
// AbstractFunction has context,but contexts in cache have been cleaned.
if (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>()) {
node->set_abstract(nullptr);
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
continue;
}
// Handle previous inferred value for CNode if is loaded from MindIR // Handle previous inferred value for CNode if is loaded from MindIR
if (res->is_load() && ResetCNodeFromLoad(node)) { if (res->is_load() && ResetCNodeFromLoad(node)) {
continue; continue;
@ -536,8 +531,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
auto func_graph = res->func_graph(); auto func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
func_graph->DumpFuncGraph(fg_name); func_graph->DumpFuncGraph(fg_name);
DumpIR(fg_name + ".ir", func_graph);
ExportIR(fg_name + ".dat", func_graph); ExportIR(fg_name + ".dat", func_graph);
DumpIR(fg_name + ".ir", func_graph);
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
} }
counter++; counter++;

View File

@ -776,6 +776,10 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) { if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
auto it = anfnode_build_map_.find(node_type.substr(kOpTypeFlagSize)); auto it = anfnode_build_map_.find(node_type.substr(kOpTypeFlagSize));
if (it != anfnode_build_map_.end()) { if (it != anfnode_build_map_.end()) {
auto funcGraph = GetValueNode<FuncGraphPtr>(it->second);
if (funcGraph != nullptr) {
return NewValueNode(funcGraph);
}
return it->second; return it->second;
} }
MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type; MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
@ -824,9 +828,10 @@ void MSANFModelParser::SetCNodeAbastract(const mind_ir::NodeProto &node_proto, C
cnode_ptr->set_abstract(nullptr); cnode_ptr->set_abstract(nullptr);
return; return;
} }
// Set abstract of switch(c,f,t)() to null
prim = GetCNodePrimitive(operatorPtr); // If the operator is not a primitive, the abstract will been set to null.
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) { // Because there are not some operators in front end, the abstract of primitive should be reserved.
if (prim == nullptr) {
cnode_ptr->set_abstract(nullptr); cnode_ptr->set_abstract(nullptr);
return; return;
} }

View File

@ -0,0 +1,121 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import export, load
class SingleWhileNet(nn.Cell):
def construct(self, x, y):
x += 1
while x < y:
x += 1
y += 2 * x
return y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_while():
context.set_context(mode=context.GRAPH_MODE)
network = SingleWhileNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
origin_out = network(x, y)
file_name = "while_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(x, y)
assert origin_out == outputs_after_load
class SingleWhileInlineNet(nn.Cell):
def construct(self, x, y):
x += 1
while x < y:
x += 1
y += x
return y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_while_inline_export():
context.set_context(mode=context.GRAPH_MODE)
network = SingleWhileInlineNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
file_name = "while_inline_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_while_inline_load():
context.set_context(mode=context.GRAPH_MODE)
network = SingleWhileInlineNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
file_name = "while_inline_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
load(mindir_name)
@pytest.mark.skip(reason="inline is not supported yet")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_single_while_inline():
context.set_context(mode=context.GRAPH_MODE)
network = SingleWhileInlineNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
origin_out = network(x, y)
file_name = "while_inline_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(x, y)
assert origin_out == outputs_after_load