forked from mindspore-Ecosystem/mindspore
fix while loop compile error when weight is used in while condition
This commit is contained in:
parent
114a82e258
commit
44bdcb101c
|
@ -138,6 +138,7 @@ class ParameterEliminator {
|
||||||
}
|
}
|
||||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
|
||||||
auto new_caller = caller->func_graph()->NewCNode(new_args);
|
auto new_caller = caller->func_graph()->NewCNode(new_args);
|
||||||
|
new_caller->set_abstract(caller->abstract());
|
||||||
tr->Replace(caller, new_caller);
|
tr->Replace(caller, new_caller);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -750,6 +750,23 @@ def test_while_scalar():
|
||||||
out = net(x, y)
|
out = net(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_while_with_weight_in_condition():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.loop = Parameter(Tensor(1, dtype=ms.float32), name="loop")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
while self.loop < 5:
|
||||||
|
self.loop += 1
|
||||||
|
x += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
x = Tensor(-1, dtype=ms.float32)
|
||||||
|
grad_all(net)(x)
|
||||||
|
|
||||||
|
|
||||||
def test_mixed_precision_cast():
|
def test_mixed_precision_cast():
|
||||||
x = Tensor(np.ones([2, 3], dtype=np.float32))
|
x = Tensor(np.ones([2, 3], dtype=np.float32))
|
||||||
z = F.mixed_precision_cast(mstype.float16, x)
|
z = F.mixed_precision_cast(mstype.float16, x)
|
||||||
|
|
Loading…
Reference in New Issue