!32819 fix partial_eliminate recursive issue
Merge pull request !32819 from lanzhineng/func_closure
This commit is contained in:
commit
0a1913b370
|
@ -162,7 +162,8 @@ class ChoicePartialEliminater : public AnfVisitor {
|
|||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if (fg->func_graph_cnodes_index().size() != 1) {
|
||||
// If a graph is used by 2 or more partial nodes at the same time, clone the graph.
|
||||
auto new_fg = BasicClone(fg);
|
||||
// BasicClone should be replaced by TransformableClone to avoid recursive.
|
||||
auto new_fg = TransformableClone(fg);
|
||||
auto manager = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(new_fg);
|
||||
|
|
|
@ -477,3 +477,32 @@ def test_continue_stuck_in_vm():
|
|||
grad_net = Grad(net)
|
||||
grad = grad_net(x, y)
|
||||
print(grad)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_partial_eliminate_while_for_if_break():
|
||||
"""
|
||||
Feature: nest control flow.
|
||||
Description: nest control flow with while,for,if and break.
|
||||
Expectation: Null.
|
||||
"""
|
||||
|
||||
class NetWork(nn.Cell):
|
||||
def construct(self, x):
|
||||
while x < 3:
|
||||
for _ in range(2):
|
||||
if x <= 4:
|
||||
x = x + 1
|
||||
break
|
||||
x = 1 + x
|
||||
return x
|
||||
|
||||
x = np.array([0], np.float32)
|
||||
net = NetWork()
|
||||
grad_net = Grad(net)
|
||||
grad = grad_net(Tensor(x))
|
||||
print(grad)
|
||||
|
|
Loading…
Reference in New Issue