From ba9b30253a7f52e4bf0fae36cef6365604036b22 Mon Sep 17 00:00:00 2001 From: lilei Date: Thu, 9 Dec 2021 21:28:54 +0800 Subject: [PATCH] modify while list --- mindspore/nn/wrap/cell_wrapper.py | 3 +++ mindspore/parallel/_auto_parallel_context.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 92a84d2096a..123baaea942 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -471,6 +471,9 @@ class MicroBatchInterleaved(Cell): network (Cell): The target network to wrap. interleave_num (int): split num of batch size. Default: 2. + Supported Platforms: + ``Ascend`` ``GPU`` + Examples: >>> net = Net() >>> net = MicroBatchInterleaved(net, 4) diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 921c98c6da1..ad0be70ea44 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -203,10 +203,10 @@ class _AutoParallelContext: def set_pipeline_stages(self, stages): """Set the stages of the pipeline""" if isinstance(stages, bool) or not isinstance(stages, int): - raise TypeError("For 'set_auto_parallel_context().set_pipeline_stages', the argument 'pipeline_stages' " + raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_stages' " "must be int, but got the type : {}.".format(type(stages))) if stages < 1: - raise ValueError("For 'set_auto_parallel_context().set_pipeline_stages', the argument 'pipeline_stages' " + raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_stages' " "should be greater or equal 1, but got the value of stages : {}.".format(stages)) self.check_context_handle() self._context_handle.set_pipeline_stage_split_num(stages) @@ -266,7 +266,7 @@ class _AutoParallelContext: loss_repeated_mean (bool): The loss_repeated_mean flag. """ if not isinstance(loss_repeated_mean, bool): - raise TypeError("For 'auto_parallel_context().set_loss_repeated_mean', the argument 'loss_repeated_mean' " + raise TypeError("For 'auto_parallel_context', the argument 'loss_repeated_mean' " "must be bool, but got the type : {}.".format(type(loss_repeated_mean))) self.check_context_handle() self._context_handle.set_loss_repeated_mean(loss_repeated_mean) @@ -431,22 +431,22 @@ class _AutoParallelContext: self.check_context_handle() if isinstance(dataset_strategy, str): if dataset_strategy not in ("full_batch", "data_parallel"): - raise ValueError("For 'set_auto_parallel_context().set_dataset_strategy', the argument " + raise ValueError("For 'set_auto_parallel_context', the argument " "'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}." .format(dataset_strategy)) self._context_handle.set_full_batch(dataset_strategy == "full_batch") self._dataset_strategy_using_str = True return if not isinstance(dataset_strategy, tuple): - raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the argument 'dataset_strategy' " + raise TypeError("For 'set_auto_parallel_context', the argument 'dataset_strategy' " "must be str or tuple type, but got the type : {}.".format(type(dataset_strategy))) for ele in dataset_strategy: if not isinstance(ele, tuple): - raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the element of argument " + raise TypeError("For 'set_auto_parallel_context', the element of argument " "'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele))) for dim in ele: if not isinstance(dim, int): - raise TypeError("For 'set_auto_parallel_context().set_dataset_strategy', the element of argument " + raise TypeError("For 'set_auto_parallel_context', the element of argument " "'dataset_strategy' must be int type, but got the type : {} .".format(type(dim))) self._dataset_strategy_using_str = False self._context_handle.set_dataset_strategy(dataset_strategy) @@ -645,7 +645,7 @@ class _AutoParallelContext: """ self.check_context_handle() if not isinstance(enable_parallel_optimizer, bool): - raise TypeError("For 'set_auto_parallel_context().set_enable_parallel_optimizer', " + raise TypeError("For 'set_auto_parallel_context', " "the argument 'enable_parallel_optimizer' must be bool, but got the type : {}." .format(type(enable_parallel_optimizer))) self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)