diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index e0db34626cd..7b1995467d6 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -12,6 +12,7 @@ "mindspore/mindspore/python/mindspore/_check_version.py" "unused-import" "mindspore/mindspore/python/mindspore/_check_version.py" "broad-except" "mindspore/mindspore/python/mindspore/common/parameter.py" "protected-access" +"mindspore/mindspore/python/mindspore/common/hook_handle.py" "protected-access" "mindspore/mindspore/python/mindspore/common/dtype.py" "undefined-all-variable" "mindspore/mindspore/python/mindspore/context.py" "protected-access" "mindspore/mindspore/python/mindspore/ops/operations" "super-init-not-called" @@ -28,6 +29,7 @@ "mindspore/mindspore/python/mindspore/nn/cell.py" "assignment-from-none" "mindspore/mindspore/python/mindspore/_extends/parse/resources.py" "bad-whitespace" "mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "broad-except" +"mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "protected-access" "mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "eval-used" "mindspore/mindspore/python/mindspore/nn/cell.py" "protected-access" "mindspore/mindspore/python/mindspore/nn/optim/ftrl.py" "unused-import" diff --git a/mindspore/python/mindspore/_extends/parse/parser.py b/mindspore/python/mindspore/_extends/parse/parser.py index e70bd4adff5..63bf8804db0 100644 --- a/mindspore/python/mindspore/_extends/parse/parser.py +++ b/mindspore/python/mindspore/_extends/parse/parser.py @@ -124,8 +124,8 @@ def get_parse_method_of_class(obj, parse_method=None): if parse_method is not None: method_name = parse_method elif isinstance(obj, nn.Cell): - if obj.enable_backward_hook: - method_name = "run_backward_hook" + if obj._enable_backward_hook: + method_name = "_backward_hook_construct" else: method_name = "construct" if method_name is not None: diff --git a/mindspore/python/mindspore/common/hook_handle.py b/mindspore/python/mindspore/common/hook_handle.py index a43533611b2..9c50e52d9d5 100644 --- a/mindspore/python/mindspore/common/hook_handle.py +++ b/mindspore/python/mindspore/common/hook_handle.py @@ -19,8 +19,8 @@ from .api import _pynative_executor class HookHandle: r""" - It is the return object of Cell forward pre hook function, forward hook function and backward hook function. - It corresponds to the cell hook function and is used to remove the cell hook function by calling 'remove()'. + It is the return object of forward pre hook function, forward hook function and backward hook function of Cell + object. It corresponds to the cell hook function and is used to remove the cell hook function by calling 'remove()'. Note: It is only supported in pynative mode and works when registering or removing hook function for Cell object. @@ -29,12 +29,9 @@ class HookHandle: hook_cell (Cell): The Cell object with hook function registered on. Default value: None. hook_key (int): The key of cell hook function in dict. It is generated during cell hook function registration. Default value: -1. - hook_type (str): The type of cell hook function: 'forward_pre_hook', 'forward_hook' or 'cell_backward_hook'. + hook_type (str): The type of cell hook function: '_forward_pre_hook', '_forward_hook' or '_cell_backward_hook'. Default value: "". - Returns: - None. - Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` """ @@ -49,8 +46,8 @@ class HookHandle: def remove(self): """ Remove the cell hook function, which corresponds to this 'HookHandle' object. - In order to prevent running failed when switching to graph mode, it is not recommended to write in the - construct. + In order to prevent running failed when switching to graph mode, it is not recommended to call the `remove()` + function in the construct function of Cell object. Args: None. @@ -95,14 +92,14 @@ class HookHandle: """ if self._hook_cell is not None: hook_cell = self._hook_cell() - if self._hook_type == "forward_pre_hook" and self._hook_key in hook_cell.forward_pre_hook: - del hook_cell.forward_pre_hook[self._hook_key] + if self._hook_type == "_forward_pre_hook" and self._hook_key in hook_cell._forward_pre_hook: + del hook_cell._forward_pre_hook[self._hook_key] _pynative_executor.set_hook_changed(hook_cell) - elif self._hook_type == "forward_hook" and self._hook_key in hook_cell.forward_hook: - del hook_cell.forward_hook[self._hook_key] + elif self._hook_type == "_forward_hook" and self._hook_key in hook_cell._forward_hook: + del hook_cell._forward_hook[self._hook_key] _pynative_executor.set_hook_changed(hook_cell) - elif self._hook_type == "cell_backward_hook": - hook_cell.cell_backward_hook.remove_backward_hook(self._hook_key) + elif self._hook_type == "_cell_backward_hook": + hook_cell._cell_backward_hook.remove_backward_hook(self._hook_key) def __del__(self): self._hook_cell = None diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index 7c29ec16910..e886856fb3e 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -84,8 +84,8 @@ class Cell(Cell_): IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names', '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run', '_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode', - 'forward_pre_hook', 'forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook', - '_bprop_debug', 'enable_backward_hook', 'cell_backward_hook', '_is_run', '_param_prefix', + '_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook', + '_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix', '_attr_synced', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type'] def __init__(self, auto_prefix=True, flags=None): @@ -127,12 +127,12 @@ class Cell(Cell_): if flags: self.add_flags(**flags) self._bprop_debug = False - self.forward_pre_hook = OrderedDict() - self.forward_hook = OrderedDict() + self._forward_pre_hook = OrderedDict() + self._forward_hook = OrderedDict() self._enable_forward_pre_hook = False self._enable_forward_hook = False - self.enable_backward_hook = False - self.cell_backward_hook = None + self._enable_backward_hook = False + self._cell_backward_hook = None self.cell_type = None self._auto_parallel_compile_and_run = False self.cast = Cast() @@ -383,17 +383,36 @@ class Cell(Cell_): self.parameter_broadcast_done = True def run_construct(self, cast_inputs, kwargs): + """ + Run the construct function. + + Note: + This function will be removed in a future version. It is not recommended to call this function. + + Args: + cast_inputs (tuple): The input objects of Cell. + kwargs (dict): Provide keyword arguments. + + Returns: + output, the output object of Cell. + """ + logger.warning(f"The 'run_construct' function of '{self.cls_name}' will be removed in a future version. " + f"Calling this function is not recommended.") + output = self._run_construct(cast_inputs, kwargs) + return output + + def _run_construct(self, cast_inputs, kwargs): """Run the construct function""" if self._enable_forward_pre_hook: - cast_inputs = self.run_forward_pre_hook(cast_inputs) - if self.enable_backward_hook: - output = self.run_backward_hook(*cast_inputs) + cast_inputs = self._run_forward_pre_hook(cast_inputs) + if self._enable_backward_hook: + output = self._backward_hook_construct(*cast_inputs) elif hasattr(self, "_shard_fn"): output = self._shard_fn(*cast_inputs, **kwargs) else: output = self.construct(*cast_inputs, **kwargs) if self._enable_forward_hook: - output = self.run_forward_hook(cast_inputs, output) + output = self._run_forward_hook(cast_inputs, output) return output def _check_construct_args(self, *inputs, **kwargs): @@ -536,7 +555,7 @@ class Cell(Cell_): # Run in Graph mode. if context._get_mode() == context.GRAPH_MODE: self._check_construct_args(*args, **kwargs) - if self._enable_forward_pre_hook or self._enable_forward_hook or self.enable_backward_hook: + if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook: logger.warning(f"For 'Cell', it's not support hook function in graph mode, please use " f"context.set_context to set pynative mode.") out = self.compile_and_run(*args) @@ -562,7 +581,7 @@ class Cell(Cell_): with self.CellGuard(): try: - output = self.run_construct(cast_inputs, kwargs) + output = self._run_construct(cast_inputs, kwargs) except Exception as err: _pynative_executor.clear_res() raise err @@ -1525,7 +1544,7 @@ class Cell(Cell_): self.add_flags(auto_parallel=True) self._get_construct_inputs_number_and_name() - def run_forward_pre_hook(self, inputs): + def _run_forward_pre_hook(self, inputs): """ Running forward pre hook function registered on cell object. @@ -1539,7 +1558,7 @@ class Cell(Cell_): ``Ascend`` ``GPU`` ``CPU`` """ cell_id = self.cls_name + "(" + str(id(self)) + ")" - for fn in self.forward_pre_hook.values(): + for fn in self._forward_pre_hook.values(): ret = fn(cell_id, inputs) if ret is not None: if not isinstance(ret, tuple): @@ -1620,11 +1639,11 @@ class Cell(Cell_): if not hasattr(self, '_forward_pre_hook_key'): self._forward_pre_hook_key = -1 self._forward_pre_hook_key += 1 - self.forward_pre_hook[self._forward_pre_hook_key] = hook_fn - handle = HookHandle(self, self._forward_pre_hook_key, "forward_pre_hook") + self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn + handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook") return handle - def run_forward_hook(self, inputs, output): + def _run_forward_hook(self, inputs, output): """ Running forward hook function registered on cell object. @@ -1639,7 +1658,7 @@ class Cell(Cell_): ``Ascend`` ``GPU`` ``CPU`` """ cell_id = self.cls_name + "(" + str(id(self)) + ")" - for fn in self.forward_hook.values(): + for fn in self._forward_hook.values(): ret = fn(cell_id, inputs, output) if ret is not None: output = ret @@ -1719,11 +1738,11 @@ class Cell(Cell_): if not hasattr(self, '_forward_hook_key'): self._forward_hook_key = -1 self._forward_hook_key += 1 - self.forward_hook[self._forward_hook_key] = hook_fn - handle = HookHandle(self, self._forward_hook_key, "forward_hook") + self._forward_hook[self._forward_hook_key] = hook_fn + handle = HookHandle(self, self._forward_hook_key, "_forward_hook") return handle - def run_backward_hook(self, *inputs): + def _backward_hook_construct(self, *inputs): """ Backward hook construct method to replace original construct method. @@ -1736,15 +1755,15 @@ class Cell(Cell_): Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` """ - inputs = self.cell_backward_hook(*inputs) + inputs = self._cell_backward_hook(*inputs) if isinstance(inputs, tuple): inputs = self.construct(*inputs) else: inputs = self.construct(inputs) if isinstance(inputs, tuple): - outputs = self.cell_backward_hook(*inputs) + outputs = self._cell_backward_hook(*inputs) else: - outputs = self.cell_backward_hook(inputs) + outputs = self._cell_backward_hook(inputs) return outputs def register_backward_hook(self, hook_fn): @@ -1811,14 +1830,14 @@ class Cell(Cell_): if not isinstance(hook_fn, (FunctionType, MethodType)): raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' should be python " f"function, but got {type(hook_fn)}.") - if self.cell_backward_hook is None: - self.enable_backward_hook = True - self.cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")") - backward_hook_key = self.cell_backward_hook.register_backward_hook(hook_fn) - handle = HookHandle(self, backward_hook_key, "cell_backward_hook") + if self._cell_backward_hook is None: + self._enable_backward_hook = True + self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")") + backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn) + handle = HookHandle(self, backward_hook_key, "_cell_backward_hook") else: - backward_hook_key = self.cell_backward_hook.register_backward_hook(hook_fn) - handle = HookHandle(self, backward_hook_key, "cell_backward_hook") + backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn) + handle = HookHandle(self, backward_hook_key, "_cell_backward_hook") return handle def set_param_ps(self, recurse=True, init_in_server=False):