forked from mindspore-Ecosystem/mindspore
md support ps-lite
This commit is contained in:
parent
089623ad19
commit
add19a591c
|
@ -38,7 +38,7 @@ from mindspore._c_expression import typing
|
|||
|
||||
from mindspore import log as logger
|
||||
from . import samplers
|
||||
from .iterators import DictIterator, TupleIterator
|
||||
from .iterators import DictIterator, TupleIterator, DummyIterator
|
||||
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
|
||||
check_rename, check_numpyslicesdataset, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
|
@ -146,6 +146,12 @@ class Dataset:
|
|||
self._num_classes = None
|
||||
self._repeat_count = None
|
||||
self._sync = False
|
||||
self.ms_role = os.getenv("MS_ROLE")
|
||||
|
||||
def _noop_mode(self):
|
||||
if self.ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __add__(self, datasets):
|
||||
return self.concat(datasets)
|
||||
|
@ -1062,6 +1068,8 @@ class Dataset:
|
|||
>>> # convert the returned tuple to a list and print
|
||||
>>> print(list(item))
|
||||
"""
|
||||
if self._noop_mode():
|
||||
return DummyIterator(self, 'tuple')
|
||||
return TupleIterator(self, columns)
|
||||
|
||||
def create_dict_iterator(self):
|
||||
|
@ -1085,6 +1093,8 @@ class Dataset:
|
|||
>>> print(item["column1"])
|
||||
|
||||
"""
|
||||
if self._noop_mode():
|
||||
return DummyIterator(self, 'dict')
|
||||
return DictIterator(self)
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -2318,6 +2328,8 @@ class TransferDataset(DatasetOp):
|
|||
|
||||
def send(self):
|
||||
# need to keep iterator alive so the executionTree is not destroyed
|
||||
if self._noop_mode():
|
||||
return
|
||||
self.iterator = TupleIterator(self)
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
from abc import abstractmethod
|
||||
import copy
|
||||
import weakref
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._c_dataengine import DEPipeline
|
||||
from mindspore._c_dataengine import OpName
|
||||
|
||||
|
@ -287,3 +289,32 @@ class TupleIterator(Iterator):
|
|||
"""
|
||||
|
||||
return [t.as_array() for t in self.depipeline.GetNextAsList()]
|
||||
|
||||
|
||||
class DummyIterator():
|
||||
"""
|
||||
A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
|
||||
"""
|
||||
def __init__(self, dataset, mode):
|
||||
self.mode = mode
|
||||
self.shapes = dataset.output_shapes()
|
||||
self.types = dataset.output_types()
|
||||
self.fetched_first = False
|
||||
|
||||
def __get_tensor(self):
|
||||
tensor_row = []
|
||||
for np_shape, np_type in zip(self.shapes, self.types):
|
||||
input_np = np.zeros(np_shape, np_type)
|
||||
tensor = Tensor(input_np)
|
||||
tensor_row.append(tensor)
|
||||
return tensor_row
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.mode == "tuple":
|
||||
if not self.fetched_first:
|
||||
self.fetched_first = True
|
||||
return self.__get_tensor()
|
||||
raise StopIteration()
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Dataset help for minddata dataset"""
|
||||
import math
|
||||
import os
|
||||
|
||||
from mindspore._checkparam import check_bool
|
||||
from .. import context
|
||||
|
@ -60,7 +61,11 @@ class DatasetHelper:
|
|||
if context.get_context("device_target") == "Ascend":
|
||||
iterclass = _DatasetIterMSLoopSink
|
||||
elif context.get_context("device_target") == "GPU":
|
||||
iterclass = _DatasetIterMS
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
iterclass = _DatasetIterPSLite
|
||||
else:
|
||||
iterclass = _DatasetIterMS
|
||||
elif context.get_context("device_target") == "CPU":
|
||||
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
|
||||
else:
|
||||
|
@ -131,6 +136,9 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
|||
def __init__(self, dataset):
|
||||
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
||||
self.loop_count = self.get_loop_count(dataset)
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
self.loop_count = 1
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
|
||||
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
|
||||
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
|
@ -154,6 +162,18 @@ class _DatasetIterMS(_DatasetIter):
|
|||
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
||||
|
||||
|
||||
class _DatasetIterPSLite(_DatasetIter):
|
||||
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
|
||||
def __init__(self, dataset):
|
||||
super(_DatasetIterPSLite, self).__init__(dataset)
|
||||
self.loop_count = 1
|
||||
self.loop_size = 1
|
||||
self.op = None
|
||||
def op():
|
||||
return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1)
|
||||
self.op = op
|
||||
|
||||
|
||||
class _DatasetIterGE(_DatasetIter):
|
||||
"""Iter for ge"""
|
||||
def __init__(self, dataset):
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Model."""
|
||||
from collections.abc import Iterable
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
|
@ -350,6 +351,9 @@ class Model:
|
|||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.train_dataset_element = None
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
epoch = 1
|
||||
|
||||
# build callback list
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test No-op mode support with Dummy Iterator
|
||||
"""
|
||||
import os
|
||||
import mindspore.dataset as ds
|
||||
|
||||
DATA_DIR = "../data/dataset/testVOC2012"
|
||||
|
||||
def test_noop_pserver():
|
||||
os.environ['MS_ROLE'] = 'MS_PSERVER'
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
|
||||
num = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num += 1
|
||||
assert num == 0
|
||||
del os.environ['MS_ROLE']
|
||||
|
||||
|
||||
def test_noop_sched():
|
||||
os.environ['MS_ROLE'] = 'MS_SCHED'
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False)
|
||||
num = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num += 1
|
||||
assert num == 0
|
||||
del os.environ['MS_ROLE']
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_noop_pserver()
|
||||
test_noop_sched()
|
Loading…
Reference in New Issue