diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index a392a8c1ed..37a17f110a 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -45,7 +45,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \ - check_paddeddataset + check_paddeddataset, check_iterator from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE @@ -1149,6 +1149,7 @@ class Dataset: return SaveOp(self).save(file_names, file_type) + @check_iterator def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False): """ Create an iterator over the dataset. The data retrieved will be a list of ndarrays of data. @@ -1179,10 +1180,14 @@ class Dataset: >>> # convert the returned tuple to a list and print >>> print(list(item)) """ + if output_numpy is None: + output_numpy = False + if self._noop_mode(): return DummyIterator(self, 'tuple') return TupleIterator(self, columns, num_epochs, output_numpy) + @check_iterator def create_dict_iterator(self, num_epochs=-1, output_numpy=False): """ Create an iterator over the dataset. The data retrieved will be a dictionary. @@ -1210,6 +1215,9 @@ class Dataset: >>> # print the data in column1 >>> print(item["column1"]) """ + if output_numpy is None: + output_numpy = False + if self._noop_mode(): return DummyIterator(self, 'dict') return DictIterator(self, num_epochs, output_numpy) @@ -2583,10 +2591,10 @@ class TransferDataset(DatasetOp): args["send_epoch_end"] = self._send_epoch_end return args - def create_dict_iterator(self, num_epochs=-1): + def create_dict_iterator(self, num_epochs=-1, output_numpy=False): raise RuntimeError("TransferDataset is not iterable.") - def create_tuple_iterator(self, columns=None, num_epochs=-1): + def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False): raise RuntimeError("TransferDataset is not iterable.") def __iter__(self): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b7c1060869..816eeeca4d 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -276,6 +276,18 @@ def check_save(method): return new_method +def check_iterator(method): + """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + nreq_param_bool = ['output_numpy'] + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + return method(self, *args, **kwargs) + + return new_method + def check_minddataset(method): """A wrapper that wraps a parameter checker around the original Dataset(MindDataset).""" diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py index 32272ea9fe..3b27fbaf2e 100644 --- a/tests/ut/python/dataset/test_iterator.py +++ b/tests/ut/python/dataset/test_iterator.py @@ -125,6 +125,32 @@ def test_iterator_weak_ref(): _cleanup() +def test_iterator_exception(): + data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + try: + _ = data.create_dict_iterator(output_numpy="123") + assert False + except TypeError as e: + assert "Argument output_numpy with value 123 is not of type" in str(e) + + try: + _ = data.create_dict_iterator(output_numpy=123) + assert False + except TypeError as e: + assert "Argument output_numpy with value 123 is not of type" in str(e) + + try: + _ = data.create_tuple_iterator(output_numpy="123") + assert False + except TypeError as e: + assert "Argument output_numpy with value 123 is not of type" in str(e) + + try: + _ = data.create_tuple_iterator(output_numpy=123) + assert False + except TypeError as e: + assert "Argument output_numpy with value 123 is not of type" in str(e) + class MyDict(dict): def __getattr__(self, key): @@ -157,4 +183,5 @@ def test_tree_copy(): if __name__ == '__main__': test_iterator_create_tuple_numpy() test_iterator_weak_ref() + test_iterator_exception() test_tree_copy()