!27496 modify parallel api note for master

Merge pull request !27496 from lilei/modify_parallel_API_note
This commit is contained in:
i-robot 2021-12-14 01:23:18 +00:00 committed by Gitee
commit ef2f5f6ca2
2 changed files with 11 additions and 8 deletions

View File

@ -471,6 +471,9 @@ class MicroBatchInterleaved(Cell):
network (Cell): The target network to wrap. network (Cell): The target network to wrap.
interleave_num (int): split num of batch size. Default: 2. interleave_num (int): split num of batch size. Default: 2.
Supported Platforms:
``Ascend`` ``GPU``
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> net = MicroBatchInterleaved(net, 4) >>> net = MicroBatchInterleaved(net, 4)

View File

@ -208,10 +208,10 @@ class _AutoParallelContext:
def set_pipeline_stages(self, stages): def set_pipeline_stages(self, stages):
"""Set the stages of the pipeline""" """Set the stages of the pipeline"""
if isinstance(stages, bool) or not isinstance(stages, int): if isinstance(stages, bool) or not isinstance(stages, int):
raise TypeError("For 'set_auto_parallel_context().set_pipeline_stages', the argument 'pipeline_stages' " raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
"must be int, but got the type : {}.".format(type(stages))) "must be int, but got the type : {}.".format(type(stages)))
if stages < 1: if stages < 1:
raise ValueError("For 'set_auto_parallel_context().set_pipeline_stages', the argument 'pipeline_stages' " raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
"should be greater or equal 1, but got the value of stages : {}.".format(stages)) "should be greater or equal 1, but got the value of stages : {}.".format(stages))
self.check_context_handle() self.check_context_handle()
self._context_handle.set_pipeline_stage_split_num(stages) self._context_handle.set_pipeline_stage_split_num(stages)
@ -271,7 +271,7 @@ class _AutoParallelContext:
loss_repeated_mean (bool): The loss_repeated_mean flag. loss_repeated_mean (bool): The loss_repeated_mean flag.
""" """
if not isinstance(loss_repeated_mean, bool): if not isinstance(loss_repeated_mean, bool):
raise TypeError("For 'auto_parallel_context().set_loss_repeated_mean', the argument 'loss_repeated_mean' " raise TypeError("For 'auto_parallel_context', the argument 'loss_repeated_mean' "
"must be bool, but got the type : {}.".format(type(loss_repeated_mean))) "must be bool, but got the type : {}.".format(type(loss_repeated_mean)))
self.check_context_handle() self.check_context_handle()
self._context_handle.set_loss_repeated_mean(loss_repeated_mean) self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
@ -436,22 +436,22 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
if isinstance(dataset_strategy, str): if isinstance(dataset_strategy, str):
if dataset_strategy not in ("full_batch", "data_parallel"): if dataset_strategy not in ("full_batch", "data_parallel"):
raise ValueError("For 'set_auto_parallel_context().set_dataset_strategy', the argument " raise ValueError("For 'set_auto_parallel_context', the argument "
"'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}." "'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}."
.format(dataset_strategy)) .format(dataset_strategy))
self._context_handle.set_full_batch(dataset_strategy == "full_batch") self._context_handle.set_full_batch(dataset_strategy == "full_batch")
self._dataset_strategy_using_str = True self._dataset_strategy_using_str = True
return return
if not isinstance(dataset_strategy, tuple): if not isinstance(dataset_strategy, tuple):
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the argument 'dataset_strategy' " raise TypeError("For 'set_auto_parallel_context', the argument 'dataset_strategy' "
"must be str or tuple type, but got the type : {}.".format(type(dataset_strategy))) "must be str or tuple type, but got the type : {}.".format(type(dataset_strategy)))
for ele in dataset_strategy: for ele in dataset_strategy:
if not isinstance(ele, tuple): if not isinstance(ele, tuple):
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the element of argument " raise TypeError("For 'set_auto_parallel_context', the element of argument "
"'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele))) "'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele)))
for dim in ele: for dim in ele:
if not isinstance(dim, int): if not isinstance(dim, int):
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the element of argument " raise TypeError("For 'set_auto_parallel_context', the element of argument "
"'dataset_strategy' must be int type, but got the type : {} .".format(type(dim))) "'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
self._dataset_strategy_using_str = False self._dataset_strategy_using_str = False
self._context_handle.set_dataset_strategy(dataset_strategy) self._context_handle.set_dataset_strategy(dataset_strategy)
@ -650,7 +650,7 @@ class _AutoParallelContext:
""" """
self.check_context_handle() self.check_context_handle()
if not isinstance(enable_parallel_optimizer, bool): if not isinstance(enable_parallel_optimizer, bool):
raise TypeError("For 'set_auto_parallel_context().set_enable_parallel_optimizer', " raise TypeError("For 'set_auto_parallel_context', "
"the argument 'enable_parallel_optimizer' must be bool, but got the type : {}." "the argument 'enable_parallel_optimizer' must be bool, but got the type : {}."
.format(type(enable_parallel_optimizer))) .format(type(enable_parallel_optimizer)))
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer) self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)