From ebbe807d65498f7afd0cd0694340884d0764a327 Mon Sep 17 00:00:00 2001 From: luochao Date: Sat, 3 Sep 2022 09:42:45 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=B1=BB=E5=86=85=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E6=97=A0=E6=B3=95=E6=B1=82=E6=A2=AF=E5=BA=A6=E7=9A=84?= =?UTF-8?q?BUG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspore/python/mindspore/ops/composite/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/python/mindspore/ops/composite/base.py b/mindspore/python/mindspore/ops/composite/base.py index cab948f4ede..995dc0bd214 100644 --- a/mindspore/python/mindspore/ops/composite/base.py +++ b/mindspore/python/mindspore/ops/composite/base.py @@ -17,7 +17,7 @@ """Basic composite operations.""" from functools import partial -from types import FunctionType +from types import FunctionType, MethodType import mindspore as ms import mindspore.nn as nn from mindspore import context @@ -411,7 +411,7 @@ class GradOperation(GradOperation_): else: new_kwargs = kwargs.copy() new_kwargs.pop('sens') - if isinstance(fn, FunctionType): + if isinstance(fn, (FunctionType, MethodType)): if not _pynative_executor.check_run(grad, fn, self.weights_id, *args, **new_kwargs): _pynative_executor.set_grad_flag(True) _pynative_executor.new_graph(fn, *args, **new_kwargs) @@ -571,7 +571,7 @@ class _Grad(GradOperation_): new_kwargs.pop('sens') else: args = args[:-1] - if isinstance(fn, FunctionType): + if isinstance(fn, (FunctionType, MethodType)): if not _pynative_executor.check_run(grad, fn, self.grad_hash_id, *args, **new_kwargs): _pynative_executor.set_grad_flag(True) _pynative_executor.new_graph(fn, *args, **new_kwargs)