!46150 Add the verfication of setting dataset_strategt under PyNative mode

Merge pull request !46150 from liuluobin/fix_dataset_strategy
This commit is contained in:
i-robot 2022-11-30 01:37:05 +00:00 committed by Gitee
commit 14ffcb619f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 13 additions and 6 deletions

View File

@ -48,7 +48,7 @@ mindspore.set_auto_parallel_context
- **strategy_ckpt_load_file** (str) - 表示用于加载并行策略checkpoint的路径。默认值 ''。
- **strategy_ckpt_save_file** (str) - 表示用于保存并行策略checkpoint的路径。默认值 ''。
- **full_batch** (bool) - 如果在auto_parallel模式下加载整个batch数据集则此参数应设置为True。默认值False。目前不建议使用该接口建议使用dataset_strategy来替换它。
- **dataset_strategy** (Union[str, tuple]) - 表示数据集分片策略。默认值data_parallel。dataset_strategy="data_parallel"等于full_batch=Falsedataset_strategy="full_batch"等于full_batch=True。对于通过模型并列策略加载到网络的数据集如ds_stra ((1, 8)、(1, 8))需要使用set_auto_parallel_context(dataset_strategy=ds_stra)。
- **dataset_strategy** (Union[str, tuple]) - 表示数据集分片策略。默认值data_parallel。dataset_strategy="data_parallel"等于full_batch=Falsedataset_strategy="full_batch"等于full_batch=True。对于在静态图模式下执行并且通过模型并列策略加载到网络的数据集如ds_stra ((1, 8)、(1, 8))需要使用set_auto_parallel_context(dataset_strategy=ds_stra)。
- **enable_parallel_optimizer** (bool) - 这是一个开发中的特性它可以为数据并行训练对权重更新计算进行分片以节省时间和内存。目前自动和半自动并行模式支持Ascend和GPU中的所有优化器。数据并行模式仅支持Ascend中的 `Lamb``AdamWeightDecay` 。默认值False。
- **enable_alltoall** (bool) - 允许在通信期间生成 `AllToAll` 通信算子的开关。如果其值为 False则将由 `AllGather``Split``Concat` 等通信算子的组合来代替 `AllToAll` 。默认值False。
- **all_reduce_fusion_config** (list) - 通过参数索引设置 AllReduce 融合策略。仅支持ReduceOp.SUM和HCCL_WORLD_GROUP/NCCL_WORLD_GROUP。没有默认值。如果不设置则关闭算子融合。

View File

@ -549,8 +549,9 @@ def set_auto_parallel_context(**kwargs):
it is better using 'dataset_strategy' to replace it.
dataset_strategy (Union[str, tuple]): Dataset sharding strategy. Default: "data_parallel".
dataset_strategy="data_parallel" is equal to full_batch=False, dataset_strategy="full_batch" is
equal to full_batch=True. For dataset load into net by model parallel strategy likes
ds_stra ((1, 8), (1, 8)), it requires using set_auto_parallel_context(dataset_strategy=ds_stra).
equal to full_batch=True. For execution mode is 'GRAPH_MODE' and dataset load into net by model
parallel strategy likes ds_stra ((1, 8), (1, 8)), it requires using
set_auto_parallel_context(dataset_strategy=ds_stra).
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
data parallel training in the benefit of time and memory saving. Currently, auto and semi auto
parallel mode support all optimizers in both Ascend and GPU. Data parallel mode only supports

View File

@ -521,6 +521,9 @@ class _AutoParallelContext:
if not isinstance(dim, int):
raise TypeError("For 'set_auto_parallel_context', the element of argument "
"'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
if context.get_context('mode') == context.PYNATIVE_MODE:
raise ValueError("In PyNative mode, the setting value of 'dataset_strategy' must be either 'full_batch' "
f"or 'data_parallel', but got {dataset_strategy}.")
self._dataset_strategy_using_str = False
self._context_handle.set_dataset_strategy(dataset_strategy)
@ -531,7 +534,11 @@ class _AutoParallelContext:
if self._context_handle.get_full_batch():
return "full_batch"
return "data_parallel"
return self._context_handle.get_dataset_strategy()
dataset_strategy = self._context_handle.get_dataset_strategy()
if context.get_context('mode') == context.PYNATIVE_MODE:
raise ValueError("In PyNative mode, the value of 'dataset_strategy' must be either 'full_batch' "
f"or 'data_parallel', but got the setting value is {dataset_strategy}.")
return dataset_strategy
def set_grad_accumulation_step(self, grad_accumulation_step):
"""

View File

@ -52,9 +52,8 @@ def _is_in_hybrid_parallel_mode():
def _is_pynative_parallel():
run_mode = context.get_context('mode')
parallel_mode = context.get_auto_parallel_context('parallel_mode')
return run_mode == context.PYNATIVE_MODE and parallel_mode in (
return context.get_context('mode') == context.PYNATIVE_MODE and parallel_mode in (
context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)