Support dict type data in noop mode
This commit is contained in:
parent
f98efcf71f
commit
75f11f8e6f
|
@ -1431,7 +1431,7 @@ class Dataset:
|
|||
output_numpy = False
|
||||
|
||||
if Dataset._noop_mode():
|
||||
return DummyIterator(self, 'tuple')
|
||||
return DummyIterator(self, 'tuple', output_numpy)
|
||||
return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
|
||||
|
||||
@check_dict_iterator
|
||||
|
@ -1463,7 +1463,7 @@ class Dataset:
|
|||
output_numpy = False
|
||||
|
||||
if Dataset._noop_mode():
|
||||
return DummyIterator(self, 'dict')
|
||||
return DummyIterator(self, 'dict', output_numpy)
|
||||
return DictIterator(self, num_epochs, output_numpy)
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
@ -261,26 +261,33 @@ class DummyIterator:
|
|||
A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, mode):
|
||||
def __init__(self, dataset, mode, output_numpy=False):
|
||||
self.mode = mode
|
||||
self.shapes = dataset.output_shapes()
|
||||
self.types = dataset.output_types()
|
||||
self.col_names = dataset.get_col_names()
|
||||
self.fetched_first = False
|
||||
self.output_numpy = output_numpy
|
||||
|
||||
def __get_tensor(self):
|
||||
"""Get a next tensor."""
|
||||
tensor_row = []
|
||||
for np_shape, np_type in zip(self.shapes, self.types):
|
||||
input_np = np.zeros(np_shape, np_type)
|
||||
tensor = Tensor(input_np)
|
||||
tensor_row.append(tensor)
|
||||
if self.output_numpy:
|
||||
tensor_row.append(tensor.asnumpy())
|
||||
else:
|
||||
tensor_row.append(tensor)
|
||||
if self.mode == "dict":
|
||||
tensor_row = {col_name: tensor for col_name, tensor in zip(self.col_names, tensor_row)}
|
||||
return tensor_row
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.mode == "tuple":
|
||||
if not self.fetched_first:
|
||||
self.fetched_first = True
|
||||
return self.__get_tensor()
|
||||
if not self.fetched_first:
|
||||
self.fetched_first = True
|
||||
return self.__get_tensor()
|
||||
raise StopIteration()
|
||||
|
|
|
@ -33,7 +33,7 @@ def test_noop_pserver():
|
|||
num = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num += 1
|
||||
assert num == 0
|
||||
assert num == 1
|
||||
del os.environ['MS_ROLE']
|
||||
context.set_ps_context(enable_ps=False)
|
||||
|
||||
|
@ -50,7 +50,7 @@ def test_noop_sched():
|
|||
num = 0
|
||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||
num += 1
|
||||
assert num == 0
|
||||
assert num == 1
|
||||
del os.environ['MS_ROLE']
|
||||
context.set_ps_context(enable_ps=False)
|
||||
|
||||
|
|
Loading…
Reference in New Issue