!3330 [AutoParallel]Support dataset in GPU

Merge pull request !3330 from lichen/autoparallel_support_dataset_in_gpu
This commit is contained in:
mindspore-ci-bot 2020-07-22 20:06:22 +08:00 committed by Gitee
commit 08bafed565
2 changed files with 8 additions and 2 deletions

View File

@ -44,7 +44,10 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
return inst_context_;
}
ParallelContext::ParallelContext() { Reset(); }
ParallelContext::ParallelContext() {
communication_backend_ = HCCL_BACKEND;
Reset();
}
void ParallelContext::Reset() {
mirror_mean_ = false;
@ -53,7 +56,6 @@ void ParallelContext::Reset() {
loss_repeated_mean_ = true;
device_num_ = 1;
global_rank_ = 0;
communication_backend_ = HCCL_BACKEND;
device_num_is_set_ = false;
global_rank_is_set_ = false;
parallel_mode_ = STAND_ALONE;

View File

@ -30,6 +30,8 @@ from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ._utils import _to_full_tensor
from ..parallel._utils import _need_to_full
from ..common import dtype as mstype
from .dataset_helper import DatasetHelper
from . import amp
@ -418,6 +420,8 @@ class Model:
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
if _need_to_full():
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs)
cb_params.cur_step_num += dataset_helper.sink_size()