add param check for to_device and device_que

This commit is contained in:
ms_yan 2020-07-22 18:04:15 +08:00
parent 875bdc2ebc
commit 27190cf2fc
2 changed files with 22 additions and 1 deletions

View File

@ -40,7 +40,7 @@ from mindspore import log as logger
from . import samplers
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, \
check_rename, check_numpyslicesdataset, check_device_send, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
@ -952,6 +952,7 @@ class Dataset:
raise TypeError("apply_func must return a dataset.")
return dataset
@check_device_send
def device_que(self, prefetch_size=None, send_epoch_end=True):
"""
Return a transferredDataset that transfer data through device.
@ -970,6 +971,7 @@ class Dataset:
"""
return self.to_device(send_epoch_end=send_epoch_end)
@check_device_send
def to_device(self, send_epoch_end=True):
"""
Transfer data through CPU, GPU or Ascend devices.

View File

@ -652,6 +652,25 @@ def check_positive_int32(method):
return new_method
def check_device_send(method):
"""check the input argument for to_device and device_que."""
@wraps(method)
def new_method(self, *args, **kwargs):
param, param_dict = parse_user_args(method, *args, **kwargs)
para_list = list(param_dict.keys())
if "prefetch_size" in para_list:
if param[0] is not None:
check_pos_int32(param[0], "prefetch_size")
type_check(param[1], (bool,), "send_epoch_end")
else:
type_check(param[0], (bool,), "send_epoch_end")
return method(self, *args, **kwargs)
return new_method
def check_zip(method):
"""check the input arguments of zip."""