forked from mindspore-Ecosystem/mindspore
!27496 modify parallel api note for master
Merge pull request !27496 from lilei/modify_parallel_API_note
This commit is contained in:
commit
ef2f5f6ca2
|
@ -471,6 +471,9 @@ class MicroBatchInterleaved(Cell):
|
|||
network (Cell): The target network to wrap.
|
||||
interleave_num (int): split num of batch size. Default: 2.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> net = MicroBatchInterleaved(net, 4)
|
||||
|
|
|
@ -208,10 +208,10 @@ class _AutoParallelContext:
|
|||
def set_pipeline_stages(self, stages):
|
||||
"""Set the stages of the pipeline"""
|
||||
if isinstance(stages, bool) or not isinstance(stages, int):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_pipeline_stages', the argument 'pipeline_stages' "
|
||||
raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
|
||||
"must be int, but got the type : {}.".format(type(stages)))
|
||||
if stages < 1:
|
||||
raise ValueError("For 'set_auto_parallel_context().set_pipeline_stages', the argument 'pipeline_stages' "
|
||||
raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
|
||||
"should be greater or equal 1, but got the value of stages : {}.".format(stages))
|
||||
self.check_context_handle()
|
||||
self._context_handle.set_pipeline_stage_split_num(stages)
|
||||
|
@ -271,7 +271,7 @@ class _AutoParallelContext:
|
|||
loss_repeated_mean (bool): The loss_repeated_mean flag.
|
||||
"""
|
||||
if not isinstance(loss_repeated_mean, bool):
|
||||
raise TypeError("For 'auto_parallel_context().set_loss_repeated_mean', the argument 'loss_repeated_mean' "
|
||||
raise TypeError("For 'auto_parallel_context', the argument 'loss_repeated_mean' "
|
||||
"must be bool, but got the type : {}.".format(type(loss_repeated_mean)))
|
||||
self.check_context_handle()
|
||||
self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
|
||||
|
@ -436,22 +436,22 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
if isinstance(dataset_strategy, str):
|
||||
if dataset_strategy not in ("full_batch", "data_parallel"):
|
||||
raise ValueError("For 'set_auto_parallel_context().set_dataset_strategy', the argument "
|
||||
raise ValueError("For 'set_auto_parallel_context', the argument "
|
||||
"'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}."
|
||||
.format(dataset_strategy))
|
||||
self._context_handle.set_full_batch(dataset_strategy == "full_batch")
|
||||
self._dataset_strategy_using_str = True
|
||||
return
|
||||
if not isinstance(dataset_strategy, tuple):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the argument 'dataset_strategy' "
|
||||
raise TypeError("For 'set_auto_parallel_context', the argument 'dataset_strategy' "
|
||||
"must be str or tuple type, but got the type : {}.".format(type(dataset_strategy)))
|
||||
for ele in dataset_strategy:
|
||||
if not isinstance(ele, tuple):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the element of argument "
|
||||
raise TypeError("For 'set_auto_parallel_context', the element of argument "
|
||||
"'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele)))
|
||||
for dim in ele:
|
||||
if not isinstance(dim, int):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the element of argument "
|
||||
raise TypeError("For 'set_auto_parallel_context', the element of argument "
|
||||
"'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
|
||||
self._dataset_strategy_using_str = False
|
||||
self._context_handle.set_dataset_strategy(dataset_strategy)
|
||||
|
@ -650,7 +650,7 @@ class _AutoParallelContext:
|
|||
"""
|
||||
self.check_context_handle()
|
||||
if not isinstance(enable_parallel_optimizer, bool):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_enable_parallel_optimizer', "
|
||||
raise TypeError("For 'set_auto_parallel_context', "
|
||||
"the argument 'enable_parallel_optimizer' must be bool, but got the type : {}."
|
||||
.format(type(enable_parallel_optimizer)))
|
||||
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
|
||||
|
|
Loading…
Reference in New Issue