diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index e0761df83e6..4fd756b3ddf 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1327,7 +1327,7 @@ class Cell(Cell_): for cell in self.cells(): cell._recompute(mode, True) - @args_type_check(mp_comm_recompute=bool, optimizer_shard_comm_recompute=bool) + @args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool) def recompute(self, **kwargs): """ Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive @@ -1347,18 +1347,18 @@ class Cell(Cell_): Args: mp_comm_recompute (bool): Specifies whether the model parallel communication operators - in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True. - optimizer_shard_comm_recompute (bool): Specifies whether the communication operator allgathers - introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default: False. + in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True. + parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers + introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default: False. """ self._recompute() if 'mp_comm_recompute' in kwargs.keys(): self._mp_comm_recompute(kwargs['mp_comm_recompute']) - if 'optimizer_shard_comm_recompute' in kwargs.keys(): + if 'parallel_optimizer_comm_recompute' in kwargs.keys(): raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard " "are not support recomputation") for key, _ in kwargs.items(): - if key not in ('mp_comm_recompute', 'optimizer_shard_comm_recompute'): + if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute'): raise ValueError("Recompute keyword %s is not recognized!" % key)