forked from mindspore-Ecosystem/mindspore
!41391 fix ms_function bug
Merge pull request !41391 from luochao60/feat_fix_msfunction_bug_20220903
This commit is contained in:
commit
10ee9cef2b
|
@ -17,7 +17,7 @@
|
|||
|
||||
"""Basic composite operations."""
|
||||
from functools import partial
|
||||
from types import FunctionType
|
||||
from types import FunctionType, MethodType
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
|
@ -410,7 +410,7 @@ class GradOperation(GradOperation_):
|
|||
else:
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.pop('sens')
|
||||
if isinstance(fn, FunctionType):
|
||||
if isinstance(fn, (FunctionType, MethodType)):
|
||||
if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs):
|
||||
_pynative_executor.set_grad_flag(True)
|
||||
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
||||
|
@ -570,7 +570,7 @@ class _Grad(GradOperation_):
|
|||
new_kwargs.pop('sens')
|
||||
else:
|
||||
args = args[:-1]
|
||||
if isinstance(fn, FunctionType):
|
||||
if isinstance(fn, (FunctionType, MethodType)):
|
||||
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
|
||||
_pynative_executor.set_grad_flag(True)
|
||||
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
||||
|
|
Loading…
Reference in New Issue