Fix the loss of const_arg attr when using ms_function

This commit is contained in:
yujianfeng 2022-09-16 14:35:38 +08:00
parent 66ba9e952b
commit af60c8833f
2 changed files with 7 additions and 2 deletions

View File

@ -217,7 +217,8 @@ def _restore_mutable_attr(args_list, compile_args):
"""Restore the mutable attr for every arg."""
new_compile_args = ()
for idx, arg in enumerate(args_list):
if hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__"):
if hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__") and \
not (hasattr(arg, "const_arg") and getattr(arg, "const_arg")):
new_compile_args += (mutable(compile_args[idx]),)
else:
new_compile_args += (compile_args[idx],)

View File

@ -212,7 +212,7 @@ def test_grad_constant_tensor():
def test_grad_constant_tensor_mixed_call():
"""
Feature: Set mutable tensor input to constant.
Description: Get gradient with respect to the constant tensor input.
Description: Get gradient with respect to the constant tensor input for mixed call of mutable and const_arg.
Expectation: Get an empty gradient.
"""
@ -243,6 +243,10 @@ def test_grad_constant_tensor_mixed_call():
output = grad_net(x, y)
assert isinstance(output, tuple)
assert output == ()
grad_net = GradOperation()(Net())
output = grad_net(x, y)
assert isinstance(output, tuple)
assert output == ()
def test_ms_function_grad_constant_tensor():