transfer_interface_to_inner_interface

This commit is contained in:
7347157+joylvliang@user.noreply.gitee.com 2022-03-04 10:56:37 +08:00
parent d2c23394d8
commit f4e9517220
4 changed files with 65 additions and 47 deletions

View File

@ -12,6 +12,7 @@
"mindspore/mindspore/python/mindspore/_check_version.py" "unused-import"
"mindspore/mindspore/python/mindspore/_check_version.py" "broad-except"
"mindspore/mindspore/python/mindspore/common/parameter.py" "protected-access"
"mindspore/mindspore/python/mindspore/common/hook_handle.py" "protected-access"
"mindspore/mindspore/python/mindspore/common/dtype.py" "undefined-all-variable"
"mindspore/mindspore/python/mindspore/context.py" "protected-access"
"mindspore/mindspore/python/mindspore/ops/operations" "super-init-not-called"
@ -28,6 +29,7 @@
"mindspore/mindspore/python/mindspore/nn/cell.py" "assignment-from-none"
"mindspore/mindspore/python/mindspore/_extends/parse/resources.py" "bad-whitespace"
"mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "broad-except"
"mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "protected-access"
"mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "eval-used"
"mindspore/mindspore/python/mindspore/nn/cell.py" "protected-access"
"mindspore/mindspore/python/mindspore/nn/optim/ftrl.py" "unused-import"

View File

@ -124,8 +124,8 @@ def get_parse_method_of_class(obj, parse_method=None):
if parse_method is not None:
method_name = parse_method
elif isinstance(obj, nn.Cell):
if obj.enable_backward_hook:
method_name = "run_backward_hook"
if obj._enable_backward_hook:
method_name = "_backward_hook_construct"
else:
method_name = "construct"
if method_name is not None:

View File

