diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h index 01e417a1610..55d9a22e0eb 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h @@ -138,6 +138,7 @@ class ParameterEliminator { } TraceGuard trace_guard(std::make_shared(caller->debug_info())); auto new_caller = caller->func_graph()->NewCNode(new_args); + new_caller->set_abstract(caller->abstract()); tr->Replace(caller, new_caller); } }; diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 880698e4980..5d1ecbae560 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -750,6 +750,23 @@ def test_while_scalar(): out = net(x, y) +def test_while_with_weight_in_condition(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.loop = Parameter(Tensor(1, dtype=ms.float32), name="loop") + + def construct(self, x): + while self.loop < 5: + self.loop += 1 + x += 1 + return x + + net = Net() + x = Tensor(-1, dtype=ms.float32) + grad_all(net)(x) + + def test_mixed_precision_cast(): x = Tensor(np.ones([2, 3], dtype=np.float32)) z = F.mixed_precision_cast(mstype.float16, x)