forked from mindspore-Ecosystem/mindspore
raise RuntimeError when using full_batch neither under semi_auto_parallel nor auto_parallel
This commit is contained in:
parent
38babd1452
commit
e4cd67596f
|
@ -23,7 +23,7 @@ from mindspore import log as logger
|
|||
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
|
||||
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
|
||||
from .tensor import Tensor as MsTensor
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor
|
||||
from ..parallel._ps_context import _is_role_pserver
|
||||
# store ms_function class compiled pipeline cache
|
||||
ms_compile_cache = {}
|
||||
|
@ -384,6 +384,7 @@ class _Executor:
|
|||
Bool, if the graph has been compiled before, return False, else return True.
|
||||
"""
|
||||
obj.check_names()
|
||||
_check_full_batch()
|
||||
args_names, args_list = _generate_pip_args(obj, *args)
|
||||
dic = dict(zip(args_names, args_list))
|
||||
key = generate_key(phase, dic)
|
||||
|
|
|
@ -32,6 +32,18 @@ def _get_full_batch():
|
|||
"""Get whether to use full_batch."""
|
||||
return auto_parallel_context().get_full_batch()
|
||||
|
||||
def _check_full_batch():
|
||||
"""
|
||||
full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel.
|
||||
"""
|
||||
parallel_mode = _get_parallel_mode()
|
||||
full_batch = _get_full_batch()
|
||||
if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch):
|
||||
raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.")
|
||||
|
||||
|
||||
def _need_to_full():
|
||||
"""Check whether to convert input to full shape or tensor."""
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
@ -88,3 +89,22 @@ def test_all_to_all():
|
|||
strategy1 = ((8, 1),)
|
||||
_reset_op_id()
|
||||
all_to_all_common(strategy1)
|
||||
|
||||
def test_data_parallel_mode():
|
||||
_reset_op_id()
|
||||
learning_rate = 0.1
|
||||
momentum = 0.9
|
||||
epoch_size = 2
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, full_batch=True)
|
||||
predict = Tensor(np.ones([256, 128]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([256]), dtype=ms.int32)
|
||||
dataset = Dataset(predict, label, 2)
|
||||
net = all_to_all_net(None)
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
opt = Momentum(net.trainable_params(), learning_rate, momentum)
|
||||
model = Model(net, loss, opt)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
model.train(epoch_size, dataset, dataset_sink_mode=False)
|
||||
|
|
Loading…
Reference in New Issue