From e4cd67596f07a9c1e5d75189b81b8002aa80ac0a Mon Sep 17 00:00:00 2001 From: Yi Huaijie Date: Wed, 16 Sep 2020 14:36:27 +0800 Subject: [PATCH] raise RuntimeError when using full_batch neither under semi_auto_parallel nor auto_parallel --- mindspore/common/api.py | 3 ++- mindspore/parallel/_utils.py | 12 ++++++++++++ tests/ut/python/parallel/test_full_batch.py | 20 ++++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 571a54abc75..279a93dc8a8 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -23,7 +23,7 @@ from mindspore import log as logger from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_ from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend from .tensor import Tensor as MsTensor -from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor +from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor from ..parallel._ps_context import _is_role_pserver # store ms_function class compiled pipeline cache ms_compile_cache = {} @@ -384,6 +384,7 @@ class _Executor: Bool, if the graph has been compiled before, return False, else return True. """ obj.check_names() + _check_full_batch() args_names, args_list = _generate_pip_args(obj, *args) dic = dict(zip(args_names, args_list)) key = generate_key(phase, dic) diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 1c93ae20029..0cecc692cdd 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -32,6 +32,18 @@ def _get_full_batch(): """Get whether to use full_batch.""" return auto_parallel_context().get_full_batch() +def _check_full_batch(): + """ + full_batch could only be used under semi_auto_parallel or auto_parallel, check it. + + Raises: + RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel. + """ + parallel_mode = _get_parallel_mode() + full_batch = _get_full_batch() + if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch): + raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.") + def _need_to_full(): """Check whether to convert input to full shape or tensor.""" diff --git a/tests/ut/python/parallel/test_full_batch.py b/tests/ut/python/parallel/test_full_batch.py index 9d504f2af26..dc82cb04a25 100644 --- a/tests/ut/python/parallel/test_full_batch.py +++ b/tests/ut/python/parallel/test_full_batch.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import pytest import mindspore as ms import mindspore.nn as nn @@ -88,3 +89,22 @@ def test_all_to_all(): strategy1 = ((8, 1),) _reset_op_id() all_to_all_common(strategy1) + +def test_data_parallel_mode(): + _reset_op_id() + learning_rate = 0.1 + momentum = 0.9 + epoch_size = 2 + context.set_context(mode=context.GRAPH_MODE, save_graphs=False) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, full_batch=True) + predict = Tensor(np.ones([256, 128]), dtype=ms.float32) + label = Tensor(np.ones([256]), dtype=ms.int32) + dataset = Dataset(predict, label, 2) + net = all_to_all_net(None) + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + opt = Momentum(net.trainable_params(), learning_rate, momentum) + model = Model(net, loss, opt) + + with pytest.raises(RuntimeError): + model.train(epoch_size, dataset, dataset_sink_mode=False)