forked from mindspore-Ecosystem/mindspore
recompute interface modify
This commit is contained in:
parent
94f010aee6
commit
a312a8a9eb
|
@ -1327,7 +1327,7 @@ class Cell(Cell_):
|
|||
for cell in self.cells():
|
||||
cell._recompute(mode, True)
|
||||
|
||||
@args_type_check(mp_comm_recompute=bool, optimizer_shard_comm_recompute=bool)
|
||||
@args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool)
|
||||
def recompute(self, **kwargs):
|
||||
"""
|
||||
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
|
||||
|
@ -1348,17 +1348,17 @@ class Cell(Cell_):
|
|||
Args:
|
||||
mp_comm_recompute (bool): Specifies whether the model parallel communication operators
|
||||
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
|
||||
optimizer_shard_comm_recompute (bool): Specifies whether the communication operator allgathers
|
||||
parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
|
||||
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default: False.
|
||||
"""
|
||||
self._recompute()
|
||||
if 'mp_comm_recompute' in kwargs.keys():
|
||||
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
|
||||
if 'optimizer_shard_comm_recompute' in kwargs.keys():
|
||||
if 'parallel_optimizer_comm_recompute' in kwargs.keys():
|
||||
raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard "
|
||||
"are not support recomputation")
|
||||
for key, _ in kwargs.items():
|
||||
if key not in ('mp_comm_recompute', 'optimizer_shard_comm_recompute'):
|
||||
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute'):
|
||||
raise ValueError("Recompute keyword %s is not recognized!" % key)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue