autoparallel support dataset in gpu

This commit is contained in:
lichenever 2020-07-22 16:50:31 +08:00
parent 8f4bab4e75
commit e712c6cfe5
2 changed files with 8 additions and 2 deletions

View File

@ -44,7 +44,10 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
return inst_context_;
}
ParallelContext::ParallelContext() { Reset(); }
ParallelContext::ParallelContext() {
communication_backend_ = HCCL_BACKEND;
Reset();
}
void ParallelContext::Reset() {
mirror_mean_ = false;
@ -53,7 +56,6 @@ void ParallelContext::Reset() {
loss_repeated_mean_ = true;
device_num_ = 1;
global_rank_ = 0;
communication_backend_ = HCCL_BACKEND;
device_num_is_set_ = false;
global_rank_is_set_ = false;
parallel_mode_ = STAND_ALONE;

View File

@ -30,6 +30,8 @@ from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ._utils import _to_full_tensor
from ..parallel._utils import _need_to_full
from ..common import dtype as mstype
from .dataset_helper import DatasetHelper
from . import amp
@ -418,6 +420,8 @@ class Model:
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
if _need_to_full():
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs)
cb_params.cur_step_num += dataset_helper.sink_size()