forked from mindspore-Ecosystem/mindspore
fix while loop compile error when weight is used in while condition
This commit is contained in:
parent
114a82e258
commit
44bdcb101c
|
@ -138,6 +138,7 @@ class ParameterEliminator {
|
|||
}
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
|
||||
auto new_caller = caller->func_graph()->NewCNode(new_args);
|
||||
new_caller->set_abstract(caller->abstract());
|
||||
tr->Replace(caller, new_caller);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -750,6 +750,23 @@ def test_while_scalar():
|
|||
out = net(x, y)
|
||||
|
||||
|
||||
def test_while_with_weight_in_condition():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.loop = Parameter(Tensor(1, dtype=ms.float32), name="loop")
|
||||
|
||||
def construct(self, x):
|
||||
while self.loop < 5:
|
||||
self.loop += 1
|
||||
x += 1
|
||||
return x
|
||||
|
||||
net = Net()
|
||||
x = Tensor(-1, dtype=ms.float32)
|
||||
grad_all(net)(x)
|
||||
|
||||
|
||||
def test_mixed_precision_cast():
|
||||
x = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||
z = F.mixed_precision_cast(mstype.float16, x)
|
||||
|
|
Loading…
Reference in New Issue