fix while loop compile error when weight is used in while condition

This commit is contained in:
zhousiyi 2021-08-25 01:33:34 +00:00
parent 114a82e258
commit 44bdcb101c
2 changed files with 18 additions and 0 deletions

View File

@ -138,6 +138,7 @@ class ParameterEliminator {
} }
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info())); TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
auto new_caller = caller->func_graph()->NewCNode(new_args); auto new_caller = caller->func_graph()->NewCNode(new_args);
new_caller->set_abstract(caller->abstract());
tr->Replace(caller, new_caller); tr->Replace(caller, new_caller);
} }
}; };

View File

@ -750,6 +750,23 @@ def test_while_scalar():
out = net(x, y) out = net(x, y)
def test_while_with_weight_in_condition():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.loop = Parameter(Tensor(1, dtype=ms.float32), name="loop")
def construct(self, x):
while self.loop < 5:
self.loop += 1
x += 1
return x
net = Net()
x = Tensor(-1, dtype=ms.float32)
grad_all(net)(x)
def test_mixed_precision_cast(): def test_mixed_precision_cast():
x = Tensor(np.ones([2, 3], dtype=np.float32)) x = Tensor(np.ones([2, 3], dtype=np.float32))
z = F.mixed_precision_cast(mstype.float16, x) z = F.mixed_precision_cast(mstype.float16, x)