!32819 fix partial_eliminate recursive issue

Merge pull request !32819 from lanzhineng/func_closure
This commit is contained in:
i-robot 2022-04-12 01:23:11 +00:00 committed by Gitee
commit 0a1913b370
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 31 additions and 1 deletions

View File

@ -162,7 +162,8 @@ class ChoicePartialEliminater : public AnfVisitor {
MS_EXCEPTION_IF_NULL(fg);
if (fg->func_graph_cnodes_index().size() != 1) {
// If a graph is used by 2 or more partial nodes at the same time, clone the graph.
auto new_fg = BasicClone(fg);
// BasicClone should be replaced by TransformableClone to avoid recursive.
auto new_fg = TransformableClone(fg);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(new_fg);

View File

@ -477,3 +477,32 @@ def test_continue_stuck_in_vm():
grad_net = Grad(net)
grad = grad_net(x, y)
print(grad)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_partial_eliminate_while_for_if_break():
"""
Feature: nest control flow.
Description: nest control flow with while,for,if and break.
Expectation: Null.
"""
class NetWork(nn.Cell):
def construct(self, x):
while x < 3:
for _ in range(2):
if x <= 4:
x = x + 1
break
x = 1 + x
return x
x = np.array([0], np.float32)
net = NetWork()
grad_net = Grad(net)
grad = grad_net(Tensor(x))
print(grad)