!41391 fix ms_function bug

Merge pull request !41391 from luochao60/feat_fix_msfunction_bug_20220903
This commit is contained in:
i-robot 2022-09-19 17:16:31 +00:00 committed by Gitee
commit 10ee9cef2b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 3 deletions

View File

@ -17,7 +17,7 @@
"""Basic composite operations.""" """Basic composite operations."""
from functools import partial from functools import partial
from types import FunctionType from types import FunctionType, MethodType
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
@ -410,7 +410,7 @@ class GradOperation(GradOperation_):
else: else:
new_kwargs = kwargs.copy() new_kwargs = kwargs.copy()
new_kwargs.pop('sens') new_kwargs.pop('sens')
if isinstance(fn, FunctionType): if isinstance(fn, (FunctionType, MethodType)):
if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs): if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs):
_pynative_executor.set_grad_flag(True) _pynative_executor.set_grad_flag(True)
_pynative_executor.new_graph(fn, *args, **new_kwargs) _pynative_executor.new_graph(fn, *args, **new_kwargs)
@ -570,7 +570,7 @@ class _Grad(GradOperation_):
new_kwargs.pop('sens') new_kwargs.pop('sens')
else: else:
args = args[:-1] args = args[:-1]
if isinstance(fn, FunctionType): if isinstance(fn, (FunctionType, MethodType)):
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs): if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
_pynative_executor.set_grad_flag(True) _pynative_executor.set_grad_flag(True)
_pynative_executor.new_graph(fn, *args, **new_kwargs) _pynative_executor.new_graph(fn, *args, **new_kwargs)