diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc index 9b08091fc13..214d86b3c32 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -46,8 +46,8 @@ Status CacheClient::Builder::SanityCheck() { CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive"); CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); - CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive"); - CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "illegal port number"); + CHECK_FAIL_RETURN_UNEXPECTED(port_ > 1024, "Port must be in range (1025..65535)"); + CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "Port must be in range (1025..65535)"); CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1", "now cache client has to be on the same host with cache server"); return Status::OK(); diff --git a/mindspore/dataset/engine/cache_client.py b/mindspore/dataset/engine/cache_client.py index 1a2f90d79fb..04ebd4d394c 100644 --- a/mindspore/dataset/engine/cache_client.py +++ b/mindspore/dataset/engine/cache_client.py @@ -18,7 +18,7 @@ import copy from mindspore._c_dataengine import CacheClient -from ..core.validator_helpers import type_check, check_uint32, check_uint64 +from ..core.validator_helpers import type_check, check_uint32, check_uint64, check_positive, check_value class DatasetCache: @@ -29,8 +29,20 @@ class DatasetCache: def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None, prefetch_size=None): check_uint32(session_id, "session_id") - check_uint64(size, "size") + type_check(size, (int,), "size") + if size != 0: + check_positive(size, "size") + check_uint64(size, "size") type_check(spilling, (bool,), "spilling") + if hostname is not None: + type_check(hostname, (str,), "hostname") + if port is not None: + type_check(port, (int,), "port") + check_value(port, (1025, 65535), "port") + if num_connections is not None: + check_uint32(num_connections, "num_connections") + if prefetch_size is not None: + check_uint32(prefetch_size, "prefetch_size") self.session_id = session_id self.size = size diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 30c72808fa6..316a35769e2 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -550,7 +550,7 @@ def test_cache_map_parameter_check(): with pytest.raises(ValueError) as info: ds.DatasetCache(session_id=1, size=-1, spilling=True) - assert "Input is not within the required interval" in str(info.value) + assert "Input size must be greater than 0" in str(info.value) with pytest.raises(TypeError) as info: ds.DatasetCache(session_id=1, size="1", spilling=True) @@ -564,6 +564,10 @@ def test_cache_map_parameter_check(): ds.DatasetCache(session_id=1, size=0, spilling="illegal") assert "Argument spilling with value illegal is not of type (,)" in str(info.value) + with pytest.raises(TypeError) as err: + ds.DatasetCache(session_id=1, size=0, spilling=True, hostname=50052) + assert "Argument hostname with value 50052 is not of type (,)" in str(err.value) + with pytest.raises(RuntimeError) as err: ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal") assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value) @@ -574,19 +578,19 @@ def test_cache_map_parameter_check(): with pytest.raises(TypeError) as info: ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal") - assert "incompatible constructor arguments" in str(info.value) + assert "Argument port with value illegal is not of type (,)" in str(info.value) with pytest.raises(TypeError) as info: ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052") - assert "incompatible constructor arguments" in str(info.value) + assert "Argument port with value 50052 is not of type (,)" in str(info.value) - with pytest.raises(RuntimeError) as err: + with pytest.raises(ValueError) as err: ds.DatasetCache(session_id=1, size=0, spilling=True, port=0) - assert "Unexpected error. port must be positive" in str(err.value) + assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value) - with pytest.raises(RuntimeError) as err: + with pytest.raises(ValueError) as err: ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536) - assert "Unexpected error. illegal port number" in str(err.value) + assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value) with pytest.raises(TypeError) as err: ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True)