Support dict type data in noop mode
This commit is contained in:
parent
f98efcf71f
commit
75f11f8e6f
|
@ -1431,7 +1431,7 @@ class Dataset:
|
||||||
output_numpy = False
|
output_numpy = False
|
||||||
|
|
||||||
if Dataset._noop_mode():
|
if Dataset._noop_mode():
|
||||||
return DummyIterator(self, 'tuple')
|
return DummyIterator(self, 'tuple', output_numpy)
|
||||||
return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
|
return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
|
||||||
|
|
||||||
@check_dict_iterator
|
@check_dict_iterator
|
||||||
|
@ -1463,7 +1463,7 @@ class Dataset:
|
||||||
output_numpy = False
|
output_numpy = False
|
||||||
|
|
||||||
if Dataset._noop_mode():
|
if Dataset._noop_mode():
|
||||||
return DummyIterator(self, 'dict')
|
return DummyIterator(self, 'dict', output_numpy)
|
||||||
return DictIterator(self, num_epochs, output_numpy)
|
return DictIterator(self, num_epochs, output_numpy)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|
|
@ -261,26 +261,33 @@ class DummyIterator:
|
||||||
A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
|
A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset, mode):
|
def __init__(self, dataset, mode, output_numpy=False):
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.shapes = dataset.output_shapes()
|
self.shapes = dataset.output_shapes()
|
||||||
self.types = dataset.output_types()
|
self.types = dataset.output_types()
|
||||||
|
self.col_names = dataset.get_col_names()
|
||||||
self.fetched_first = False
|
self.fetched_first = False
|
||||||
|
self.output_numpy = output_numpy
|
||||||
|
|
||||||
def __get_tensor(self):
|
def __get_tensor(self):
|
||||||
|
"""Get a next tensor."""
|
||||||
tensor_row = []
|
tensor_row = []
|
||||||
for np_shape, np_type in zip(self.shapes, self.types):
|
for np_shape, np_type in zip(self.shapes, self.types):
|
||||||
input_np = np.zeros(np_shape, np_type)
|
input_np = np.zeros(np_shape, np_type)
|
||||||
tensor = Tensor(input_np)
|
tensor = Tensor(input_np)
|
||||||
tensor_row.append(tensor)
|
if self.output_numpy:
|
||||||
|
tensor_row.append(tensor.asnumpy())
|
||||||
|
else:
|
||||||
|
tensor_row.append(tensor)
|
||||||
|
if self.mode == "dict":
|
||||||
|
tensor_row = {col_name: tensor for col_name, tensor in zip(self.col_names, tensor_row)}
|
||||||
return tensor_row
|
return tensor_row
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
if self.mode == "tuple":
|
if not self.fetched_first:
|
||||||
if not self.fetched_first:
|
self.fetched_first = True
|
||||||
self.fetched_first = True
|
return self.__get_tensor()
|
||||||
return self.__get_tensor()
|
|
||||||
raise StopIteration()
|
raise StopIteration()
|
||||||
|
|
|
@ -33,7 +33,7 @@ def test_noop_pserver():
|
||||||
num = 0
|
num = 0
|
||||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||||
num += 1
|
num += 1
|
||||||
assert num == 0
|
assert num == 1
|
||||||
del os.environ['MS_ROLE']
|
del os.environ['MS_ROLE']
|
||||||
context.set_ps_context(enable_ps=False)
|
context.set_ps_context(enable_ps=False)
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ def test_noop_sched():
|
||||||
num = 0
|
num = 0
|
||||||
for _ in data1.create_dict_iterator(num_epochs=1):
|
for _ in data1.create_dict_iterator(num_epochs=1):
|
||||||
num += 1
|
num += 1
|
||||||
assert num == 0
|
assert num == 1
|
||||||
del os.environ['MS_ROLE']
|
del os.environ['MS_ROLE']
|
||||||
context.set_ps_context(enable_ps=False)
|
context.set_ps_context(enable_ps=False)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue