add_warning_info_for_cell_id_in_hookbackward_op

This commit is contained in:
7347157+joylvliang@user.noreply.gitee.com 2022-02-26 17:58:58 +08:00
parent 8b547ef653
commit 26a232e8f2
2 changed files with 45 additions and 36 deletions

View File

@ -1500,10 +1500,10 @@ class Cell(Cell_):
def run_forward_pre_hook(self, inputs):
"""
Running forward pre hook function registered in Cell object.
Running forward pre hook function registered on cell object.
Args:
inputs: The input objects of Cell object.
inputs: The input objects of cell object.
Returns:
- **outputs** - New input objects or none.
@ -1523,16 +1523,17 @@ class Cell(Cell_):
def register_forward_pre_hook(self, hook_fn):
"""
Register forward pre hook function for Cell object. Note that this function is only supported in pynative mode.
Register forward pre hook function for cell object. Note that this function is only supported in pynative mode.
Note:
'hook_fn' must be defined as the following code.
`cell_id` is the information of registered Cell object. `inputs` is the forward input objects passed to
the Cell. The 'hook_fn' can modify the forward input objects by returning new forward input objects.
It should have the following signature:
hook_fn(cell_id, inputs) -> new input objects or none.
In order to prevent running failed when switching to graph mode, it is not recommended to write in the
construct.
- 'hook_fn' must be defined as the following code.
`cell_id` is the information of registered cell object, including name and ID. `inputs` is the forward
input objects passed to the cell. The 'hook_fn' can modify the forward input objects by returning new
forward input objects.
- It should have the following signature:
hook_fn(cell_id, inputs) -> new input objects or none.
- In order to prevent running failed when switching to graph mode, it is not recommended to write in the
construct.
Args:
hook_fn (function): Python function. Forward pre hook function.
@ -1598,11 +1599,11 @@ class Cell(Cell_):
def run_forward_hook(self, inputs, output):
"""
Running forward hook function registered in Cell object.
Running forward hook function registered on cell object.
Args:
inputs: The input objects of Cell object.
output: The output object of Cell object.
inputs: The input objects of cell object.
output: The output object of cell object.
Returns:
- **output** - New output object or none.
@ -1622,14 +1623,14 @@ class Cell(Cell_):
Set the cell forward hook function. Note that this function is only supported in pynative mode.
Note:
'hook_fn' must be defined as the following code.
`cell_id` is the information of registered Cell object. `inputs` is the forward input objects passed to
the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can modify the forward output
object by returning new forward output object.
It should have the following signature:
hook_fn(cell_id, inputs, output) -> new output object or none.
In order to prevent running failed when switching to graph mode, it is not recommended to write in the
construct.
- 'hook_fn' must be defined as the following code.
`cell_id` is the information of registered cell object, including name and ID. `inputs` is the forward
input objects passed to the cell. `output` is the forward output object of the cell. The 'hook_fn' can
modify the forward output object by returning new forward output object.
- It should have the following signature:
hook_fn(cell_id, inputs, output) -> new output object or none.
- In order to prevent running failed when switching to graph mode, it is not recommended to write in the
construct.
Args:
hook_fn (function): Python function. Forward hook function.
@ -1700,10 +1701,10 @@ class Cell(Cell_):
Backward hook construct method to replace original construct method.
Args:
inputs: The input objects of Cell object.
inputs: The input objects of cell object.
Returns:
- **outputs** - The output objects of Cell object.
- **outputs** - The output objects of cell object.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -1724,15 +1725,14 @@ class Cell(Cell_):
Register the backward hook function. Note that this function is only supported in pynative mode.
Note:
The 'hook_fn' must be defined as the following code.
`cell_id` is the information of registered cell. `grad_input` is the gradient passed to the cell.
`grad_output` is the gradient computed and passed to the next cell or primitive, which may be modified by
returning a new output gradient.
The 'hook_fn' should have the following signature:
hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
The 'hook_fn' is executed in the python environment.
In order to prevent running failed when switching to graph mode, it is not recommended to write in the
construct.
- The 'hook_fn' must be defined as the following code.
`cell_id` is the information of registered cell, including name and ID. `grad_input` is the gradient
passed to the cell. `grad_output` is the gradient computed and passed to the next cell or primitive,
which may be modified by returning a new output gradient.
- The 'hook_fn' should have the following signature:
hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
- The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
graph mode, it is not recommended to write in the construct.
Args:
hook_fn (function): Python function. Backward hook function.

View File

@ -16,6 +16,7 @@
from types import FunctionType, MethodType
from mindspore import context
from mindspore import log as logger
from mindspore._c_expression import security
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
@ -334,6 +335,11 @@ class HookBackward(PrimitiveWithInfer):
Args:
hook_fn (Function): Python function. hook function.
cell_id (str): Used to identify whether the function registered by the hook is actually registered on
the specified cell object. For example, 'nn.Conv2d' is a cell object.
The default value of cell_id is empty string(""), in this case, the system will automatically
register a value of cell_id.
The value of cell_id currently does not support custom values.
Inputs:
- **input** (Tensor) - The variable to hook.
@ -375,15 +381,18 @@ class HookBackward(PrimitiveWithInfer):
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
"""
def __init__(self, hook_fn):
def __init__(self, hook_fn, cell_id=""):
"""Initialize HookBackward."""
super(HookBackward, self).__init__(self.__class__.__name__)
if not isinstance(hook_fn, (FunctionType, MethodType)):
raise TypeError(f"For '{self.name}', the type of 'hook_fn' should be python function, "
f"but got {type(hook_fn)}.")
self.add_prim_attr("cell_id", "")
self.init_attrs["cell_id"] = ""
self.cell_id = ""
if cell_id != "":
logger.warning(f"The args 'cell_id' of HookBackward will be removed in a future version. If the value of "
f"'cell_id' is set, the hook function will not work.")
self.add_prim_attr("cell_id", cell_id)
self.init_attrs["cell_id"] = cell_id
self.cell_id = cell_id
self.add_backward_hook_fn(hook_fn)
def infer_shape(self, *inputs_shape):