Bypass specializing the FuncGraph which is a input of a Partial CNode and it been specialized in ProcessCNode in FirstPass

This commit is contained in:
zhousiyi 2022-01-19 02:14:04 +00:00
parent 377d6108f9
commit 784d4b4315
2 changed files with 48 additions and 3 deletions

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -58,7 +58,18 @@ const StringImmPtr kDeadNode = std::make_shared<StringImm>(kDeadNodeName);
const StringImmPtr kPolyNode = std::make_shared<StringImm>(kPolyNodeName);
inline bool CanSpecializeValueNode(const AnfNodePtr &node) {
if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
if (IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
return true;
}
if (IsValueNode<FuncGraph>(node)) {
if (node->abstract() != nullptr) {
auto abs_func = node->abstract()->cast<FuncGraphAbstractClosurePtr>();
// If this funcgraph had specialized in ProcessCNode of FirstPass,
// then ignore it.
if (abs_func != nullptr && abs_func->specialized()) {
return false;
}
}
return true;
}
return false;
@ -410,6 +421,8 @@ void FuncGraphSpecializer::FirstPass() {
parent->FirstPass();
AnfNodePtr new_node = parent->GetReplicatedNode(node);
if (new_node->isa<CNode>()) {
MS_LOG(INFO) << "ProcessCNode in FirstPass for " << func_graph_->ToString() << ", node: " << node->DebugString()
<< ", new_node: " << new_node->DebugString();
parent->ProcessCNode(new_node->cast<CNodePtr>());
}
continue;

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -95,3 +95,35 @@ def test_single_if_01():
expect1 = Tensor(26, mstype.int32)
expect2 = (Tensor(2, mstype.int32), Tensor(2, mstype.int32))
control_flow_single_if(SingleIfNet1, x, y, expect1, expect2)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_single_if_any():
"""
Feature: compile and run control flow with if statement
Description: true-branch func graph refer a CNode in construct as free variable.
That CNode and the inputs will be specialized before ProcessCNode
of true-branch func graph, so it's no need to specialize the inputs
of that CNode again if it's a specialized func graph.
Expectation: success
"""
x = Tensor([True, True, False])
y = Tensor([False])
class Net(nn.Cell):
def __init__(self, input1, input2):
super().__init__()
self.input1 = input1
self.input2 = input2
def construct(self):
if self.input1.all() == self.input2:
return self.input1.any()
return self.input2
context.set_context(mode=context.GRAPH_MODE)
net = Net(x, y)
output = net()
assert output