diff --git a/mindspore/context.py b/mindspore/context.py index d28fa919838..4279d1b76df 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -383,9 +383,9 @@ def set_auto_parallel_context(**kwargs): full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter should be set with True. Default: False. enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for - data parallel training in the benefit of time and memory saving. For now, auto parallel mode - supports all optimizers. Data parallel mode only supports `Lamb` and `AdamWeightDecay`. - Default: False. + data parallel training in the benefit of time and memory saving. Currently, auto and semi auto + parallel mode support all optimizers in both Ascend and GPU. Data parallel mode only supports + `Lamb` and `AdamWeightDecay` in Ascend . Default: False. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index c0b3331a32b..02c2962faa7 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -164,8 +164,10 @@ class Optimizer(Cell): self.param_length = len(self.parameters) self.map_ = C.Map() if context.get_auto_parallel_context("enable_parallel_optimizer"): - if _get_parallel_mode() == ParallelMode.DATA_PARALLEL: + if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend": self.use_parallel = True + elif context.get_context("device_target") != "Ascend": + raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.") elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL): raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode())) else: @@ -174,10 +176,10 @@ class Optimizer(Cell): self.use_parallel = False if self.use_parallel: if self.cls_name not in ["Lamb", "AdamWeightDecay"]: - raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) + raise RuntimeError("Parallel optimizer does not support optimizer {}".format(self.cls_name)) self.dev_num = _get_device_num() if self.dev_num > self.param_length: - raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is" + raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is" " less than the number of devices {}".format(self.param_length, self.dev_num)) self.param_rank = self._get_parameter_group_id() self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) diff --git a/tests/ut/python/parallel/test_parallel_optimizer.py b/tests/ut/python/parallel/test_parallel_optimizer.py index c8506437451..df649af4b34 100644 --- a/tests/ut/python/parallel/test_parallel_optimizer.py +++ b/tests/ut/python/parallel/test_parallel_optimizer.py @@ -164,6 +164,12 @@ def test_edge_case(): context.set_auto_parallel_context(parallel_mode="stand_alone") Lamb(net.trainable_params(), learning_rate=0.1) with pytest.raises(RuntimeError): + context.set_context(device_target="GPU") + context.set_auto_parallel_context(parallel_mode="data_parallel") + Lamb(net.trainable_params(), learning_rate=0.1) + with pytest.raises(RuntimeError): + context.set_context(device_target="Ascend") + context.set_auto_parallel_context(parallel_mode="data_parallel") Adam(net.trainable_params(), learning_rate=0.1) with pytest.raises(RuntimeError): context.set_auto_parallel_context(device_num=16)