search_hook_for_all_sub_cell

This commit is contained in:
7347157+joylvliang@user.noreply.gitee.com 2022-03-12 15:23:49 +08:00
parent a6454e02e4
commit a3f967dc33
5 changed files with 36 additions and 11 deletions

View File

@ -11,6 +11,7 @@
"mindspore/mindspore/python/mindspore/_check_deps_version.py" "broad-except" "mindspore/mindspore/python/mindspore/_check_deps_version.py" "broad-except"
"mindspore/mindspore/python/mindspore/_check_version.py" "unused-import" "mindspore/mindspore/python/mindspore/_check_version.py" "unused-import"
"mindspore/mindspore/python/mindspore/_check_version.py" "broad-except" "mindspore/mindspore/python/mindspore/_check_version.py" "broad-except"
"mindspore/mindspore/python/mindspore/common/api.py" "protected-access"
"mindspore/mindspore/python/mindspore/common/parameter.py" "protected-access" "mindspore/mindspore/python/mindspore/common/parameter.py" "protected-access"
"mindspore/mindspore/python/mindspore/common/hook_handle.py" "protected-access" "mindspore/mindspore/python/mindspore/common/hook_handle.py" "protected-access"
"mindspore/mindspore/python/mindspore/common/dtype.py" "undefined-all-variable" "mindspore/mindspore/python/mindspore/common/dtype.py" "undefined-all-variable"

View File

@ -268,6 +268,12 @@ class _MindsporeFunctionExecutor:
def compile(self, args_list, method_name): def compile(self, args_list, method_name):
"""Returns pipeline for the given args.""" """Returns pipeline for the given args."""
# Check whether hook function registered on Cell object.
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
if self.obj._hook_fn_registered():
logger.warning(f"For 'Cell', it's not support hook function when using ms_function. If you want to "
f"use hook function, please use context.set_context to set pynative mode and remove "
f"`ms_function`.")
# Verify the signature for both function and method # Verify the signature for both function and method
if self.input_signature is not None: if self.input_signature is not None:
signatures = [] signatures = []

View File

@ -439,6 +439,14 @@ class Cell(Cell_):
f"{default_args} default argument, total {positional_args + default_args}, " f"{default_args} default argument, total {positional_args + default_args}, "
f"but got {len(inputs)}.") f"but got {len(inputs)}.")
def _hook_fn_registered(self):
if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
return True
for cell in self.cells():
if cell._hook_fn_registered():
return True
return False
def _get_prims_recursively(self): def _get_prims_recursively(self):
all_prims = list() all_prims = list()
for _, value in self._primitives.items(): for _, value in self._primitives.items():
@ -555,9 +563,9 @@ class Cell(Cell_):
# Run in Graph mode. # Run in Graph mode.
if context._get_mode() == context.GRAPH_MODE: if context._get_mode() == context.GRAPH_MODE:
self._check_construct_args(*args, **kwargs) self._check_construct_args(*args, **kwargs)
if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook: if self._hook_fn_registered():
logger.warning(f"For 'Cell', it's not support hook function in graph mode, please use " logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
f"context.set_context to set pynative mode.") f"function, please use context.set_context to set pynative mode.")
out = self.compile_and_run(*args) out = self.compile_and_run(*args)
return out return out
@ -1755,15 +1763,15 @@ class Cell(Cell_):
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
""" """
inputs = self._cell_backward_hook(*inputs) if len(inputs) > 1:
if isinstance(inputs, tuple): inputs = self._cell_backward_hook(inputs)
inputs = self.construct(*inputs)
else: else:
inputs = self.construct(inputs) inputs = self._cell_backward_hook(*inputs)
if isinstance(inputs, tuple): if isinstance(inputs, tuple):
outputs = self._cell_backward_hook(*inputs) outputs = self.construct(*inputs)
else: else:
outputs = self._cell_backward_hook(inputs) outputs = self.construct(inputs)
outputs = self._cell_backward_hook(outputs)
return outputs return outputs
def register_backward_hook(self, hook_fn): def register_backward_hook(self, hook_fn):

View File

@ -20,11 +20,12 @@ import numpy as np
from mindspore.common import Tensor from mindspore.common import Tensor
from .. import signature as sig from .. import signature as sig
from ..operations.math_ops import _infer_shape_reduce from ..operations.math_ops import _infer_shape_reduce
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, _run_op
from ... import context from ... import context
from ..._checkparam import Rel from ..._checkparam import Rel
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.parameter import Parameter
from ...communication.management import GlobalComm from ...communication.management import GlobalComm
@ -1764,6 +1765,14 @@ class CellBackwardHook(PrimitiveWithInfer):
self.add_prim_attr("cell_id", cell_id) self.add_prim_attr("cell_id", cell_id)
self.init_attrs["cell_id"] = cell_id self.init_attrs["cell_id"] = cell_id
def __call__(self, args):
if not isinstance(args, tuple):
args = (args,)
for arg in args:
if isinstance(arg, Parameter) and arg.has_init:
arg.init_data()
return _run_op(self, self.name, args)
def infer_shape(self, *inputs_shape): def infer_shape(self, *inputs_shape):
if len(inputs_shape) == 1: if len(inputs_shape) == 1:
return inputs_shape[0] return inputs_shape[0]

View File

@ -247,7 +247,8 @@ class Primitive(Primitive_):
return super().get_attr_dict()[item] return super().get_attr_dict()[item]
if item in self.attrs: if item in self.attrs:
return self.attrs[item] return self.attrs[item]
raise AttributeError(item) err_msg = "'{prim}' object has no attribute '{attr}'".format(prim=self.name, attr=item)
raise AttributeError(err_msg)
def check_elim(self, *args): def check_elim(self, *args):
""" """