@ -19,8 +19,8 @@ from .api import _pynative_executor
class HookHandle:
r"""
It is the return object of Cell forward pre hook function, forward hook function and backward hook function.
It corresponds to the cell hook function and is used to remove the cell hook function by calling 'remove()'.
It is the return object of forward pre hook function, forward hook function and backward hook function of Cell
object. It corresponds to the cell hook function and is used to remove the cell hook function by calling 'remove()'.
Note:
It is only supported in pynative mode and works when registering or removing hook function for Cell object.
@ -29,12 +29,9 @@ class HookHandle:
hook_cell (Cell): The Cell object with hook function registered on. Default value: None.
hook_key (int): The key of cell hook function in dict. It is generated during cell hook function registration.
Default value: -1.
hook_type (str): The type of cell hook function: 'forward_pre_hook', 'forward_hook' or 'cell_backward_hook'.
hook_type (str): The type of cell hook function: '_forward_pre_hook', '_forward_hook' or '_cell_backward_hook'.
Default value: "".
Returns:
None.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@ -49,8 +46,8 @@ class HookHandle:
def remove(self):
"""
Remove the cell hook function, which corresponds to this 'HookHandle' object.
In order to prevent running failed when switching to graph mode, it is not recommended to write in the
construct.
In order to prevent running failed when switching to graph mode, it is not recommended to call the `remove()`
function in the construct function of Cell object.
Args:
None.
@ -95,14 +92,14 @@ class HookHandle:
"""
if self._hook_cell is not None:
hook_cell = self._hook_cell()
if self._hook_type == "forward_pre_hook" and self._hook_key in hook_cell.forward_pre_hook:
del hook_cell.forward_pre_hook[self._hook_key]
if self._hook_type == "_forward_pre_hook" and self._hook_key in hook_cell._forward_pre_hook:
del hook_cell._forward_pre_hook[self._hook_key]
_pynative_executor.set_hook_changed(hook_cell)
elif self._hook_type == "forward_hook" and self._hook_key in hook_cell.forward_hook:
del hook_cell.forward_hook[self._hook_key]
elif self._hook_type == "_forward_hook" and self._hook_key in hook_cell._forward_hook:
del hook_cell._forward_hook[self._hook_key]
_pynative_executor.set_hook_changed(hook_cell)
elif self._hook_type == "cell_backward_hook":
hook_cell.cell_backward_hook.remove_backward_hook(self._hook_key)
elif self._hook_type == "_cell_backward_hook":
hook_cell._cell_backward_hook.remove_backward_hook(self._hook_key)
def __del__(self):
self._hook_cell = None

View File

@ -84,8 +84,8 @@ class Cell(Cell_):
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode',
'forward_pre_hook', 'forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
'_bprop_debug', 'enable_backward_hook', 'cell_backward_hook', '_is_run', '_param_prefix',
'_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
'_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
'_attr_synced', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type']
def __init__(self, auto_prefix=True, flags=None):
@ -127,12 +127,12 @@ class Cell(Cell_):
if flags:
self.add_flags(**flags)
self._bprop_debug = False
self.forward_pre_hook = OrderedDict()
self.forward_hook = OrderedDict()
self._forward_pre_hook = OrderedDict()
self._forward_hook = OrderedDict()
self._enable_forward_pre_hook = False
self._enable_forward_hook = False
self.enable_backward_hook = False
self.cell_backward_hook = None
self._enable_backward_hook = False
self._cell_backward_hook = None
self.cell_type = None
self._auto_parallel_compile_and_run = False
self.cast = Cast()
@ -383,17 +383,36 @@ class Cell(Cell_):
self.parameter_broadcast_done = True
def run_construct(self, cast_inputs, kwargs):
"""
Run the construct function.
Note:
This function will be removed in a future version. It is not recommended to call this function.
Args:
cast_inputs (tuple): The input objects of Cell.
kwargs (dict): Provide keyword arguments.
Returns:
output, the output object of Cell.
"""
logger.warning(f"The 'run_construct' function of '{self.cls_name}' will be removed in a future version. "
f"Calling this function is not recommended.")
output = self._run_construct(cast_inputs, kwargs)
return output
def _run_construct(self, cast_inputs, kwargs):
"""Run the construct function"""
if self._enable_forward_pre_hook:
cast_inputs = self.run_forward_pre_hook(cast_inputs)
if self.enable_backward_hook:
output = self.run_backward_hook(*cast_inputs)
cast_inputs = self._run_forward_pre_hook(cast_inputs)
if self._enable_backward_hook:
output = self._backward_hook_construct(*cast_inputs)
elif hasattr(self, "_shard_fn"):
output = self._shard_fn(*cast_inputs, **kwargs)
else:
output = self.construct(*cast_inputs, **kwargs)
if self._enable_forward_hook:
output = self.run_forward_hook(cast_inputs, output)
output = self._run_forward_hook(cast_inputs, output)
return output
def _check_construct_args(self, *inputs, **kwargs):
@ -536,7 +555,7 @@ class Cell(Cell_):
# Run in Graph mode.
if context._get_mode() == context.GRAPH_MODE:
self._check_construct_args(*args, **kwargs)
if self._enable_forward_pre_hook or self._enable_forward_hook or self.enable_backward_hook:
if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
logger.warning(f"For 'Cell', it's not support hook function in graph mode, please use "
f"context.set_context to set pynative mode.")
out = self.compile_and_run(*args)
@ -562,7 +581,7 @@ class Cell(Cell_):
with self.CellGuard():
try:
output = self.run_construct(cast_inputs, kwargs)
output = self._run_construct(cast_inputs, kwargs)
except Exception as err:
_pynative_executor.clear_res()
raise err
@ -1525,7 +1544,7 @@ class Cell(Cell_):
self.add_flags(auto_parallel=True)
self._get_construct_inputs_number_and_name()
def run_forward_pre_hook(self, inputs):
def _run_forward_pre_hook(self, inputs):
"""
Running forward pre hook function registered on cell object.
@ -1539,7 +1558,7 @@ class Cell(Cell_):
``Ascend`` ``GPU`` ``CPU``
"""
cell_id = self.cls_name + "(" + str(id(self)) + ")"
for fn in self.forward_pre_hook.values():
for fn in self._forward_pre_hook.values():
ret = fn(cell_id, inputs)
if ret is not None:
if not isinstance(ret, tuple):
@ -1620,11 +1639,11 @@ class Cell(Cell_):
if not hasattr(self, '_forward_pre_hook_key'):
self._forward_pre_hook_key = -1
self._forward_pre_hook_key += 1
self.forward_pre_hook[self._forward_pre_hook_key] = hook_fn
handle = HookHandle(self, self._forward_pre_hook_key, "forward_pre_hook")
self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn
handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook")
return handle
def run_forward_hook(self, inputs, output):
def _run_forward_hook(self, inputs, output):
"""
Running forward hook function registered on cell object.
@ -1639,7 +1658,7 @@ class Cell(Cell_):
``Ascend`` ``GPU`` ``CPU``
"""
cell_id = self.cls_name + "(" + str(id(self)) + ")"
for fn in self.forward_hook.values():
for fn in self._forward_hook.values():
ret = fn(cell_id, inputs, output)
if ret is not None:
output = ret
@ -1719,11 +1738,11 @@ class Cell(Cell_):
if not hasattr(self, '_forward_hook_key'):
self._forward_hook_key = -1
self._forward_hook_key += 1
self.forward_hook[self._forward_hook_key] = hook_fn
handle = HookHandle(self, self._forward_hook_key, "forward_hook")
self._forward_hook[self._forward_hook_key] = hook_fn
handle = HookHandle(self, self._forward_hook_key, "_forward_hook")
return handle
def run_backward_hook(self, *inputs):
def _backward_hook_construct(self, *inputs):
"""
Backward hook construct method to replace original construct method.
@ -1736,15 +1755,15 @@ class Cell(Cell_):
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
inputs = self.cell_backward_hook(*inputs)
inputs = self._cell_backward_hook(*inputs)
if isinstance(inputs, tuple):
inputs = self.construct(*inputs)
else:
inputs = self.construct(inputs)
if isinstance(inputs, tuple):
outputs = self.cell_backward_hook(*inputs)
outputs = self._cell_backward_hook(*inputs)
else:
outputs = self.cell_backward_hook(inputs)
outputs = self._cell_backward_hook(inputs)
return outputs
def register_backward_hook(self, hook_fn):
@ -1811,14 +1830,14 @@ class Cell(Cell_):
if not isinstance(hook_fn, (FunctionType, MethodType)):
raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' should be python "
f"function, but got {type(hook_fn)}.")
if self.cell_backward_hook is None:
self.enable_backward_hook = True
self.cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
backward_hook_key = self.cell_backward_hook.register_backward_hook(hook_fn)
handle = HookHandle(self, backward_hook_key, "cell_backward_hook")
if self._cell_backward_hook is None:
self._enable_backward_hook = True
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
else:
backward_hook_key = self.cell_backward_hook.register_backward_hook(hook_fn)
handle = HookHandle(self, backward_hook_key, "cell_backward_hook")
backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
return handle
def set_param_ps(self, recurse=True, init_in_server=False):