forked from mindspore-Ecosystem/mindspore
Minor fix for DatasetCache param validation
This commit is contained in:
parent
39e656096c
commit
ab7427f1a9
|
@ -46,8 +46,8 @@ Status CacheClient::Builder::SanityCheck() {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "illegal port number");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 1024, "Port must be in range (1025..65535)");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "Port must be in range (1025..65535)");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1",
|
||||
"now cache client has to be on the same host with cache server");
|
||||
return Status::OK();
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
import copy
|
||||
from mindspore._c_dataengine import CacheClient
|
||||
|
||||
from ..core.validator_helpers import type_check, check_uint32, check_uint64
|
||||
from ..core.validator_helpers import type_check, check_uint32, check_uint64, check_positive, check_value
|
||||
|
||||
|
||||
class DatasetCache:
|
||||
|
@ -29,8 +29,20 @@ class DatasetCache:
|
|||
def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None,
|
||||
prefetch_size=None):
|
||||
check_uint32(session_id, "session_id")
|
||||
check_uint64(size, "size")
|
||||
type_check(size, (int,), "size")
|
||||
if size != 0:
|
||||
check_positive(size, "size")
|
||||
check_uint64(size, "size")
|
||||
type_check(spilling, (bool,), "spilling")
|
||||
if hostname is not None:
|
||||
type_check(hostname, (str,), "hostname")
|
||||
if port is not None:
|
||||
type_check(port, (int,), "port")
|
||||
check_value(port, (1025, 65535), "port")
|
||||
if num_connections is not None:
|
||||
check_uint32(num_connections, "num_connections")
|
||||
if prefetch_size is not None:
|
||||
check_uint32(prefetch_size, "prefetch_size")
|
||||
|
||||
self.session_id = session_id
|
||||
self.size = size
|
||||
|
|
|
@ -550,7 +550,7 @@ def test_cache_map_parameter_check():
|
|||
|
||||
with pytest.raises(ValueError) as info:
|
||||
ds.DatasetCache(session_id=1, size=-1, spilling=True)
|
||||
assert "Input is not within the required interval" in str(info.value)
|
||||
assert "Input size must be greater than 0" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size="1", spilling=True)
|
||||
|
@ -564,6 +564,10 @@ def test_cache_map_parameter_check():
|
|||
ds.DatasetCache(session_id=1, size=0, spilling="illegal")
|
||||
assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname=50052)
|
||||
assert "Argument hostname with value 50052 is not of type (<class 'str'>,)" in str(err.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal")
|
||||
assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value)
|
||||
|
@ -574,19 +578,19 @@ def test_cache_map_parameter_check():
|
|||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal")
|
||||
assert "incompatible constructor arguments" in str(info.value)
|
||||
assert "Argument port with value illegal is not of type (<class 'int'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052")
|
||||
assert "incompatible constructor arguments" in str(info.value)
|
||||
assert "Argument port with value 50052 is not of type (<class 'int'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
with pytest.raises(ValueError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, port=0)
|
||||
assert "Unexpected error. port must be positive" in str(err.value)
|
||||
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
with pytest.raises(ValueError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536)
|
||||
assert "Unexpected error. illegal port number" in str(err.value)
|
||||
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True)
|
||||
|
|
Loading…
Reference in New Issue