forked from mindspore-Ecosystem/mindspore
!3330 [AutoParallel]Support dataset in GPU
Merge pull request !3330 from lichen/autoparallel_support_dataset_in_gpu
This commit is contained in:
commit
08bafed565
|
@ -44,7 +44,10 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
|
|||
return inst_context_;
|
||||
}
|
||||
|
||||
ParallelContext::ParallelContext() { Reset(); }
|
||||
ParallelContext::ParallelContext() {
|
||||
communication_backend_ = HCCL_BACKEND;
|
||||
Reset();
|
||||
}
|
||||
|
||||
void ParallelContext::Reset() {
|
||||
mirror_mean_ = false;
|
||||
|
@ -53,7 +56,6 @@ void ParallelContext::Reset() {
|
|||
loss_repeated_mean_ = true;
|
||||
device_num_ = 1;
|
||||
global_rank_ = 0;
|
||||
communication_backend_ = HCCL_BACKEND;
|
||||
device_num_is_set_ = false;
|
||||
global_rank_is_set_ = false;
|
||||
parallel_mode_ = STAND_ALONE;
|
||||
|
|
|
@ -30,6 +30,8 @@ from ..nn.metrics import Loss
|
|||
from .. import nn
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from .parallel_utils import ParallelMode
|
||||
from ._utils import _to_full_tensor
|
||||
from ..parallel._utils import _need_to_full
|
||||
from ..common import dtype as mstype
|
||||
from .dataset_helper import DatasetHelper
|
||||
from . import amp
|
||||
|
@ -418,6 +420,8 @@ class Model:
|
|||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
if _need_to_full():
|
||||
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.cur_step_num += dataset_helper.sink_size()
|
||||
|
|
Loading…
Reference in New Issue