forked from mindspore-Ecosystem/mindspore
recompute interface modify
This commit is contained in:
parent
94f010aee6
commit
a312a8a9eb
|
@ -1327,7 +1327,7 @@ class Cell(Cell_):
|
||||||
for cell in self.cells():
|
for cell in self.cells():
|
||||||
cell._recompute(mode, True)
|
cell._recompute(mode, True)
|
||||||
|
|
||||||
@args_type_check(mp_comm_recompute=bool, optimizer_shard_comm_recompute=bool)
|
@args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool)
|
||||||
def recompute(self, **kwargs):
|
def recompute(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
|
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
|
||||||
|
@ -1347,18 +1347,18 @@ class Cell(Cell_):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mp_comm_recompute (bool): Specifies whether the model parallel communication operators
|
mp_comm_recompute (bool): Specifies whether the model parallel communication operators
|
||||||
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
|
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
|
||||||
optimizer_shard_comm_recompute (bool): Specifies whether the communication operator allgathers
|
parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
|
||||||
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default: False.
|
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode. Default: False.
|
||||||
"""
|
"""
|
||||||
self._recompute()
|
self._recompute()
|
||||||
if 'mp_comm_recompute' in kwargs.keys():
|
if 'mp_comm_recompute' in kwargs.keys():
|
||||||
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
|
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
|
||||||
if 'optimizer_shard_comm_recompute' in kwargs.keys():
|
if 'parallel_optimizer_comm_recompute' in kwargs.keys():
|
||||||
raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard "
|
raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard "
|
||||||
"are not support recomputation")
|
"are not support recomputation")
|
||||||
for key, _ in kwargs.items():
|
for key, _ in kwargs.items():
|
||||||
if key not in ('mp_comm_recompute', 'optimizer_shard_comm_recompute'):
|
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute'):
|
||||||
raise ValueError("Recompute keyword %s is not recognized!" % key)
|
raise ValueError("Recompute keyword %s is not recognized!" % key)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue