diff --git a/mindspore/dataset/engine/cache_client.py b/mindspore/dataset/engine/cache_client.py index 800c0dab1de..d140a0cb554 100644 --- a/mindspore/dataset/engine/cache_client.py +++ b/mindspore/dataset/engine/cache_client.py @@ -18,21 +18,20 @@ import copy from mindspore._c_dataengine import CacheClient +from ..core.validator_helpers import type_check, check_uint32, check_uint64 + class DatasetCache: """ A client to interface with tensor caching service """ - def __init__(self, session_id=None, size=None, spilling=False): - if session_id is None: - raise RuntimeError("Session generation is not implemented yet. session id required") - self.size = size if size is not None else 0 - if size < 0: - raise ValueError("cache size should be 0 or positive integer value but got: size={}".format(size)) - if not isinstance(spilling, bool): - raise ValueError( - "spilling argument for cache should be a boolean value but got: spilling={}".format(spilling)) + def __init__(self, session_id=None, size=0, spilling=False): + check_uint32(session_id, "session_id") + check_uint64(size, "size") + type_check(spilling, (bool,), "spilling") + self.session_id = session_id + self.size = size self.spilling = spilling self.cache_client = CacheClient(session_id, size, spilling)