forked from mindspore-Ecosystem/mindspore
fix split erroer message
This commit is contained in:
parent
d6d93f16b1
commit
68e2097897
|
@ -1065,12 +1065,12 @@ def check_split(method):
|
|||
if all_int:
|
||||
all_positive = all(item > 0 for item in sizes)
|
||||
if not all_positive:
|
||||
raise ValueError("sizes is a list of int, but there should be no negative numbers.")
|
||||
raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.")
|
||||
|
||||
if all_float:
|
||||
all_valid_percentages = all(0 < item <= 1 for item in sizes)
|
||||
if not all_valid_percentages:
|
||||
raise ValueError("sizes is a list of float, but there should be no numbers outside the range [0, 1].")
|
||||
raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].")
|
||||
|
||||
epsilon = 0.00001
|
||||
if not abs(sum(sizes) - 1) < epsilon:
|
||||
|
|
|
@ -38,7 +38,7 @@ def split_with_invalid_inputs(d):
|
|||
|
||||
with pytest.raises(ValueError) as info:
|
||||
_, _ = d.split([-1, 6])
|
||||
assert "there should be no negative numbers" in str(info.value)
|
||||
assert "there should be no negative or zero numbers" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
_, _ = d.split([3, 1])
|
||||
|
@ -54,11 +54,11 @@ def split_with_invalid_inputs(d):
|
|||
|
||||
with pytest.raises(ValueError) as info:
|
||||
_, _ = d.split([-0.5, 0.5])
|
||||
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
|
||||
assert "there should be no numbers outside the range (0, 1]" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
_, _ = d.split([1.5, 0.5])
|
||||
assert "there should be no numbers outside the range [0, 1]" in str(info.value)
|
||||
assert "there should be no numbers outside the range (0, 1]" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
_, _ = d.split([0.5, 0.6])
|
||||
|
|
Loading…
Reference in New Issue