forked from mindspore-Ecosystem/mindspore
!41391 fix ms_function bug
Merge pull request !41391 from luochao60/feat_fix_msfunction_bug_20220903
This commit is contained in:
commit
10ee9cef2b
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
"""Basic composite operations."""
|
"""Basic composite operations."""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from types import FunctionType
|
from types import FunctionType, MethodType
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -410,7 +410,7 @@ class GradOperation(GradOperation_):
|
||||||
else:
|
else:
|
||||||
new_kwargs = kwargs.copy()
|
new_kwargs = kwargs.copy()
|
||||||
new_kwargs.pop('sens')
|
new_kwargs.pop('sens')
|
||||||
if isinstance(fn, FunctionType):
|
if isinstance(fn, (FunctionType, MethodType)):
|
||||||
if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs):
|
if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs):
|
||||||
_pynative_executor.set_grad_flag(True)
|
_pynative_executor.set_grad_flag(True)
|
||||||
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
||||||
|
@ -570,7 +570,7 @@ class _Grad(GradOperation_):
|
||||||
new_kwargs.pop('sens')
|
new_kwargs.pop('sens')
|
||||||
else:
|
else:
|
||||||
args = args[:-1]
|
args = args[:-1]
|
||||||
if isinstance(fn, FunctionType):
|
if isinstance(fn, (FunctionType, MethodType)):
|
||||||
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
|
if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs):
|
||||||
_pynative_executor.set_grad_flag(True)
|
_pynative_executor.set_grad_flag(True)
|
||||||
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
_pynative_executor.new_graph(fn, *args, **new_kwargs)
|
||||||
|
|
Loading…
Reference in New Issue