forked from mindspore-Ecosystem/mindspore
add output_numpy validation to iterator
This commit is contained in:
parent
d26df7cdcb
commit
a6360cb2e4
|
@ -45,7 +45,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
||||||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \
|
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \
|
||||||
check_paddeddataset
|
check_paddeddataset, check_iterator
|
||||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||||
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
|
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
|
||||||
|
|
||||||
|
@ -1149,6 +1149,7 @@ class Dataset:
|
||||||
|
|
||||||
return SaveOp(self).save(file_names, file_type)
|
return SaveOp(self).save(file_names, file_type)
|
||||||
|
|
||||||
|
@check_iterator
|
||||||
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
|
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
|
||||||
"""
|
"""
|
||||||
Create an iterator over the dataset. The data retrieved will be a list of ndarrays of data.
|
Create an iterator over the dataset. The data retrieved will be a list of ndarrays of data.
|
||||||
|
@ -1179,10 +1180,14 @@ class Dataset:
|
||||||
>>> # convert the returned tuple to a list and print
|
>>> # convert the returned tuple to a list and print
|
||||||
>>> print(list(item))
|
>>> print(list(item))
|
||||||
"""
|
"""
|
||||||
|
if output_numpy is None:
|
||||||
|
output_numpy = False
|
||||||
|
|
||||||
if self._noop_mode():
|
if self._noop_mode():
|
||||||
return DummyIterator(self, 'tuple')
|
return DummyIterator(self, 'tuple')
|
||||||
return TupleIterator(self, columns, num_epochs, output_numpy)
|
return TupleIterator(self, columns, num_epochs, output_numpy)
|
||||||
|
|
||||||
|
@check_iterator
|
||||||
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
||||||
"""
|
"""
|
||||||
Create an iterator over the dataset. The data retrieved will be a dictionary.
|
Create an iterator over the dataset. The data retrieved will be a dictionary.
|
||||||
|
@ -1210,6 +1215,9 @@ class Dataset:
|
||||||
>>> # print the data in column1
|
>>> # print the data in column1
|
||||||
>>> print(item["column1"])
|
>>> print(item["column1"])
|
||||||
"""
|
"""
|
||||||
|
if output_numpy is None:
|
||||||
|
output_numpy = False
|
||||||
|
|
||||||
if self._noop_mode():
|
if self._noop_mode():
|
||||||
return DummyIterator(self, 'dict')
|
return DummyIterator(self, 'dict')
|
||||||
return DictIterator(self, num_epochs, output_numpy)
|
return DictIterator(self, num_epochs, output_numpy)
|
||||||
|
@ -2583,10 +2591,10 @@ class TransferDataset(DatasetOp):
|
||||||
args["send_epoch_end"] = self._send_epoch_end
|
args["send_epoch_end"] = self._send_epoch_end
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def create_dict_iterator(self, num_epochs=-1):
|
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
||||||
raise RuntimeError("TransferDataset is not iterable.")
|
raise RuntimeError("TransferDataset is not iterable.")
|
||||||
|
|
||||||
def create_tuple_iterator(self, columns=None, num_epochs=-1):
|
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
|
||||||
raise RuntimeError("TransferDataset is not iterable.")
|
raise RuntimeError("TransferDataset is not iterable.")
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|
|
@ -276,6 +276,18 @@ def check_save(method):
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
def check_iterator(method):
|
||||||
|
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||||
|
nreq_param_bool = ['output_numpy']
|
||||||
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_minddataset(method):
|
def check_minddataset(method):
|
||||||
"""A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
|
"""A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
|
||||||
|
|
|
@ -125,6 +125,32 @@ def test_iterator_weak_ref():
|
||||||
|
|
||||||
_cleanup()
|
_cleanup()
|
||||||
|
|
||||||
|
def test_iterator_exception():
|
||||||
|
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||||
|
try:
|
||||||
|
_ = data.create_dict_iterator(output_numpy="123")
|
||||||
|
assert False
|
||||||
|
except TypeError as e:
|
||||||
|
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = data.create_dict_iterator(output_numpy=123)
|
||||||
|
assert False
|
||||||
|
except TypeError as e:
|
||||||
|
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = data.create_tuple_iterator(output_numpy="123")
|
||||||
|
assert False
|
||||||
|
except TypeError as e:
|
||||||
|
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = data.create_tuple_iterator(output_numpy=123)
|
||||||
|
assert False
|
||||||
|
except TypeError as e:
|
||||||
|
assert "Argument output_numpy with value 123 is not of type" in str(e)
|
||||||
|
|
||||||
|
|
||||||
class MyDict(dict):
|
class MyDict(dict):
|
||||||
def __getattr__(self, key):
|
def __getattr__(self, key):
|
||||||
|
@ -157,4 +183,5 @@ def test_tree_copy():
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_iterator_create_tuple_numpy()
|
test_iterator_create_tuple_numpy()
|
||||||
test_iterator_weak_ref()
|
test_iterator_weak_ref()
|
||||||
|
test_iterator_exception()
|
||||||
test_tree_copy()
|
test_tree_copy()
|
||||||
|
|
Loading…
Reference in New Issue