modify while list

This commit is contained in:
lilei 2021-12-09 21:28:54 +08:00
parent b4176e73a7
commit ba9b30253a
2 changed files with 11 additions and 8 deletions

View File

@ -471,6 +471,9 @@ class MicroBatchInterleaved(Cell):
network (Cell): The target network to wrap.
interleave_num (int): split num of batch size. Default: 2.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> net = Net()
>>> net = MicroBatchInterleaved(net, 4)

View File

@ -203,10 +203,10 @@ class _AutoParallelContext:
def set_pipeline_stages(self, stages):
"""Set the stages of the pipeline"""
if isinstance(stages, bool) or not isinstance(stages, int):
raise TypeError("For 'set_auto_parallel_context().set_pipeline_stages', the argument 'pipeline_stages' "
raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
"must be int, but got the type : {}.".format(type(stages)))
if stages < 1:
raise ValueError("For 'set_auto_parallel_context().set_pipeline_stages', the argument 'pipeline_stages' "
raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
"should be greater or equal 1, but got the value of stages : {}.".format(stages))
self.check_context_handle()
self._context_handle.set_pipeline_stage_split_num(stages)
@ -266,7 +266,7 @@ class _AutoParallelContext:
loss_repeated_mean (bool): The loss_repeated_mean flag.
"""
if not isinstance(loss_repeated_mean, bool):
raise TypeError("For 'auto_parallel_context().set_loss_repeated_mean', the argument 'loss_repeated_mean' "
raise TypeError("For 'auto_parallel_context', the argument 'loss_repeated_mean' "
"must be bool, but got the type : {}.".format(type(loss_repeated_mean)))
self.check_context_handle()
self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
@ -431,22 +431,22 @@ class _AutoParallelContext:
self.check_context_handle()
if isinstance(dataset_strategy, str):
if dataset_strategy not in ("full_batch", "data_parallel"):
raise ValueError("For 'set_auto_parallel_context().set_dataset_strategy', the argument "
raise ValueError("For 'set_auto_parallel_context', the argument "
"'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}."
.format(dataset_strategy))
self._context_handle.set_full_batch(dataset_strategy == "full_batch")
self._dataset_strategy_using_str = True
return
if not isinstance(dataset_strategy, tuple):
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the argument 'dataset_strategy' "
raise TypeError("For 'set_auto_parallel_context', the argument 'dataset_strategy' "
"must be str or tuple type, but got the type : {}.".format(type(dataset_strategy)))
for ele in dataset_strategy:
if not isinstance(ele, tuple):
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the element of argument "
raise TypeError("For 'set_auto_parallel_context', the element of argument "
"'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele)))
for dim in ele:
if not isinstance(dim, int):
raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the element of argument "
raise TypeError("For 'set_auto_parallel_context', the element of argument "
"'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
self._dataset_strategy_using_str = False
self._context_handle.set_dataset_strategy(dataset_strategy)
@ -645,7 +645,7 @@ class _AutoParallelContext:
"""
self.check_context_handle()
if not isinstance(enable_parallel_optimizer, bool):
raise TypeError("For 'set_auto_parallel_context().set_enable_parallel_optimizer', "
raise TypeError("For 'set_auto_parallel_context', "
"the argument 'enable_parallel_optimizer' must be bool, but got the type : {}."
.format(type(enable_parallel_optimizer)))
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)