forked from mindspore-Ecosystem/mindspore
!19549 fix parallel testcase
Merge pull request !19549 from gziyan/fix_testcase
This commit is contained in:
commit
0d52219090
|
@ -624,25 +624,6 @@ class Model:
|
|||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
self._train_check(train_dataset, dataset_sink_mode, sink_size)
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
|
||||
self._train(epoch,
|
||||
train_dataset,
|
||||
callbacks=callbacks,
|
||||
dataset_sink_mode=dataset_sink_mode,
|
||||
sink_size=sink_size)
|
||||
|
||||
def _train_check(self, train_dataset, dataset_sink_mode, sink_size):
|
||||
"""
|
||||
Check arguments of training.
|
||||
|
||||
Args:
|
||||
train_dataset (Dataset): A training dataset iterator.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
||||
sink_size (int): Control the amount of data in each sink.
|
||||
"""
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
|
||||
raise ValueError("Sink mode is currently not supported when training with a GraphCell.")
|
||||
|
@ -655,6 +636,14 @@ class Model:
|
|||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
|
||||
self._train(epoch,
|
||||
train_dataset,
|
||||
callbacks=callbacks,
|
||||
dataset_sink_mode=dataset_sink_mode,
|
||||
sink_size=sink_size)
|
||||
|
||||
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Evaluation. The data would be passed to network through dataset channel.
|
||||
|
@ -818,9 +807,37 @@ class Model:
|
|||
check_output_data(result)
|
||||
return result
|
||||
|
||||
def _infer_train_check(self, train_dataset, dataset_sink_mode, sink_size):
|
||||
"""
|
||||
Check arguments of training.
|
||||
|
||||
Args:
|
||||
train_dataset (Dataset): A training dataset iterator.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
||||
sink_size (int): Control the amount of data in each sink.
|
||||
"""
|
||||
if context.get_context("mode") != context.GRAPH_MODE:
|
||||
raise RuntimeError('Pre-compile process only supports GRAPH MODE and Ascend target currently.')
|
||||
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
raise RuntimeError('infer train layout only supports semi auto parallel and auto parallel mode.')
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
if not dataset_sink_mode:
|
||||
raise ValueError("Only dataset sink mode is supported for now.")
|
||||
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
|
||||
raise ValueError("Sink mode is currently not supported when training with a GraphCell.")
|
||||
Validator.check_is_int(sink_size)
|
||||
dataset_size = train_dataset.get_dataset_size()
|
||||
if dataset_size == 0:
|
||||
raise ValueError("There is no valid data in dataset, please check dataset file first.")
|
||||
if sink_size == -1:
|
||||
sink_size = dataset_size
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
||||
def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
|
||||
"""
|
||||
Generate parameter layout for the train network in auto or semi auto parallel mode.
|
||||
Only dataset sink mode is supported for now.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
@ -870,11 +887,7 @@ class Model:
|
|||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
|
||||
>>> model.infer_train_layout(dataset)
|
||||
"""
|
||||
if context.get_context("mode") != context.GRAPH_MODE:
|
||||
raise RuntimeError('Pre-compile process only supports GRAPH MODE and Ascend target currently.')
|
||||
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
raise RuntimeError('infer train layout only supports semi auto parallel and auto parallel mode.')
|
||||
self._train_check(train_dataset, dataset_sink_mode, sink_size)
|
||||
self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
|
||||
|
||||
train_dataset.__no_send__ = True
|
||||
train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
||||
|
|
|
@ -106,7 +106,7 @@ def _update_param(param, new_param, strict_load):
|
|||
.format(param.name, param.data.dtype, new_param.data.dtype))
|
||||
raise RuntimeError(msg)
|
||||
|
||||
param.set_data(new_param.data)
|
||||
param.set_data(new_param.data, param.sliced)
|
||||
return
|
||||
|
||||
if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
|
||||
|
|
Loading…
Reference in New Issue