diff --git a/mindspore/python/mindspore/dataset/engine/datasets.py b/mindspore/python/mindspore/dataset/engine/datasets.py index 0d953f5993f..15701286d55 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets.py +++ b/mindspore/python/mindspore/dataset/engine/datasets.py @@ -1431,7 +1431,7 @@ class Dataset: output_numpy = False if Dataset._noop_mode(): - return DummyIterator(self, 'tuple') + return DummyIterator(self, 'tuple', output_numpy) return TupleIterator(self, columns, num_epochs, output_numpy, do_copy) @check_dict_iterator @@ -1463,7 +1463,7 @@ class Dataset: output_numpy = False if Dataset._noop_mode(): - return DummyIterator(self, 'dict') + return DummyIterator(self, 'dict', output_numpy) return DictIterator(self, num_epochs, output_numpy) def __iter__(self): diff --git a/mindspore/python/mindspore/dataset/engine/iterators.py b/mindspore/python/mindspore/dataset/engine/iterators.py index ac2a47c159e..aba171b92c5 100644 --- a/mindspore/python/mindspore/dataset/engine/iterators.py +++ b/mindspore/python/mindspore/dataset/engine/iterators.py @@ -261,26 +261,33 @@ class DummyIterator: A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED" """ - def __init__(self, dataset, mode): + def __init__(self, dataset, mode, output_numpy=False): self.mode = mode self.shapes = dataset.output_shapes() self.types = dataset.output_types() + self.col_names = dataset.get_col_names() self.fetched_first = False + self.output_numpy = output_numpy def __get_tensor(self): + """Get a next tensor.""" tensor_row = [] for np_shape, np_type in zip(self.shapes, self.types): input_np = np.zeros(np_shape, np_type) tensor = Tensor(input_np) - tensor_row.append(tensor) + if self.output_numpy: + tensor_row.append(tensor.asnumpy()) + else: + tensor_row.append(tensor) + if self.mode == "dict": + tensor_row = {col_name: tensor for col_name, tensor in zip(self.col_names, tensor_row)} return tensor_row def __iter__(self): return self def __next__(self): - if self.mode == "tuple": - if not self.fetched_first: - self.fetched_first = True - return self.__get_tensor() + if not self.fetched_first: + self.fetched_first = True + return self.__get_tensor() raise StopIteration() diff --git a/tests/ut/python/dataset/test_noop_mode.py b/tests/ut/python/dataset/test_noop_mode.py index 5dccc24b795..2267698645b 100644 --- a/tests/ut/python/dataset/test_noop_mode.py +++ b/tests/ut/python/dataset/test_noop_mode.py @@ -33,7 +33,7 @@ def test_noop_pserver(): num = 0 for _ in data1.create_dict_iterator(num_epochs=1): num += 1 - assert num == 0 + assert num == 1 del os.environ['MS_ROLE'] context.set_ps_context(enable_ps=False) @@ -50,7 +50,7 @@ def test_noop_sched(): num = 0 for _ in data1.create_dict_iterator(num_epochs=1): num += 1 - assert num == 0 + assert num == 1 del os.environ['MS_ROLE'] context.set_ps_context(enable_ps=False)