diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 6f11a230916..74d744f21e6 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -40,7 +40,7 @@ from mindspore import log as logger from . import samplers from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ - check_rename, check_numpyslicesdataset, \ + check_rename, check_numpyslicesdataset, check_device_send, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ @@ -952,6 +952,7 @@ class Dataset: raise TypeError("apply_func must return a dataset.") return dataset + @check_device_send def device_que(self, prefetch_size=None, send_epoch_end=True): """ Return a transferredDataset that transfer data through device. @@ -970,6 +971,7 @@ class Dataset: """ return self.to_device(send_epoch_end=send_epoch_end) + @check_device_send def to_device(self, send_epoch_end=True): """ Transfer data through CPU, GPU or Ascend devices. diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index ff1ddab9b0c..a9a61c113cd 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -652,6 +652,25 @@ def check_positive_int32(method): return new_method +def check_device_send(method): + """check the input argument for to_device and device_que.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + param, param_dict = parse_user_args(method, *args, **kwargs) + para_list = list(param_dict.keys()) + if "prefetch_size" in para_list: + if param[0] is not None: + check_pos_int32(param[0], "prefetch_size") + type_check(param[1], (bool,), "send_epoch_end") + else: + type_check(param[0], (bool,), "send_epoch_end") + + return method(self, *args, **kwargs) + + return new_method + + def check_zip(method): """check the input arguments of zip."""