forked from mindspore-Ecosystem/mindspore
transfer_interface_to_inner_interface
This commit is contained in:
parent
d2c23394d8
commit
f4e9517220
|
@ -12,6 +12,7 @@
|
|||
"mindspore/mindspore/python/mindspore/_check_version.py" "unused-import"
|
||||
"mindspore/mindspore/python/mindspore/_check_version.py" "broad-except"
|
||||
"mindspore/mindspore/python/mindspore/common/parameter.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/common/hook_handle.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/common/dtype.py" "undefined-all-variable"
|
||||
"mindspore/mindspore/python/mindspore/context.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations" "super-init-not-called"
|
||||
|
@ -28,6 +29,7 @@
|
|||
"mindspore/mindspore/python/mindspore/nn/cell.py" "assignment-from-none"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/resources.py" "bad-whitespace"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "broad-except"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "eval-used"
|
||||
"mindspore/mindspore/python/mindspore/nn/cell.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/nn/optim/ftrl.py" "unused-import"
|
||||
|
|
|
@ -124,8 +124,8 @@ def get_parse_method_of_class(obj, parse_method=None):
|
|||
if parse_method is not None:
|
||||
method_name = parse_method
|
||||
elif isinstance(obj, nn.Cell):
|
||||
if obj.enable_backward_hook:
|
||||
method_name = "run_backward_hook"
|
||||
if obj._enable_backward_hook:
|
||||
method_name = "_backward_hook_construct"
|
||||
else:
|
||||
method_name = "construct"
|
||||
if method_name is not None:
|
||||
|
|
|
@ -19,8 +19,8 @@ from .api import _pynative_executor
|
|||
|
||||
class HookHandle:
|
||||
r"""
|
||||
It is the return object of Cell forward pre hook function, forward hook function and backward hook function.
|
||||
It corresponds to the cell hook function and is used to remove the cell hook function by calling 'remove()'.
|
||||
It is the return object of forward pre hook function, forward hook function and backward hook function of Cell
|
||||
object. It corresponds to the cell hook function and is used to remove the cell hook function by calling 'remove()'.
|
||||
|
||||
Note:
|
||||
It is only supported in pynative mode and works when registering or removing hook function for Cell object.
|
||||
|
@ -29,12 +29,9 @@ class HookHandle:
|
|||
hook_cell (Cell): The Cell object with hook function registered on. Default value: None.
|
||||
hook_key (int): The key of cell hook function in dict. It is generated during cell hook function registration.
|
||||
Default value: -1.
|
||||
hook_type (str): The type of cell hook function: 'forward_pre_hook', 'forward_hook' or 'cell_backward_hook'.
|
||||
hook_type (str): The type of cell hook function: '_forward_pre_hook', '_forward_hook' or '_cell_backward_hook'.
|
||||
Default value: "".
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
@ -49,8 +46,8 @@ class HookHandle:
|
|||
def remove(self):
|
||||
"""
|
||||
Remove the cell hook function, which corresponds to this 'HookHandle' object.
|
||||
In order to prevent running failed when switching to graph mode, it is not recommended to write in the
|
||||
construct.
|
||||
In order to prevent running failed when switching to graph mode, it is not recommended to call the `remove()`
|
||||
function in the construct function of Cell object.
|
||||
|
||||
Args:
|
||||
None.
|
||||
|
@ -95,14 +92,14 @@ class HookHandle:
|
|||
"""
|
||||
if self._hook_cell is not None:
|
||||
hook_cell = self._hook_cell()
|
||||
if self._hook_type == "forward_pre_hook" and self._hook_key in hook_cell.forward_pre_hook:
|
||||
del hook_cell.forward_pre_hook[self._hook_key]
|
||||
if self._hook_type == "_forward_pre_hook" and self._hook_key in hook_cell._forward_pre_hook:
|
||||
del hook_cell._forward_pre_hook[self._hook_key]
|
||||
_pynative_executor.set_hook_changed(hook_cell)
|
||||
elif self._hook_type == "forward_hook" and self._hook_key in hook_cell.forward_hook:
|
||||
del hook_cell.forward_hook[self._hook_key]
|
||||
elif self._hook_type == "_forward_hook" and self._hook_key in hook_cell._forward_hook:
|
||||
del hook_cell._forward_hook[self._hook_key]
|
||||
_pynative_executor.set_hook_changed(hook_cell)
|
||||
elif self._hook_type == "cell_backward_hook":
|
||||
hook_cell.cell_backward_hook.remove_backward_hook(self._hook_key)
|
||||
elif self._hook_type == "_cell_backward_hook":
|
||||
hook_cell._cell_backward_hook.remove_backward_hook(self._hook_key)
|
||||
|
||||
def __del__(self):
|
||||
self._hook_cell = None
|
||||
|
|
|
@ -84,8 +84,8 @@ class Cell(Cell_):
|
|||
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
|
||||
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
|
||||
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode',
|
||||
'forward_pre_hook', 'forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
|
||||
'_bprop_debug', 'enable_backward_hook', 'cell_backward_hook', '_is_run', '_param_prefix',
|
||||
'_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
|
||||
'_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
|
||||
'_attr_synced', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type']
|
||||
|
||||
def __init__(self, auto_prefix=True, flags=None):
|
||||
|
@ -127,12 +127,12 @@ class Cell(Cell_):
|
|||
if flags:
|
||||
self.add_flags(**flags)
|
||||
self._bprop_debug = False
|
||||
self.forward_pre_hook = OrderedDict()
|
||||
self.forward_hook = OrderedDict()
|
||||
self._forward_pre_hook = OrderedDict()
|
||||
self._forward_hook = OrderedDict()
|
||||
self._enable_forward_pre_hook = False
|
||||
self._enable_forward_hook = False
|
||||
self.enable_backward_hook = False
|
||||
self.cell_backward_hook = None
|
||||
self._enable_backward_hook = False
|
||||
self._cell_backward_hook = None
|
||||
self.cell_type = None
|
||||
self._auto_parallel_compile_and_run = False
|
||||
self.cast = Cast()
|
||||
|
@ -383,17 +383,36 @@ class Cell(Cell_):
|
|||
self.parameter_broadcast_done = True
|
||||
|
||||
def run_construct(self, cast_inputs, kwargs):
|
||||
"""
|
||||
Run the construct function.
|
||||
|
||||
Note:
|
||||
This function will be removed in a future version. It is not recommended to call this function.
|
||||
|
||||
Args:
|
||||
cast_inputs (tuple): The input objects of Cell.
|
||||
kwargs (dict): Provide keyword arguments.
|
||||
|
||||
Returns:
|
||||
output, the output object of Cell.
|
||||
"""
|
||||
logger.warning(f"The 'run_construct' function of '{self.cls_name}' will be removed in a future version. "
|
||||
f"Calling this function is not recommended.")
|
||||
output = self._run_construct(cast_inputs, kwargs)
|
||||
return output
|
||||
|
||||
def _run_construct(self, cast_inputs, kwargs):
|
||||
"""Run the construct function"""
|
||||
if self._enable_forward_pre_hook:
|
||||
cast_inputs = self.run_forward_pre_hook(cast_inputs)
|
||||
if self.enable_backward_hook:
|
||||
output = self.run_backward_hook(*cast_inputs)
|
||||
cast_inputs = self._run_forward_pre_hook(cast_inputs)
|
||||
if self._enable_backward_hook:
|
||||
output = self._backward_hook_construct(*cast_inputs)
|
||||
elif hasattr(self, "_shard_fn"):
|
||||
output = self._shard_fn(*cast_inputs, **kwargs)
|
||||
else:
|
||||
output = self.construct(*cast_inputs, **kwargs)
|
||||
if self._enable_forward_hook:
|
||||
output = self.run_forward_hook(cast_inputs, output)
|
||||
output = self._run_forward_hook(cast_inputs, output)
|
||||
return output
|
||||
|
||||
def _check_construct_args(self, *inputs, **kwargs):
|
||||
|
@ -536,7 +555,7 @@ class Cell(Cell_):
|
|||
# Run in Graph mode.
|
||||
if context._get_mode() == context.GRAPH_MODE:
|
||||
self._check_construct_args(*args, **kwargs)
|
||||
if self._enable_forward_pre_hook or self._enable_forward_hook or self.enable_backward_hook:
|
||||
if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
|
||||
logger.warning(f"For 'Cell', it's not support hook function in graph mode, please use "
|
||||
f"context.set_context to set pynative mode.")
|
||||
out = self.compile_and_run(*args)
|
||||
|
@ -562,7 +581,7 @@ class Cell(Cell_):
|
|||
|
||||
with self.CellGuard():
|
||||
try:
|
||||
output = self.run_construct(cast_inputs, kwargs)
|
||||
output = self._run_construct(cast_inputs, kwargs)
|
||||
except Exception as err:
|
||||
_pynative_executor.clear_res()
|
||||
raise err
|
||||
|
@ -1525,7 +1544,7 @@ class Cell(Cell_):
|
|||
self.add_flags(auto_parallel=True)
|
||||
self._get_construct_inputs_number_and_name()
|
||||
|
||||
def run_forward_pre_hook(self, inputs):
|
||||
def _run_forward_pre_hook(self, inputs):
|
||||
"""
|
||||
Running forward pre hook function registered on cell object.
|
||||
|
||||
|
@ -1539,7 +1558,7 @@ class Cell(Cell_):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
cell_id = self.cls_name + "(" + str(id(self)) + ")"
|
||||
for fn in self.forward_pre_hook.values():
|
||||
for fn in self._forward_pre_hook.values():
|
||||
ret = fn(cell_id, inputs)
|
||||
if ret is not None:
|
||||
if not isinstance(ret, tuple):
|
||||
|
@ -1620,11 +1639,11 @@ class Cell(Cell_):
|
|||
if not hasattr(self, '_forward_pre_hook_key'):
|
||||
self._forward_pre_hook_key = -1
|
||||
self._forward_pre_hook_key += 1
|
||||
self.forward_pre_hook[self._forward_pre_hook_key] = hook_fn
|
||||
handle = HookHandle(self, self._forward_pre_hook_key, "forward_pre_hook")
|
||||
self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn
|
||||
handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook")
|
||||
return handle
|
||||
|
||||
def run_forward_hook(self, inputs, output):
|
||||
def _run_forward_hook(self, inputs, output):
|
||||
"""
|
||||
Running forward hook function registered on cell object.
|
||||
|
||||
|
@ -1639,7 +1658,7 @@ class Cell(Cell_):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
cell_id = self.cls_name + "(" + str(id(self)) + ")"
|
||||
for fn in self.forward_hook.values():
|
||||
for fn in self._forward_hook.values():
|
||||
ret = fn(cell_id, inputs, output)
|
||||
if ret is not None:
|
||||
output = ret
|
||||
|
@ -1719,11 +1738,11 @@ class Cell(Cell_):
|
|||
if not hasattr(self, '_forward_hook_key'):
|
||||
self._forward_hook_key = -1
|
||||
self._forward_hook_key += 1
|
||||
self.forward_hook[self._forward_hook_key] = hook_fn
|
||||
handle = HookHandle(self, self._forward_hook_key, "forward_hook")
|
||||
self._forward_hook[self._forward_hook_key] = hook_fn
|
||||
handle = HookHandle(self, self._forward_hook_key, "_forward_hook")
|
||||
return handle
|
||||
|
||||
def run_backward_hook(self, *inputs):
|
||||
def _backward_hook_construct(self, *inputs):
|
||||
"""
|
||||
Backward hook construct method to replace original construct method.
|
||||
|
||||
|
@ -1736,15 +1755,15 @@ class Cell(Cell_):
|
|||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
inputs = self.cell_backward_hook(*inputs)
|
||||
inputs = self._cell_backward_hook(*inputs)
|
||||
if isinstance(inputs, tuple):
|
||||
inputs = self.construct(*inputs)
|
||||
else:
|
||||
inputs = self.construct(inputs)
|
||||
if isinstance(inputs, tuple):
|
||||
outputs = self.cell_backward_hook(*inputs)
|
||||
outputs = self._cell_backward_hook(*inputs)
|
||||
else:
|
||||
outputs = self.cell_backward_hook(inputs)
|
||||
outputs = self._cell_backward_hook(inputs)
|
||||
return outputs
|
||||
|
||||
def register_backward_hook(self, hook_fn):
|
||||
|
@ -1811,14 +1830,14 @@ class Cell(Cell_):
|
|||
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
||||
raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' should be python "
|
||||
f"function, but got {type(hook_fn)}.")
|
||||
if self.cell_backward_hook is None:
|
||||
self.enable_backward_hook = True
|
||||
self.cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
|
||||
backward_hook_key = self.cell_backward_hook.register_backward_hook(hook_fn)
|
||||
handle = HookHandle(self, backward_hook_key, "cell_backward_hook")
|
||||
if self._cell_backward_hook is None:
|
||||
self._enable_backward_hook = True
|
||||
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
|
||||
backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
|
||||
handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
|
||||
else:
|
||||
backward_hook_key = self.cell_backward_hook.register_backward_hook(hook_fn)
|
||||
handle = HookHandle(self, backward_hook_key, "cell_backward_hook")
|
||||
backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
|
||||
handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
|
||||
return handle
|
||||
|
||||
def set_param_ps(self, recurse=True, init_in_server=False):
|
||||
|
|
Loading…
Reference in New Issue