From af60c8833f135f613758585d79920a24a0438b4b Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Fri, 16 Sep 2022 14:35:38 +0800 Subject: [PATCH] Fix the loss of const_arg attr when using ms_function --- mindspore/python/mindspore/common/api.py | 3 ++- tests/ut/python/ir/test_const_arg_tensor.py | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mindspore/python/mindspore/common/api.py b/mindspore/python/mindspore/common/api.py index 25772bf9843..8e242eb1b2a 100644 --- a/mindspore/python/mindspore/common/api.py +++ b/mindspore/python/mindspore/common/api.py @@ -217,7 +217,8 @@ def _restore_mutable_attr(args_list, compile_args): """Restore the mutable attr for every arg.""" new_compile_args = () for idx, arg in enumerate(args_list): - if hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__"): + if hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__") and \ + not (hasattr(arg, "const_arg") and getattr(arg, "const_arg")): new_compile_args += (mutable(compile_args[idx]),) else: new_compile_args += (compile_args[idx],) diff --git a/tests/ut/python/ir/test_const_arg_tensor.py b/tests/ut/python/ir/test_const_arg_tensor.py index 8d29d7b8ae3..bc5d2c554d5 100644 --- a/tests/ut/python/ir/test_const_arg_tensor.py +++ b/tests/ut/python/ir/test_const_arg_tensor.py @@ -212,7 +212,7 @@ def test_grad_constant_tensor(): def test_grad_constant_tensor_mixed_call(): """ Feature: Set mutable tensor input to constant. - Description: Get gradient with respect to the constant tensor input. + Description: Get gradient with respect to the constant tensor input for mixed call of mutable and const_arg. Expectation: Get an empty gradient. """ @@ -243,6 +243,10 @@ def test_grad_constant_tensor_mixed_call(): output = grad_net(x, y) assert isinstance(output, tuple) assert output == () + grad_net = GradOperation()(Net()) + output = grad_net(x, y) + assert isinstance(output, tuple) + assert output == () def test_ms_function_grad_constant_tensor():