!22344 fix compile error when weight is used in while condition

Merge pull request !22344 from xychow/fix-while-loop-with-weight-in-condition
This commit is contained in:
i-robot 2021-08-26 03:00:38 +00:00 committed by Gitee
commit 07eaa1969b
2 changed files with 18 additions and 0 deletions

View File

@ -138,6 +138,7 @@ class ParameterEliminator {
}
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
auto new_caller = caller->func_graph()->NewCNode(new_args);
new_caller->set_abstract(caller->abstract());
tr->Replace(caller, new_caller);
}
};

View File

@ -750,6 +750,23 @@ def test_while_scalar():
out = net(x, y)
def test_while_with_weight_in_condition():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.loop = Parameter(Tensor(1, dtype=ms.float32), name="loop")
def construct(self, x):
while self.loop < 5:
self.loop += 1
x += 1
return x
net = Net()
x = Tensor(-1, dtype=ms.float32)
grad_all(net)(x)
def test_mixed_precision_cast():
x = Tensor(np.ones([2, 3], dtype=np.float32))
z = F.mixed_precision_cast(mstype.float16, x)