diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index e222706e1d0..807b01d9646 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -1500,10 +1500,10 @@ class Cell(Cell_): def run_forward_pre_hook(self, inputs): """ - Running forward pre hook function registered in Cell object. + Running forward pre hook function registered on cell object. Args: - inputs: The input objects of Cell object. + inputs: The input objects of cell object. Returns: - **outputs** - New input objects or none. @@ -1523,16 +1523,17 @@ class Cell(Cell_): def register_forward_pre_hook(self, hook_fn): """ - Register forward pre hook function for Cell object. Note that this function is only supported in pynative mode. + Register forward pre hook function for cell object. Note that this function is only supported in pynative mode. Note: - 'hook_fn' must be defined as the following code. - `cell_id` is the information of registered Cell object. `inputs` is the forward input objects passed to - the Cell. The 'hook_fn' can modify the forward input objects by returning new forward input objects. - It should have the following signature: - hook_fn(cell_id, inputs) -> new input objects or none. - In order to prevent running failed when switching to graph mode, it is not recommended to write in the - construct. + - 'hook_fn' must be defined as the following code. + `cell_id` is the information of registered cell object, including name and ID. `inputs` is the forward + input objects passed to the cell. The 'hook_fn' can modify the forward input objects by returning new + forward input objects. + - It should have the following signature: + hook_fn(cell_id, inputs) -> new input objects or none. + - In order to prevent running failed when switching to graph mode, it is not recommended to write in the + construct. Args: hook_fn (function): Python function. Forward pre hook function. @@ -1598,11 +1599,11 @@ class Cell(Cell_): def run_forward_hook(self, inputs, output): """ - Running forward hook function registered in Cell object. + Running forward hook function registered on cell object. Args: - inputs: The input objects of Cell object. - output: The output object of Cell object. + inputs: The input objects of cell object. + output: The output object of cell object. Returns: - **output** - New output object or none. @@ -1622,14 +1623,14 @@ class Cell(Cell_): Set the cell forward hook function. Note that this function is only supported in pynative mode. Note: - 'hook_fn' must be defined as the following code. - `cell_id` is the information of registered Cell object. `inputs` is the forward input objects passed to - the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can modify the forward output - object by returning new forward output object. - It should have the following signature: - hook_fn(cell_id, inputs, output) -> new output object or none. - In order to prevent running failed when switching to graph mode, it is not recommended to write in the - construct. + - 'hook_fn' must be defined as the following code. + `cell_id` is the information of registered cell object, including name and ID. `inputs` is the forward + input objects passed to the cell. `output` is the forward output object of the cell. The 'hook_fn' can + modify the forward output object by returning new forward output object. + - It should have the following signature: + hook_fn(cell_id, inputs, output) -> new output object or none. + - In order to prevent running failed when switching to graph mode, it is not recommended to write in the + construct. Args: hook_fn (function): Python function. Forward hook function. @@ -1700,10 +1701,10 @@ class Cell(Cell_): Backward hook construct method to replace original construct method. Args: - inputs: The input objects of Cell object. + inputs: The input objects of cell object. Returns: - - **outputs** - The output objects of Cell object. + - **outputs** - The output objects of cell object. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -1724,15 +1725,14 @@ class Cell(Cell_): Register the backward hook function. Note that this function is only supported in pynative mode. Note: - The 'hook_fn' must be defined as the following code. - `cell_id` is the information of registered cell. `grad_input` is the gradient passed to the cell. - `grad_output` is the gradient computed and passed to the next cell or primitive, which may be modified by - returning a new output gradient. - The 'hook_fn' should have the following signature: - hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none. - The 'hook_fn' is executed in the python environment. - In order to prevent running failed when switching to graph mode, it is not recommended to write in the - construct. + - The 'hook_fn' must be defined as the following code. + `cell_id` is the information of registered cell, including name and ID. `grad_input` is the gradient + passed to the cell. `grad_output` is the gradient computed and passed to the next cell or primitive, + which may be modified by returning a new output gradient. + - The 'hook_fn' should have the following signature: + hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none. + - The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to + graph mode, it is not recommended to write in the construct. Args: hook_fn (function): Python function. Backward hook function. diff --git a/mindspore/python/mindspore/ops/operations/debug_ops.py b/mindspore/python/mindspore/ops/operations/debug_ops.py index 49d85730b57..ac338b735e3 100644 --- a/mindspore/python/mindspore/ops/operations/debug_ops.py +++ b/mindspore/python/mindspore/ops/operations/debug_ops.py @@ -16,6 +16,7 @@ from types import FunctionType, MethodType from mindspore import context +from mindspore import log as logger from mindspore._c_expression import security from ..._checkparam import Validator as validator from ..._checkparam import Rel @@ -334,6 +335,11 @@ class HookBackward(PrimitiveWithInfer): Args: hook_fn (Function): Python function. hook function. + cell_id (str): Used to identify whether the function registered by the hook is actually registered on + the specified cell object. For example, 'nn.Conv2d' is a cell object. + The default value of cell_id is empty string(""), in this case, the system will automatically + register a value of cell_id. + The value of cell_id currently does not support custom values. Inputs: - **input** (Tensor) - The variable to hook. @@ -375,15 +381,18 @@ class HookBackward(PrimitiveWithInfer): (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4)) """ - def __init__(self, hook_fn): + def __init__(self, hook_fn, cell_id=""): """Initialize HookBackward.""" super(HookBackward, self).__init__(self.__class__.__name__) if not isinstance(hook_fn, (FunctionType, MethodType)): raise TypeError(f"For '{self.name}', the type of 'hook_fn' should be python function, " f"but got {type(hook_fn)}.") - self.add_prim_attr("cell_id", "") - self.init_attrs["cell_id"] = "" - self.cell_id = "" + if cell_id != "": + logger.warning(f"The args 'cell_id' of HookBackward will be removed in a future version. If the value of " + f"'cell_id' is set, the hook function will not work.") + self.add_prim_attr("cell_id", cell_id) + self.init_attrs["cell_id"] = cell_id + self.cell_id = cell_id self.add_backward_hook_fn(hook_fn) def infer_shape(self, *inputs_shape):