add restriction for opt shard

This commit is contained in:
Ziyan 2020-12-19 16:58:09 +08:00
parent 512cd38406
commit c5c905fdf5
3 changed files with 14 additions and 6 deletions

View File

@ -383,9 +383,9 @@ def set_auto_parallel_context(**kwargs):
full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter
should be set with True. Default: False.
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
data parallel training in the benefit of time and memory saving. For now, auto parallel mode
supports all optimizers. Data parallel mode only supports `Lamb` and `AdamWeightDecay`.
Default: False.
data parallel training in the benefit of time and memory saving. Currently, auto and semi auto
parallel mode support all optimizers in both Ascend and GPU. Data parallel mode only supports
`Lamb` and `AdamWeightDecay` in Ascend . Default: False.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed.
pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how

View File

@ -164,8 +164,10 @@ class Optimizer(Cell):
self.param_length = len(self.parameters)
self.map_ = C.Map()
if context.get_auto_parallel_context("enable_parallel_optimizer"):
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL:
if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend":
self.use_parallel = True
elif context.get_context("device_target") != "Ascend":
raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.")
elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL):
raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode()))
else:
@ -174,10 +176,10 @@ class Optimizer(Cell):
self.use_parallel = False
if self.use_parallel:
if self.cls_name not in ["Lamb", "AdamWeightDecay"]:
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
raise RuntimeError("Parallel optimizer does not support optimizer {}".format(self.cls_name))
self.dev_num = _get_device_num()
if self.dev_num > self.param_length:
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is"
" less than the number of devices {}".format(self.param_length, self.dev_num))
self.param_rank = self._get_parameter_group_id()
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))

View File

@ -164,6 +164,12 @@ def test_edge_case():
context.set_auto_parallel_context(parallel_mode="stand_alone")
Lamb(net.trainable_params(), learning_rate=0.1)
with pytest.raises(RuntimeError):
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="data_parallel")
Lamb(net.trainable_params(), learning_rate=0.1)
with pytest.raises(RuntimeError):
context.set_context(device_target="Ascend")
context.set_auto_parallel_context(parallel_mode="data_parallel")
Adam(net.trainable_params(), learning_rate=0.1)
with pytest.raises(RuntimeError):
context.set_auto_parallel_context(device_num=16)