!19549 fix parallel testcase

Merge pull request !19549 from gziyan/fix_testcase
This commit is contained in:
i-robot 2021-07-07 08:14:02 +00:00 committed by Gitee
commit 0d52219090
2 changed files with 38 additions and 25 deletions

View File

@ -624,25 +624,6 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
"""
self._train_check(train_dataset, dataset_sink_mode, sink_size)
_device_number_check(self._parallel_mode, self._device_number)
self._train(epoch,
train_dataset,
callbacks=callbacks,
dataset_sink_mode=dataset_sink_mode,
sink_size=sink_size)
def _train_check(self, train_dataset, dataset_sink_mode, sink_size):
"""
Check arguments of training.
Args:
train_dataset (Dataset): A training dataset iterator.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
sink_size (int): Control the amount of data in each sink.
"""
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
raise ValueError("Sink mode is currently not supported when training with a GraphCell.")
@ -655,6 +636,14 @@ class Model:
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
_device_number_check(self._parallel_mode, self._device_number)
self._train(epoch,
train_dataset,
callbacks=callbacks,
dataset_sink_mode=dataset_sink_mode,
sink_size=sink_size)
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
"""
Evaluation. The data would be passed to network through dataset channel.
@ -818,9 +807,37 @@ class Model:
check_output_data(result)
return result
def _infer_train_check(self, train_dataset, dataset_sink_mode, sink_size):
"""
Check arguments of training.
Args:
train_dataset (Dataset): A training dataset iterator.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
sink_size (int): Control the amount of data in each sink.
"""
if context.get_context("mode") != context.GRAPH_MODE:
raise RuntimeError('Pre-compile process only supports GRAPH MODE and Ascend target currently.')
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
raise RuntimeError('infer train layout only supports semi auto parallel and auto parallel mode.')
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
if not dataset_sink_mode:
raise ValueError("Only dataset sink mode is supported for now.")
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
raise ValueError("Sink mode is currently not supported when training with a GraphCell.")
Validator.check_is_int(sink_size)
dataset_size = train_dataset.get_dataset_size()
if dataset_size == 0:
raise ValueError("There is no valid data in dataset, please check dataset file first.")
if sink_size == -1:
sink_size = dataset_size
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
"""
Generate parameter layout for the train network in auto or semi auto parallel mode.
Only dataset sink mode is supported for now.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
@ -870,11 +887,7 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.infer_train_layout(dataset)
"""
if context.get_context("mode") != context.GRAPH_MODE:
raise RuntimeError('Pre-compile process only supports GRAPH MODE and Ascend target currently.')
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
raise RuntimeError('infer train layout only supports semi auto parallel and auto parallel mode.')
self._train_check(train_dataset, dataset_sink_mode, sink_size)
self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
train_dataset.__no_send__ = True
train_dataset_helper, train_network = self._exec_preprocess(is_train=True,

View File

@ -106,7 +106,7 @@ def _update_param(param, new_param, strict_load):
.format(param.name, param.data.dtype, new_param.data.dtype))
raise RuntimeError(msg)
param.set_data(new_param.data)
param.set_data(new_param.data, param.sliced)
return
if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):