!41391 fix ms_function bug

Merge pull request !41391 from luochao60/feat_fix_msfunction_bug_20220903
This commit is contained in:
i-robot 2022-09-19 17:16:31 +00:00 committed by Gitee
commit 10ee9cef2b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 3 deletions

View File

@ -17,7 +17,7 @@
"""Basic composite operations."""
from functools import partial
from types import FunctionType
from types import FunctionType, MethodType
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
@ -410,7 +410,7 @@ class GradOperation(GradOperation_):
else:
new_kwargs = kwargs.copy()
new_kwargs.pop('sens')
if isinstance(fn, FunctionType):
if isinstance(fn, (FunctionType, MethodType)):
if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs):
_pynative_executor.set_grad_flag(True)
_pynative_executor.new_graph(fn, *args, **new_kwargs)
@ -570,7 +570,7 @@ class _Grad(GradOperation_):
new_kwargs.pop('sens')
else:
args = args[:-1]
if isinstance(fn, FunctionType):
if isinstance(fn, (FunctionType, MethodType)):
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
_pynative_executor.set_grad_flag(True)
_pynative_executor.new_graph(fn, *args, **new_kwargs)