forked from mindspore-Ecosystem/mindspore
!46150 Add the verfication of setting dataset_strategt under PyNative mode
Merge pull request !46150 from liuluobin/fix_dataset_strategy
This commit is contained in:
commit
14ffcb619f
|
@ -48,7 +48,7 @@ mindspore.set_auto_parallel_context
|
|||
- **strategy_ckpt_load_file** (str) - 表示用于加载并行策略checkpoint的路径。默认值: ''。
|
||||
- **strategy_ckpt_save_file** (str) - 表示用于保存并行策略checkpoint的路径。默认值: ''。
|
||||
- **full_batch** (bool) - 如果在auto_parallel模式下加载整个batch数据集,则此参数应设置为True。默认值:False。目前不建议使用该接口,建议使用dataset_strategy来替换它。
|
||||
- **dataset_strategy** (Union[str, tuple]) - 表示数据集分片策略。默认值:data_parallel。dataset_strategy="data_parallel"等于full_batch=False,dataset_strategy="full_batch"等于full_batch=True。对于通过模型并列策略加载到网络的数据集,如ds_stra ((1, 8)、(1, 8)),需要使用set_auto_parallel_context(dataset_strategy=ds_stra)。
|
||||
- **dataset_strategy** (Union[str, tuple]) - 表示数据集分片策略。默认值:data_parallel。dataset_strategy="data_parallel"等于full_batch=False,dataset_strategy="full_batch"等于full_batch=True。对于在静态图模式下执行并且通过模型并列策略加载到网络的数据集,如ds_stra ((1, 8)、(1, 8)),需要使用set_auto_parallel_context(dataset_strategy=ds_stra)。
|
||||
- **enable_parallel_optimizer** (bool) - 这是一个开发中的特性,它可以为数据并行训练对权重更新计算进行分片,以节省时间和内存。目前,自动和半自动并行模式支持Ascend和GPU中的所有优化器。数据并行模式仅支持Ascend中的 `Lamb` 和 `AdamWeightDecay` 。默认值:False。
|
||||
- **enable_alltoall** (bool) - 允许在通信期间生成 `AllToAll` 通信算子的开关。如果其值为 False,则将由 `AllGather` 、 `Split` 和 `Concat` 等通信算子的组合来代替 `AllToAll` 。默认值:False。
|
||||
- **all_reduce_fusion_config** (list) - 通过参数索引设置 AllReduce 融合策略。仅支持ReduceOp.SUM和HCCL_WORLD_GROUP/NCCL_WORLD_GROUP。没有默认值。如果不设置,则关闭算子融合。
|
||||
|
|
|
@ -549,8 +549,9 @@ def set_auto_parallel_context(**kwargs):
|
|||
it is better using 'dataset_strategy' to replace it.
|
||||
dataset_strategy (Union[str, tuple]): Dataset sharding strategy. Default: "data_parallel".
|
||||
dataset_strategy="data_parallel" is equal to full_batch=False, dataset_strategy="full_batch" is
|
||||
equal to full_batch=True. For dataset load into net by model parallel strategy likes
|
||||
ds_stra ((1, 8), (1, 8)), it requires using set_auto_parallel_context(dataset_strategy=ds_stra).
|
||||
equal to full_batch=True. For execution mode is 'GRAPH_MODE' and dataset load into net by model
|
||||
parallel strategy likes ds_stra ((1, 8), (1, 8)), it requires using
|
||||
set_auto_parallel_context(dataset_strategy=ds_stra).
|
||||
enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
|
||||
data parallel training in the benefit of time and memory saving. Currently, auto and semi auto
|
||||
parallel mode support all optimizers in both Ascend and GPU. Data parallel mode only supports
|
||||
|
|
|
@ -521,6 +521,9 @@ class _AutoParallelContext:
|
|||
if not isinstance(dim, int):
|
||||
raise TypeError("For 'set_auto_parallel_context', the element of argument "
|
||||
"'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
|
||||
if context.get_context('mode') == context.PYNATIVE_MODE:
|
||||
raise ValueError("In PyNative mode, the setting value of 'dataset_strategy' must be either 'full_batch' "
|
||||
f"or 'data_parallel', but got {dataset_strategy}.")
|
||||
self._dataset_strategy_using_str = False
|
||||
self._context_handle.set_dataset_strategy(dataset_strategy)
|
||||
|
||||
|
@ -531,7 +534,11 @@ class _AutoParallelContext:
|
|||
if self._context_handle.get_full_batch():
|
||||
return "full_batch"
|
||||
return "data_parallel"
|
||||
return self._context_handle.get_dataset_strategy()
|
||||
dataset_strategy = self._context_handle.get_dataset_strategy()
|
||||
if context.get_context('mode') == context.PYNATIVE_MODE:
|
||||
raise ValueError("In PyNative mode, the value of 'dataset_strategy' must be either 'full_batch' "
|
||||
f"or 'data_parallel', but got the setting value is {dataset_strategy}.")
|
||||
return dataset_strategy
|
||||
|
||||
def set_grad_accumulation_step(self, grad_accumulation_step):
|
||||
"""
|
||||
|
|
|
@ -52,9 +52,8 @@ def _is_in_hybrid_parallel_mode():
|
|||
|
||||
|
||||
def _is_pynative_parallel():
|
||||
run_mode = context.get_context('mode')
|
||||
parallel_mode = context.get_auto_parallel_context('parallel_mode')
|
||||
return run_mode == context.PYNATIVE_MODE and parallel_mode in (
|
||||
return context.get_context('mode') == context.PYNATIVE_MODE and parallel_mode in (
|
||||
context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue