forked from mindspore-Ecosystem/mindspore
mindir:support control flow while
This commit is contained in:
parent
3b0c3e640b
commit
787b19c014
|
@ -121,19 +121,21 @@ using CompileGraphs = compile::CompileGraphs;
|
|||
using abstract::AnalysisResult;
|
||||
using mindspore::abstract::AnalysisContextPtr;
|
||||
|
||||
// Some operators are not defined.
|
||||
inline bool ResetCNodeFromLoad(const AnfNodePtr &node) {
|
||||
if (node->isa<CNode>() && node->cast<CNodePtr>()->get_load_flag()) {
|
||||
// Process partial("DeadNode",args) when the graph is loaded.
|
||||
auto operatorPtr = node->cast<CNodePtr>()->input(0);
|
||||
// Set abstract of switch(c,f,t) to null
|
||||
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
|
||||
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
|
||||
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
|
||||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
|
||||
node->set_abstract(nullptr);
|
||||
return true;
|
||||
}
|
||||
// Set abstract of switch(c,f,t)() to null
|
||||
prim = GetCNodePrimitive(operatorPtr);
|
||||
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
|
||||
// If the operator is not a primitive, the abstract will been set to null.
|
||||
// Because there are not some operators in front end, the abstract of primitive should be reserved.
|
||||
if (prim == nullptr) {
|
||||
node->set_abstract(nullptr);
|
||||
return true;
|
||||
}
|
||||
|
@ -156,13 +158,6 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
const AbstractBasePtr &prev_inferred = node->abstract();
|
||||
|
||||
// AbstractFunction has context,but contexts in cache have been cleaned.
|
||||
if (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>()) {
|
||||
node->set_abstract(nullptr);
|
||||
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle previous inferred value for CNode if is loaded from MindIR
|
||||
if (res->is_load() && ResetCNodeFromLoad(node)) {
|
||||
continue;
|
||||
|
@ -536,8 +531,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
|
|||
auto func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph->DumpFuncGraph(fg_name);
|
||||
DumpIR(fg_name + ".ir", func_graph);
|
||||
ExportIR(fg_name + ".dat", func_graph);
|
||||
DumpIR(fg_name + ".ir", func_graph);
|
||||
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
|
||||
}
|
||||
counter++;
|
||||
|
|
|
@ -776,6 +776,10 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
|
|||
if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
|
||||
auto it = anfnode_build_map_.find(node_type.substr(kOpTypeFlagSize));
|
||||
if (it != anfnode_build_map_.end()) {
|
||||
auto funcGraph = GetValueNode<FuncGraphPtr>(it->second);
|
||||
if (funcGraph != nullptr) {
|
||||
return NewValueNode(funcGraph);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
|
||||
|
@ -824,9 +828,10 @@ void MSANFModelParser::SetCNodeAbastract(const mind_ir::NodeProto &node_proto, C
|
|||
cnode_ptr->set_abstract(nullptr);
|
||||
return;
|
||||
}
|
||||
// Set abstract of switch(c,f,t)() to null
|
||||
prim = GetCNodePrimitive(operatorPtr);
|
||||
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim)) {
|
||||
|
||||
// If the operator is not a primitive, the abstract will been set to null.
|
||||
// Because there are not some operators in front end, the abstract of primitive should be reserved.
|
||||
if (prim == nullptr) {
|
||||
cnode_ptr->set_abstract(nullptr);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.serialization import export, load
|
||||
|
||||
|
||||
class SingleWhileNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
x += 1
|
||||
while x < y:
|
||||
x += 1
|
||||
y += 2 * x
|
||||
return y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_single_while():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
network = SingleWhileNet()
|
||||
|
||||
x = Tensor(np.array([1]).astype(np.float32))
|
||||
y = Tensor(np.array([2]).astype(np.float32))
|
||||
origin_out = network(x, y)
|
||||
|
||||
file_name = "while_net"
|
||||
export(network, x, y, file_name=file_name, file_format='MINDIR')
|
||||
mindir_name = file_name + ".mindir"
|
||||
assert os.path.exists(mindir_name)
|
||||
|
||||
graph = load(mindir_name)
|
||||
loaded_net = nn.GraphCell(graph)
|
||||
outputs_after_load = loaded_net(x, y)
|
||||
assert origin_out == outputs_after_load
|
||||
|
||||
|
||||
class SingleWhileInlineNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
x += 1
|
||||
while x < y:
|
||||
x += 1
|
||||
y += x
|
||||
return y
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_single_while_inline_export():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
network = SingleWhileInlineNet()
|
||||
|
||||
x = Tensor(np.array([1]).astype(np.float32))
|
||||
y = Tensor(np.array([2]).astype(np.float32))
|
||||
|
||||
file_name = "while_inline_net"
|
||||
export(network, x, y, file_name=file_name, file_format='MINDIR')
|
||||
mindir_name = file_name + ".mindir"
|
||||
assert os.path.exists(mindir_name)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_single_while_inline_load():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
network = SingleWhileInlineNet()
|
||||
|
||||
x = Tensor(np.array([1]).astype(np.float32))
|
||||
y = Tensor(np.array([2]).astype(np.float32))
|
||||
|
||||
file_name = "while_inline_net"
|
||||
export(network, x, y, file_name=file_name, file_format='MINDIR')
|
||||
mindir_name = file_name + ".mindir"
|
||||
assert os.path.exists(mindir_name)
|
||||
load(mindir_name)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="inline is not supported yet")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_single_while_inline():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
network = SingleWhileInlineNet()
|
||||
|
||||
x = Tensor(np.array([1]).astype(np.float32))
|
||||
y = Tensor(np.array([2]).astype(np.float32))
|
||||
origin_out = network(x, y)
|
||||
|
||||
file_name = "while_inline_net"
|
||||
export(network, x, y, file_name=file_name, file_format='MINDIR')
|
||||
mindir_name = file_name + ".mindir"
|
||||
assert os.path.exists(mindir_name)
|
||||
|
||||
graph = load(mindir_name)
|
||||
loaded_net = nn.GraphCell(graph)
|
||||
outputs_after_load = loaded_net(x, y)
|
||||
assert origin_out == outputs_after_load
|
Loading…
Reference in New Issue