forked from mindspore-Ecosystem/mindspore
add_warning_info_for_cell_id_in_hookbackward_op
This commit is contained in:
parent
8b547ef653
commit
26a232e8f2
|
@ -1500,10 +1500,10 @@ class Cell(Cell_):
|
|||
|
||||
def run_forward_pre_hook(self, inputs):
|
||||
"""
|
||||
Running forward pre hook function registered in Cell object.
|
||||
Running forward pre hook function registered on cell object.
|
||||
|
||||
Args:
|
||||
inputs: The input objects of Cell object.
|
||||
inputs: The input objects of cell object.
|
||||
|
||||
Returns:
|
||||
- **outputs** - New input objects or none.
|
||||
|
@ -1523,16 +1523,17 @@ class Cell(Cell_):
|
|||
|
||||
def register_forward_pre_hook(self, hook_fn):
|
||||
"""
|
||||
Register forward pre hook function for Cell object. Note that this function is only supported in pynative mode.
|
||||
Register forward pre hook function for cell object. Note that this function is only supported in pynative mode.
|
||||
|
||||
Note:
|
||||
'hook_fn' must be defined as the following code.
|
||||
`cell_id` is the information of registered Cell object. `inputs` is the forward input objects passed to
|
||||
the Cell. The 'hook_fn' can modify the forward input objects by returning new forward input objects.
|
||||
It should have the following signature:
|
||||
hook_fn(cell_id, inputs) -> new input objects or none.
|
||||
In order to prevent running failed when switching to graph mode, it is not recommended to write in the
|
||||
construct.
|
||||
- 'hook_fn' must be defined as the following code.
|
||||
`cell_id` is the information of registered cell object, including name and ID. `inputs` is the forward
|
||||
input objects passed to the cell. The 'hook_fn' can modify the forward input objects by returning new
|
||||
forward input objects.
|
||||
- It should have the following signature:
|
||||
hook_fn(cell_id, inputs) -> new input objects or none.
|
||||
- In order to prevent running failed when switching to graph mode, it is not recommended to write in the
|
||||
construct.
|
||||
|
||||
Args:
|
||||
hook_fn (function): Python function. Forward pre hook function.
|
||||
|
@ -1598,11 +1599,11 @@ class Cell(Cell_):
|
|||
|
||||
def run_forward_hook(self, inputs, output):
|
||||
"""
|
||||
Running forward hook function registered in Cell object.
|
||||
Running forward hook function registered on cell object.
|
||||
|
||||
Args:
|
||||
inputs: The input objects of Cell object.
|
||||
output: The output object of Cell object.
|
||||
inputs: The input objects of cell object.
|
||||
output: The output object of cell object.
|
||||
|
||||
Returns:
|
||||
- **output** - New output object or none.
|
||||
|
@ -1622,14 +1623,14 @@ class Cell(Cell_):
|
|||
Set the cell forward hook function. Note that this function is only supported in pynative mode.
|
||||
|
||||
Note:
|
||||
'hook_fn' must be defined as the following code.
|
||||
`cell_id` is the information of registered Cell object. `inputs` is the forward input objects passed to
|
||||
the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can modify the forward output
|
||||
object by returning new forward output object.
|
||||
It should have the following signature:
|
||||
hook_fn(cell_id, inputs, output) -> new output object or none.
|
||||
In order to prevent running failed when switching to graph mode, it is not recommended to write in the
|
||||
construct.
|
||||
- 'hook_fn' must be defined as the following code.
|
||||
`cell_id` is the information of registered cell object, including name and ID. `inputs` is the forward
|
||||
input objects passed to the cell. `output` is the forward output object of the cell. The 'hook_fn' can
|
||||
modify the forward output object by returning new forward output object.
|
||||
- It should have the following signature:
|
||||
hook_fn(cell_id, inputs, output) -> new output object or none.
|
||||
- In order to prevent running failed when switching to graph mode, it is not recommended to write in the
|
||||
construct.
|
||||
|
||||
Args:
|
||||
hook_fn (function): Python function. Forward hook function.
|
||||
|
@ -1700,10 +1701,10 @@ class Cell(Cell_):
|
|||
Backward hook construct method to replace original construct method.
|
||||
|
||||
Args:
|
||||
inputs: The input objects of Cell object.
|
||||
inputs: The input objects of cell object.
|
||||
|
||||
Returns:
|
||||
- **outputs** - The output objects of Cell object.
|
||||
- **outputs** - The output objects of cell object.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
@ -1724,15 +1725,14 @@ class Cell(Cell_):
|
|||
Register the backward hook function. Note that this function is only supported in pynative mode.
|
||||
|
||||
Note:
|
||||
The 'hook_fn' must be defined as the following code.
|
||||
`cell_id` is the information of registered cell. `grad_input` is the gradient passed to the cell.
|
||||
`grad_output` is the gradient computed and passed to the next cell or primitive, which may be modified by
|
||||
returning a new output gradient.
|
||||
The 'hook_fn' should have the following signature:
|
||||
hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
|
||||
The 'hook_fn' is executed in the python environment.
|
||||
In order to prevent running failed when switching to graph mode, it is not recommended to write in the
|
||||
construct.
|
||||
- The 'hook_fn' must be defined as the following code.
|
||||
`cell_id` is the information of registered cell, including name and ID. `grad_input` is the gradient
|
||||
passed to the cell. `grad_output` is the gradient computed and passed to the next cell or primitive,
|
||||
which may be modified by returning a new output gradient.
|
||||
- The 'hook_fn' should have the following signature:
|
||||
hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
|
||||
- The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
|
||||
graph mode, it is not recommended to write in the construct.
|
||||
|
||||
Args:
|
||||
hook_fn (function): Python function. Backward hook function.
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from types import FunctionType, MethodType
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore._c_expression import security
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
|
@ -334,6 +335,11 @@ class HookBackward(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
hook_fn (Function): Python function. hook function.
|
||||
cell_id (str): Used to identify whether the function registered by the hook is actually registered on
|
||||
the specified cell object. For example, 'nn.Conv2d' is a cell object.
|
||||
The default value of cell_id is empty string(""), in this case, the system will automatically
|
||||
register a value of cell_id.
|
||||
The value of cell_id currently does not support custom values.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The variable to hook.
|
||||
|
@ -375,15 +381,18 @@ class HookBackward(PrimitiveWithInfer):
|
|||
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
|
||||
"""
|
||||
|
||||
def __init__(self, hook_fn):
|
||||
def __init__(self, hook_fn, cell_id=""):
|
||||
"""Initialize HookBackward."""
|
||||
super(HookBackward, self).__init__(self.__class__.__name__)
|
||||
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
||||
raise TypeError(f"For '{self.name}', the type of 'hook_fn' should be python function, "
|
||||
f"but got {type(hook_fn)}.")
|
||||
self.add_prim_attr("cell_id", "")
|
||||
self.init_attrs["cell_id"] = ""
|
||||
self.cell_id = ""
|
||||
if cell_id != "":
|
||||
logger.warning(f"The args 'cell_id' of HookBackward will be removed in a future version. If the value of "
|
||||
f"'cell_id' is set, the hook function will not work.")
|
||||
self.add_prim_attr("cell_id", cell_id)
|
||||
self.init_attrs["cell_id"] = cell_id
|
||||
self.cell_id = cell_id
|
||||
self.add_backward_hook_fn(hook_fn)
|
||||
|
||||
def infer_shape(self, *inputs_shape):
|
||||
|
|
Loading…
Reference in New Issue