Support dict type data in noop mode

This commit is contained in:
ZPaC 2022-07-01 09:31:44 +08:00
parent f98efcf71f
commit 75f11f8e6f
3 changed files with 17 additions and 10 deletions

View File

@ -1431,7 +1431,7 @@ class Dataset:
output_numpy = False
if Dataset._noop_mode():
return DummyIterator(self, 'tuple')
return DummyIterator(self, 'tuple', output_numpy)
return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
@check_dict_iterator
@ -1463,7 +1463,7 @@ class Dataset:
output_numpy = False
if Dataset._noop_mode():
return DummyIterator(self, 'dict')
return DummyIterator(self, 'dict', output_numpy)
return DictIterator(self, num_epochs, output_numpy)
def __iter__(self):

View File

@ -261,26 +261,33 @@ class DummyIterator:
A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
"""
def __init__(self, dataset, mode):
def __init__(self, dataset, mode, output_numpy=False):
self.mode = mode
self.shapes = dataset.output_shapes()
self.types = dataset.output_types()
self.col_names = dataset.get_col_names()
self.fetched_first = False
self.output_numpy = output_numpy
def __get_tensor(self):
"""Get a next tensor."""
tensor_row = []
for np_shape, np_type in zip(self.shapes, self.types):
input_np = np.zeros(np_shape, np_type)
tensor = Tensor(input_np)
tensor_row.append(tensor)
if self.output_numpy:
tensor_row.append(tensor.asnumpy())
else:
tensor_row.append(tensor)
if self.mode == "dict":
tensor_row = {col_name: tensor for col_name, tensor in zip(self.col_names, tensor_row)}
return tensor_row
def __iter__(self):
return self
def __next__(self):
if self.mode == "tuple":
if not self.fetched_first:
self.fetched_first = True
return self.__get_tensor()
if not self.fetched_first:
self.fetched_first = True
return self.__get_tensor()
raise StopIteration()

View File

@ -33,7 +33,7 @@ def test_noop_pserver():
num = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num += 1
assert num == 0
assert num == 1
del os.environ['MS_ROLE']
context.set_ps_context(enable_ps=False)
@ -50,7 +50,7 @@ def test_noop_sched():
num = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num += 1
assert num == 0
assert num == 1
del os.environ['MS_ROLE']
context.set_ps_context(enable_ps=False)