forked from mindspore-Ecosystem/mindspore
search_hook_for_all_sub_cell
This commit is contained in:
parent
a6454e02e4
commit
a3f967dc33
|
@ -11,6 +11,7 @@
|
|||
"mindspore/mindspore/python/mindspore/_check_deps_version.py" "broad-except"
|
||||
"mindspore/mindspore/python/mindspore/_check_version.py" "unused-import"
|
||||
"mindspore/mindspore/python/mindspore/_check_version.py" "broad-except"
|
||||
"mindspore/mindspore/python/mindspore/common/api.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/common/parameter.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/common/hook_handle.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/common/dtype.py" "undefined-all-variable"
|
||||
|
|
|
@ -268,6 +268,12 @@ class _MindsporeFunctionExecutor:
|
|||
|
||||
def compile(self, args_list, method_name):
|
||||
"""Returns pipeline for the given args."""
|
||||
# Check whether hook function registered on Cell object.
|
||||
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
|
||||
if self.obj._hook_fn_registered():
|
||||
logger.warning(f"For 'Cell', it's not support hook function when using ms_function. If you want to "
|
||||
f"use hook function, please use context.set_context to set pynative mode and remove "
|
||||
f"`ms_function`.")
|
||||
# Verify the signature for both function and method
|
||||
if self.input_signature is not None:
|
||||
signatures = []
|
||||
|
|
|
@ -439,6 +439,14 @@ class Cell(Cell_):
|
|||
f"{default_args} default argument, total {positional_args + default_args}, "
|
||||
f"but got {len(inputs)}.")
|
||||
|
||||
def _hook_fn_registered(self):
|
||||
if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
|
||||
return True
|
||||
for cell in self.cells():
|
||||
if cell._hook_fn_registered():
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_prims_recursively(self):
|
||||
all_prims = list()
|
||||
for _, value in self._primitives.items():
|
||||
|
@ -555,9 +563,9 @@ class Cell(Cell_):
|
|||
# Run in Graph mode.
|
||||
if context._get_mode() == context.GRAPH_MODE:
|
||||
self._check_construct_args(*args, **kwargs)
|
||||
if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
|
||||
logger.warning(f"For 'Cell', it's not support hook function in graph mode, please use "
|
||||
f"context.set_context to set pynative mode.")
|
||||
if self._hook_fn_registered():
|
||||
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
||||
f"function, please use context.set_context to set pynative mode.")
|
||||
out = self.compile_and_run(*args)
|
||||
return out
|
||||
|
||||
|
@ -1755,15 +1763,15 @@ class Cell(Cell_):
|
|||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
inputs = self._cell_backward_hook(*inputs)
|
||||
if isinstance(inputs, tuple):
|
||||
inputs = self.construct(*inputs)
|
||||
if len(inputs) > 1:
|
||||
inputs = self._cell_backward_hook(inputs)
|
||||
else:
|
||||
inputs = self.construct(inputs)
|
||||
inputs = self._cell_backward_hook(*inputs)
|
||||
if isinstance(inputs, tuple):
|
||||
outputs = self._cell_backward_hook(*inputs)
|
||||
outputs = self.construct(*inputs)
|
||||
else:
|
||||
outputs = self._cell_backward_hook(inputs)
|
||||
outputs = self.construct(inputs)
|
||||
outputs = self._cell_backward_hook(outputs)
|
||||
return outputs
|
||||
|
||||
def register_backward_hook(self, hook_fn):
|
||||
|
|
|
@ -20,11 +20,12 @@ import numpy as np
|
|||
from mindspore.common import Tensor
|
||||
from .. import signature as sig
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive
|
||||
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, _run_op
|
||||
from ... import context
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ...common.parameter import Parameter
|
||||
from ...communication.management import GlobalComm
|
||||
|
||||
|
||||
|
@ -1764,6 +1765,14 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|||
self.add_prim_attr("cell_id", cell_id)
|
||||
self.init_attrs["cell_id"] = cell_id
|
||||
|
||||
def __call__(self, args):
|
||||
if not isinstance(args, tuple):
|
||||
args = (args,)
|
||||
for arg in args:
|
||||
if isinstance(arg, Parameter) and arg.has_init:
|
||||
arg.init_data()
|
||||
return _run_op(self, self.name, args)
|
||||
|
||||
def infer_shape(self, *inputs_shape):
|
||||
if len(inputs_shape) == 1:
|
||||
return inputs_shape[0]
|
||||
|
|
|
@ -247,7 +247,8 @@ class Primitive(Primitive_):
|
|||
return super().get_attr_dict()[item]
|
||||
if item in self.attrs:
|
||||
return self.attrs[item]
|
||||
raise AttributeError(item)
|
||||
err_msg = "'{prim}' object has no attribute '{attr}'".format(prim=self.name, attr=item)
|
||||
raise AttributeError(err_msg)
|
||||
|
||||
def check_elim(self, *args):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue