Fix the loss of const_arg attr when using ms_function
This commit is contained in:
parent
66ba9e952b
commit
af60c8833f
|
@ -217,7 +217,8 @@ def _restore_mutable_attr(args_list, compile_args):
|
|||
"""Restore the mutable attr for every arg."""
|
||||
new_compile_args = ()
|
||||
for idx, arg in enumerate(args_list):
|
||||
if hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__"):
|
||||
if hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__") and \
|
||||
not (hasattr(arg, "const_arg") and getattr(arg, "const_arg")):
|
||||
new_compile_args += (mutable(compile_args[idx]),)
|
||||
else:
|
||||
new_compile_args += (compile_args[idx],)
|
||||
|
|
|
@ -212,7 +212,7 @@ def test_grad_constant_tensor():
|
|||
def test_grad_constant_tensor_mixed_call():
|
||||
"""
|
||||
Feature: Set mutable tensor input to constant.
|
||||
Description: Get gradient with respect to the constant tensor input.
|
||||
Description: Get gradient with respect to the constant tensor input for mixed call of mutable and const_arg.
|
||||
Expectation: Get an empty gradient.
|
||||
"""
|
||||
|
||||
|
@ -243,6 +243,10 @@ def test_grad_constant_tensor_mixed_call():
|
|||
output = grad_net(x, y)
|
||||
assert isinstance(output, tuple)
|
||||
assert output == ()
|
||||
grad_net = GradOperation()(Net())
|
||||
output = grad_net(x, y)
|
||||
assert isinstance(output, tuple)
|
||||
assert output == ()
|
||||
|
||||
|
||||
def test_ms_function_grad_constant_tensor():
|
||||
|
|
Loading…
Reference in New Issue