recompute interface modify

This commit is contained in:
yao_yf 2021-09-02 17:03:28 +08:00
parent 94f010aee6
commit a312a8a9eb
1 changed files with 6 additions and 6 deletions

View File

@ -1327,7 +1327,7 @@ class Cell(Cell_):
for cell in self.cells():
cell._recompute(mode, True)
@args_type_check(mp_comm_recompute=bool, optimizer_shard_comm_recompute=bool)
@args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool)
def recompute(self, **kwargs):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
@ -1347,18 +1347,18 @@ class Cell(Cell_):
Args:
mp_comm_recompute (bool): Specifies whether the model parallel communication operators
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
optimizer_shard_comm_recompute (bool): Specifies whether the communication operator allgathers
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default: False.
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default: False.
"""
self._recompute()
if 'mp_comm_recompute' in kwargs.keys():
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
if 'optimizer_shard_comm_recompute' in kwargs.keys():
if 'parallel_optimizer_comm_recompute' in kwargs.keys():
raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard "
"are not support recomputation")
for key, _ in kwargs.items():
if key not in ('mp_comm_recompute', 'optimizer_shard_comm_recompute'):
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute'):
raise ValueError("Recompute keyword %s is not recognized!" % key)