Minor fix for DatasetCache param validation

This commit is contained in:
Lixia Chen 2020-11-25 11:55:54 -05:00
parent 39e656096c
commit ab7427f1a9
3 changed files with 27 additions and 11 deletions

View File

@ -46,8 +46,8 @@ Status CacheClient::Builder::SanityCheck() {
CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 0, "port must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "illegal port number");
CHECK_FAIL_RETURN_UNEXPECTED(port_ > 1024, "Port must be in range (1025..65535)");
CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "Port must be in range (1025..65535)");
CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1",
"now cache client has to be on the same host with cache server");
return Status::OK();

View File

@ -18,7 +18,7 @@
import copy
from mindspore._c_dataengine import CacheClient
from ..core.validator_helpers import type_check, check_uint32, check_uint64
from ..core.validator_helpers import type_check, check_uint32, check_uint64, check_positive, check_value
class DatasetCache:
@ -29,8 +29,20 @@ class DatasetCache:
def __init__(self, session_id=None, size=0, spilling=False, hostname=None, port=None, num_connections=None,
prefetch_size=None):
check_uint32(session_id, "session_id")
check_uint64(size, "size")
type_check(size, (int,), "size")
if size != 0:
check_positive(size, "size")
check_uint64(size, "size")
type_check(spilling, (bool,), "spilling")
if hostname is not None:
type_check(hostname, (str,), "hostname")
if port is not None:
type_check(port, (int,), "port")
check_value(port, (1025, 65535), "port")
if num_connections is not None:
check_uint32(num_connections, "num_connections")
if prefetch_size is not None:
check_uint32(prefetch_size, "prefetch_size")
self.session_id = session_id
self.size = size

View File

@ -550,7 +550,7 @@ def test_cache_map_parameter_check():
with pytest.raises(ValueError) as info:
ds.DatasetCache(session_id=1, size=-1, spilling=True)
assert "Input is not within the required interval" in str(info.value)
assert "Input size must be greater than 0" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size="1", spilling=True)
@ -564,6 +564,10 @@ def test_cache_map_parameter_check():
ds.DatasetCache(session_id=1, size=0, spilling="illegal")
assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value)
with pytest.raises(TypeError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname=50052)
assert "Argument hostname with value 50052 is not of type (<class 'str'>,)" in str(err.value)
with pytest.raises(RuntimeError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal")
assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value)
@ -574,19 +578,19 @@ def test_cache_map_parameter_check():
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal")
assert "incompatible constructor arguments" in str(info.value)
assert "Argument port with value illegal is not of type (<class 'int'>,)" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052")
assert "incompatible constructor arguments" in str(info.value)
assert "Argument port with value 50052 is not of type (<class 'int'>,)" in str(info.value)
with pytest.raises(RuntimeError) as err:
with pytest.raises(ValueError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, port=0)
assert "Unexpected error. port must be positive" in str(err.value)
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)
with pytest.raises(RuntimeError) as err:
with pytest.raises(ValueError) as err:
ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536)
assert "Unexpected error. illegal port number" in str(err.value)
assert "Input port is not within the required interval of (1025 to 65535)" in str(err.value)
with pytest.raises(TypeError) as err:
ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True)