forked from mindspore-Ecosystem/mindspore
!27496 modify parallel api note for master
Merge pull request !27496 from lilei/modify_parallel_API_note
This commit is contained in:
commit
ef2f5f6ca2
|
@ -471,6 +471,9 @@ class MicroBatchInterleaved(Cell):
|
||||||
network (Cell): The target network to wrap.
|
network (Cell): The target network to wrap.
|
||||||
interleave_num (int): split num of batch size. Default: 2.
|
interleave_num (int): split num of batch size. Default: 2.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
>>> net = MicroBatchInterleaved(net, 4)
|
>>> net = MicroBatchInterleaved(net, 4)
|
||||||
|
|
|
@ -208,10 +208,10 @@ class _AutoParallelContext:
|
||||||
def set_pipeline_stages(self, stages):
|
def set_pipeline_stages(self, stages):
|
||||||
"""Set the stages of the pipeline"""
|
"""Set the stages of the pipeline"""
|
||||||
if isinstance(stages, bool) or not isinstance(stages, int):
|
if isinstance(stages, bool) or not isinstance(stages, int):
|
||||||
raise TypeError("For 'set_auto_parallel_context().set_pipeline_stages', the argument 'pipeline_stages' "
|
raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
|
||||||
"must be int, but got the type : {}.".format(type(stages)))
|
"must be int, but got the type : {}.".format(type(stages)))
|
||||||
if stages < 1:
|
if stages < 1:
|
||||||
raise ValueError("For 'set_auto_parallel_context().set_pipeline_stages', the argument 'pipeline_stages' "
|
raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
|
||||||
"should be greater or equal 1, but got the value of stages : {}.".format(stages))
|
"should be greater or equal 1, but got the value of stages : {}.".format(stages))
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
self._context_handle.set_pipeline_stage_split_num(stages)
|
self._context_handle.set_pipeline_stage_split_num(stages)
|
||||||
|
@ -271,7 +271,7 @@ class _AutoParallelContext:
|
||||||
loss_repeated_mean (bool): The loss_repeated_mean flag.
|
loss_repeated_mean (bool): The loss_repeated_mean flag.
|
||||||
"""
|
"""
|
||||||
if not isinstance(loss_repeated_mean, bool):
|
if not isinstance(loss_repeated_mean, bool):
|
||||||
raise TypeError("For 'auto_parallel_context().set_loss_repeated_mean', the argument 'loss_repeated_mean' "
|
raise TypeError("For 'auto_parallel_context', the argument 'loss_repeated_mean' "
|
||||||
"must be bool, but got the type : {}.".format(type(loss_repeated_mean)))
|
"must be bool, but got the type : {}.".format(type(loss_repeated_mean)))
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
|
self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
|
||||||
|
@ -436,22 +436,22 @@ class _AutoParallelContext:
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
if isinstance(dataset_strategy, str):
|
if isinstance(dataset_strategy, str):
|
||||||
if dataset_strategy not in ("full_batch", "data_parallel"):
|
if dataset_strategy not in ("full_batch", "data_parallel"):
|
||||||
raise ValueError("For 'set_auto_parallel_context().set_dataset_strategy', the argument "
|
raise ValueError("For 'set_auto_parallel_context', the argument "
|
||||||
"'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}."
|
"'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}."
|
||||||
.format(dataset_strategy))
|
.format(dataset_strategy))
|
||||||
self._context_handle.set_full_batch(dataset_strategy == "full_batch")
|
self._context_handle.set_full_batch(dataset_strategy == "full_batch")
|
||||||
self._dataset_strategy_using_str = True
|
self._dataset_strategy_using_str = True
|
||||||
return
|
return
|
||||||
if not isinstance(dataset_strategy, tuple):
|
if not isinstance(dataset_strategy, tuple):
|
||||||
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the argument 'dataset_strategy' "
|
raise TypeError("For 'set_auto_parallel_context', the argument 'dataset_strategy' "
|
||||||
"must be str or tuple type, but got the type : {}.".format(type(dataset_strategy)))
|
"must be str or tuple type, but got the type : {}.".format(type(dataset_strategy)))
|
||||||
for ele in dataset_strategy:
|
for ele in dataset_strategy:
|
||||||
if not isinstance(ele, tuple):
|
if not isinstance(ele, tuple):
|
||||||
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the element of argument "
|
raise TypeError("For 'set_auto_parallel_context', the element of argument "
|
||||||
"'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele)))
|
"'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele)))
|
||||||
for dim in ele:
|
for dim in ele:
|
||||||
if not isinstance(dim, int):
|
if not isinstance(dim, int):
|
||||||
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the element of argument "
|
raise TypeError("For 'set_auto_parallel_context', the element of argument "
|
||||||
"'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
|
"'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
|
||||||
self._dataset_strategy_using_str = False
|
self._dataset_strategy_using_str = False
|
||||||
self._context_handle.set_dataset_strategy(dataset_strategy)
|
self._context_handle.set_dataset_strategy(dataset_strategy)
|
||||||
|
@ -650,7 +650,7 @@ class _AutoParallelContext:
|
||||||
"""
|
"""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
if not isinstance(enable_parallel_optimizer, bool):
|
if not isinstance(enable_parallel_optimizer, bool):
|
||||||
raise TypeError("For 'set_auto_parallel_context().set_enable_parallel_optimizer', "
|
raise TypeError("For 'set_auto_parallel_context', "
|
||||||
"the argument 'enable_parallel_optimizer' must be bool, but got the type : {}."
|
"the argument 'enable_parallel_optimizer' must be bool, but got the type : {}."
|
||||||
.format(type(enable_parallel_optimizer)))
|
.format(type(enable_parallel_optimizer)))
|
||||||
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
|
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
|
||||||
|
|
Loading…
Reference in New Issue