forked from mindspore-Ecosystem/mindspore
!3334 dataset: add param check for device_que and to_device
Merge pull request !3334 from ms_yan/device_que_param
This commit is contained in:
commit
d874150fb3
|
@ -40,7 +40,7 @@ from mindspore import log as logger
|
|||
from . import samplers
|
||||
from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp
|
||||
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
|
||||
check_rename, check_numpyslicesdataset, \
|
||||
check_rename, check_numpyslicesdataset, check_device_send, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||
|
@ -953,6 +953,7 @@ class Dataset:
|
|||
raise TypeError("apply_func must return a dataset.")
|
||||
return dataset
|
||||
|
||||
@check_device_send
|
||||
def device_que(self, prefetch_size=None, send_epoch_end=True):
|
||||
"""
|
||||
Return a transferredDataset that transfer data through device.
|
||||
|
@ -971,6 +972,7 @@ class Dataset:
|
|||
"""
|
||||
return self.to_device(send_epoch_end=send_epoch_end)
|
||||
|
||||
@check_device_send
|
||||
def to_device(self, send_epoch_end=True):
|
||||
"""
|
||||
Transfer data through CPU, GPU or Ascend devices.
|
||||
|
|
|
@ -652,6 +652,25 @@ def check_positive_int32(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_device_send(method):
|
||||
"""check the input argument for to_device and device_que."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
param, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
para_list = list(param_dict.keys())
|
||||
if "prefetch_size" in para_list:
|
||||
if param[0] is not None:
|
||||
check_pos_int32(param[0], "prefetch_size")
|
||||
type_check(param[1], (bool,), "send_epoch_end")
|
||||
else:
|
||||
type_check(param[0], (bool,), "send_epoch_end")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_zip(method):
|
||||
"""check the input arguments of zip."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue