forked from mindspore-Ecosystem/mindspore
add output_numpy validation to iterator
This commit is contained in:
parent
d26df7cdcb
commit
a6360cb2e4
|
@ -45,7 +45,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \
|
||||
check_paddeddataset
|
||||
check_paddeddataset, check_iterator
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
|
||||
|
||||
|
@ -1149,6 +1149,7 @@ class Dataset:
|
|||
|
||||
return SaveOp(self).save(file_names, file_type)
|
||||
|
||||
@check_iterator
|
||||
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
|
||||
"""
|
||||
Create an iterator over the dataset. The data retrieved will be a list of ndarrays of data.
|
||||
|
@ -1179,10 +1180,14 @@ class Dataset:
|
|||
>>> # convert the returned tuple to a list and print
|
||||
>>> print(list(item))
|
||||
"""
|
||||
if output_numpy is None:
|
||||
output_numpy = False
|
||||
|
||||
if self._noop_mode():
|
||||
return DummyIterator(self, 'tuple')
|
||||
return TupleIterator(self, columns, num_epochs, output_numpy)
|
||||
|
||||
@check_iterator
|
||||
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
||||
"""
|
||||
Create an iterator over the dataset. The data retrieved will be a dictionary.
|
||||
|
@ -1210,6 +1215,9 @@ class Dataset:
|
|||
>>> # print the data in column1
|
||||
>>> print(item["column1"])
|
||||
"""
|
||||
if output_numpy is None:
|
||||
output_numpy = False
|
||||
|
||||
if self._noop_mode():
|
||||
return DummyIterator(self, 'dict')
|
||||
return DictIterator(self, num_epochs, output_numpy)
|
||||
|
@ -2583,10 +2591,10 @@ class TransferDataset(DatasetOp):
|
|||
args["send_epoch_end"] = self._send_epoch_end
|
||||
return args
|
||||
|
||||
def create_dict_iterator(self, num_epochs=-1):
|
||||
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
||||
raise RuntimeError("TransferDataset is not iterable.")
|
||||
|
||||
def create_tuple_iterator(self, columns=None, num_epochs=-1):
|
||||
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
|
||||
raise RuntimeError("TransferDataset is not iterable.")
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
@ -276,6 +276,18 @@ def check_save(method):
|
|||
|
||||
return new_method
|
||||
|
||||
def check_iterator(method):
|
||||
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
nreq_param_bool = ['output_numpy']
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_minddataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
|
||||
|
|
|
@ -125,6 +125,32 @@ def test_iterator_weak_ref():
|
|||
|
||||
_cleanup()
|
||||
|
||||
def test_iterator_exception():
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
try:
|
||||
_ = data.create_dict_iterator(output_numpy="123")
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||
|
||||
try:
|
||||
_ = data.create_dict_iterator(output_numpy=123)
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||
|
||||
try:
|
||||
_ = data.create_tuple_iterator(output_numpy="123")
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||
|
||||
try:
|
||||
_ = data.create_tuple_iterator(output_numpy=123)
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||
|
||||
|
||||
class MyDict(dict):
|
||||
def __getattr__(self, key):
|
||||
|
@ -157,4 +183,5 @@ def test_tree_copy():
|
|||
if __name__ == '__main__':
|
||||
test_iterator_create_tuple_numpy()
|
||||
test_iterator_weak_ref()
|
||||
test_iterator_exception()
|
||||
test_tree_copy()
|
||||
|
|
Loading…
Reference in New Issue