!6332 check wether full_batch valid

Merge pull request !6332 from yihuaijie/master
This commit is contained in:
mindspore-ci-bot 2020-09-17 09:09:32 +08:00 committed by Gitee
commit 72d5256d1c
3 changed files with 34 additions and 1 deletions

View File

@ -23,7 +23,7 @@ from mindspore import log as logger
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
from .tensor import Tensor as MsTensor
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor
from ..parallel._ps_context import _is_role_pserver
# store ms_function class compiled pipeline cache
ms_compile_cache = {}
@ -384,6 +384,7 @@ class _Executor:
Bool, if the graph has been compiled before, return False, else return True.
"""
obj.check_names()
_check_full_batch()
args_names, args_list = _generate_pip_args(obj, *args)
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)

View File

@ -32,6 +32,18 @@ def _get_full_batch():
"""Get whether to use full_batch."""
return auto_parallel_context().get_full_batch()
def _check_full_batch():
"""
full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
Raises:
RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel.
"""
parallel_mode = _get_parallel_mode()
full_batch = _get_full_batch()
if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch):
raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.")
def _need_to_full():
"""Check whether to convert input to full shape or tensor."""

View File

@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
@ -88,3 +89,22 @@ def test_all_to_all():
strategy1 = ((8, 1),)
_reset_op_id()
all_to_all_common(strategy1)
def test_data_parallel_mode():
_reset_op_id()
learning_rate = 0.1
momentum = 0.9
epoch_size = 2
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, full_batch=True)
predict = Tensor(np.ones([256, 128]), dtype=ms.float32)
label = Tensor(np.ones([256]), dtype=ms.int32)
dataset = Dataset(predict, label, 2)
net = all_to_all_net(None)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = Momentum(net.trainable_params(), learning_rate, momentum)
model = Model(net, loss, opt)
with pytest.raises(RuntimeError):
model.train(epoch_size, dataset, dataset_sink_mode=False)