forked from mindspore-Ecosystem/mindspore
!3066 fix some batch's get_dataset_size and some text validator inconsistency
Merge pull request !3066 from ZiruiWu/fix_validator
This commit is contained in:
commit
219a716eae
|
@ -1563,7 +1563,7 @@ class BatchDataset(DatasetOp):
|
|||
Number, number of batches.
|
||||
"""
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size is not None:
|
||||
if child_size is not None and isinstance(self.batch_size, int):
|
||||
if self.drop_remainder:
|
||||
return math.floor(child_size / self.batch_size)
|
||||
return math.ceil(child_size / self.batch_size)
|
||||
|
@ -3915,7 +3915,6 @@ class RandomDataset(SourceDataset):
|
|||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
Class to represent a schema of dataset.
|
||||
|
|
|
@ -23,7 +23,8 @@ import mindspore._c_dataengine as cde
|
|||
from mindspore._c_expression import typing
|
||||
|
||||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
|
||||
INT32_MAX, check_value
|
||||
INT32_MAX, check_value, check_positive
|
||||
|
||||
|
||||
def check_unique_list_of_words(words, arg_name):
|
||||
"""Check that words is a list and each element is a str without any duplication"""
|
||||
|
@ -109,7 +110,7 @@ def check_from_dict(method):
|
|||
for word, word_id in word_dict.items():
|
||||
type_check(word, (str,), "word")
|
||||
type_check(word_id, (int,), "word_id")
|
||||
check_value(word_id, (-1, INT32_MAX), "word_id")
|
||||
check_value(word_id, (0, INT32_MAX), "word_id")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -196,7 +197,7 @@ def check_wordpiece_tokenizer(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ =\
|
||||
[vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \
|
||||
parse_user_args(method, *args, **kwargs)
|
||||
if vocab is None:
|
||||
raise ValueError("vocab is not provided.")
|
||||
|
@ -238,7 +239,7 @@ def check_basic_tokenizer(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ =\
|
||||
[lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \
|
||||
parse_user_args(method, *args, **kwargs)
|
||||
if not isinstance(lower_case, bool):
|
||||
raise TypeError("Wrong input type for lower_case, should be boolean.")
|
||||
|
@ -317,7 +318,7 @@ def check_from_dataset(method):
|
|||
type_check(top_k, (int, type(None)), "top_k")
|
||||
|
||||
if isinstance(top_k, int):
|
||||
check_value(top_k, (0, INT32_MAX), "top_k")
|
||||
check_positive(top_k, "top_k")
|
||||
type_check(special_first, (bool,), "special_first")
|
||||
|
||||
if special_tokens is not None:
|
||||
|
@ -343,7 +344,7 @@ def check_ngram(method):
|
|||
|
||||
for i, gram in enumerate(n):
|
||||
type_check(gram, (int,), "gram[{0}]".format(i))
|
||||
check_value(gram, (0, INT32_MAX), "gram_{}".format(i))
|
||||
check_positive(gram, "gram_{}".format(i))
|
||||
|
||||
if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
|
||||
left_pad[1], int)):
|
||||
|
|
|
@ -128,7 +128,7 @@ def test_from_dataset_exceptions():
|
|||
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
|
||||
vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k)
|
||||
assert isinstance(vocab.text.Vocab)
|
||||
except (TypeError, ValueError, RuntimeError) as e:
|
||||
except (TypeError, ValueError) as e:
|
||||
assert s in str(e), str(e)
|
||||
|
||||
test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.")
|
||||
|
@ -136,8 +136,8 @@ def test_from_dataset_exceptions():
|
|||
"Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)")
|
||||
test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)")
|
||||
test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)")
|
||||
test_config("text", (2, 3), 0, "top_k needs to be positive number")
|
||||
test_config([123], (2, 3), 0, "top_k needs to be positive number")
|
||||
test_config("text", (2, 3), 0, "top_k must be greater than 0")
|
||||
test_config([123], (2, 3), -1, "top_k must be greater than 0")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -72,43 +72,36 @@ def test_simple_ngram():
|
|||
def test_corner_cases():
|
||||
""" testing various corner cases and exceptions"""
|
||||
|
||||
def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
|
||||
def test_config(input_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
|
||||
def gen(texts):
|
||||
yield (np.array(texts.split(" "), dtype='S'),)
|
||||
|
||||
dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"])
|
||||
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep))
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert [d.decode("utf8") for d in data["text"]] == output_line, output_line
|
||||
try:
|
||||
dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"])
|
||||
dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep))
|
||||
for data in dataset.create_dict_iterator():
|
||||
return [d.decode("utf8") for d in data["text"]]
|
||||
except (ValueError, TypeError) as e:
|
||||
return str(e)
|
||||
|
||||
# test tensor length smaller than n
|
||||
test_config("Lone Star", ["Lone Star", "", "", ""], [2, 3, 4, 5])
|
||||
assert test_config("Lone Star", [2, 3, 4, 5]) == ["Lone Star", "", "", ""]
|
||||
# test empty separator
|
||||
test_config("Beautiful British Columbia", ['BeautifulBritish', 'BritishColumbia'], 2, sep="")
|
||||
assert test_config("Beautiful British Columbia", 2, sep="") == ['BeautifulBritish', 'BritishColumbia']
|
||||
# test separator with longer length
|
||||
test_config("Beautiful British Columbia", ['Beautiful^-^British^-^Columbia'], 3, sep="^-^")
|
||||
assert test_config("Beautiful British Columbia", 3, sep="^-^") == ['Beautiful^-^British^-^Columbia']
|
||||
# test left pad != right pad
|
||||
test_config("Lone Star", ['The Lone Star State'], 4, ("The", 1), ("State", 1))
|
||||
assert test_config("Lone Star", 4, ("The", 1), ("State", 1)) == ['The Lone Star State']
|
||||
# test invalid n
|
||||
try:
|
||||
test_config("Yours to Discover", "", [0, [1]])
|
||||
except Exception as e:
|
||||
assert "Argument gram[1] with value [1] is not of type (<class 'int'>,)" in str(e)
|
||||
# test empty n
|
||||
try:
|
||||
test_config("Yours to Discover", "", [])
|
||||
except Exception as e:
|
||||
assert "n needs to be a non-empty list" in str(e)
|
||||
assert "gram[1] with value [1] is not of type (<class 'int'>,)" in test_config("Yours to Discover", [1, [1]])
|
||||
assert "n needs to be a non-empty list" in test_config("Yours to Discover", [])
|
||||
# test invalid pad
|
||||
try:
|
||||
test_config("Yours to Discover", "", [1], ("str", -1))
|
||||
except Exception as e:
|
||||
assert "padding width need to be positive numbers" in str(e)
|
||||
# test invalid pad
|
||||
try:
|
||||
test_config("Yours to Discover", "", [1], ("str", "rts"))
|
||||
except Exception as e:
|
||||
assert "pad needs to be a tuple of (str, int)" in str(e)
|
||||
assert "padding width need to be positive numbers" in test_config("Yours to Discover", [1], ("str", -1))
|
||||
assert "pad needs to be a tuple of (str, int)" in test_config("Yours to Discover", [1], ("str", "rts"))
|
||||
# test 0 as in valid input
|
||||
assert "gram_0 must be greater than 0" in test_config("Yours to Discover", 0)
|
||||
assert "gram_0 must be greater than 0" in test_config("Yours to Discover", [0])
|
||||
assert "gram_1 must be greater than 0" in test_config("Yours to Discover", [1, 0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -60,6 +60,15 @@ def test_from_dict_tutorial():
|
|||
ind += 1
|
||||
|
||||
|
||||
def test_from_dict_exception():
|
||||
try:
|
||||
vocab = text.Vocab.from_dict({"home": -1, "behind": 0})
|
||||
if not vocab:
|
||||
raise ValueError("Vocab is None")
|
||||
except ValueError as e:
|
||||
assert "is not within the required interval" in str(e)
|
||||
|
||||
|
||||
def test_from_list():
|
||||
def gen(texts):
|
||||
for word in texts.split(" "):
|
||||
|
@ -74,13 +83,11 @@ def test_from_list():
|
|||
for d in data.create_dict_iterator():
|
||||
res.append(d["text"].item())
|
||||
return res
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
except RuntimeError as e:
|
||||
return str(e)
|
||||
except TypeError as e:
|
||||
except (ValueError, RuntimeError, TypeError) as e:
|
||||
return str(e)
|
||||
|
||||
# test basic default config, special_token=None, unknown_token=None
|
||||
assert test_config("w1 w2 w3", ["w1", "w2", "w3"], None, True, None) == [0, 1, 2]
|
||||
# test normal operations
|
||||
assert test_config("w1 w2 w3 s1 s2 ephemeral", ["w1", "w2", "w3"], ["s1", "s2"], True, "s2") == [2, 3, 4, 0, 1, 1]
|
||||
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False, "s2") == [0, 1, 2, 3, 4]
|
||||
|
@ -129,6 +136,7 @@ def test_from_file():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_from_dict_exception()
|
||||
test_from_list_tutorial()
|
||||
test_from_file_tutorial()
|
||||
test_from_dict_tutorial()
|
||||
|
|
Loading…
Reference in New Issue