add output_numpy validation to iterator

This commit is contained in:
xiefangqi 2020-09-19 11:28:16 +08:00
parent d26df7cdcb
commit a6360cb2e4
3 changed files with 50 additions and 3 deletions

View File

@ -45,7 +45,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \
check_paddeddataset
check_paddeddataset, check_iterator
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
@ -1149,6 +1149,7 @@ class Dataset:
return SaveOp(self).save(file_names, file_type)
@check_iterator
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
"""
Create an iterator over the dataset. The data retrieved will be a list of ndarrays of data.
@ -1179,10 +1180,14 @@ class Dataset:
>>> # convert the returned tuple to a list and print
>>> print(list(item))
"""
if output_numpy is None:
output_numpy = False
if self._noop_mode():
return DummyIterator(self, 'tuple')
return TupleIterator(self, columns, num_epochs, output_numpy)
@check_iterator
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
"""
Create an iterator over the dataset. The data retrieved will be a dictionary.
@ -1210,6 +1215,9 @@ class Dataset:
>>> # print the data in column1
>>> print(item["column1"])
"""
if output_numpy is None:
output_numpy = False
if self._noop_mode():
return DummyIterator(self, 'dict')
return DictIterator(self, num_epochs, output_numpy)
@ -2583,10 +2591,10 @@ class TransferDataset(DatasetOp):
args["send_epoch_end"] = self._send_epoch_end
return args
def create_dict_iterator(self, num_epochs=-1):
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
raise RuntimeError("TransferDataset is not iterable.")
def create_tuple_iterator(self, columns=None, num_epochs=-1):
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
raise RuntimeError("TransferDataset is not iterable.")
def __iter__(self):

View File

@ -276,6 +276,18 @@ def check_save(method):
return new_method
def check_iterator(method):
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_bool = ['output_numpy']
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
return method(self, *args, **kwargs)
return new_method
def check_minddataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""

View File

@ -125,6 +125,32 @@ def test_iterator_weak_ref():
_cleanup()
def test_iterator_exception():
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
try:
_ = data.create_dict_iterator(output_numpy="123")
assert False
except TypeError as e:
assert "Argument output_numpy with value 123 is not of type" in str(e)
try:
_ = data.create_dict_iterator(output_numpy=123)
assert False
except TypeError as e:
assert "Argument output_numpy with value 123 is not of type" in str(e)
try:
_ = data.create_tuple_iterator(output_numpy="123")
assert False
except TypeError as e:
assert "Argument output_numpy with value 123 is not of type" in str(e)
try:
_ = data.create_tuple_iterator(output_numpy=123)
assert False
except TypeError as e:
assert "Argument output_numpy with value 123 is not of type" in str(e)
class MyDict(dict):
def __getattr__(self, key):
@ -157,4 +183,5 @@ def test_tree_copy():
if __name__ == '__main__':
test_iterator_create_tuple_numpy()
test_iterator_weak_ref()
test_iterator_exception()
test_tree_copy()