fix vocab validation

This commit is contained in:
luoyang 2022-02-07 15:38:17 +08:00
parent e1c2c4c268
commit c248274d91
3 changed files with 21 additions and 4 deletions

View File

@ -72,6 +72,8 @@ class Vocab:
>>> ids = vocab.tokens_to_ids(["w1", "w3"]) >>> ids = vocab.tokens_to_ids(["w1", "w3"])
""" """
check_vocab(self.c_vocab) check_vocab(self.c_vocab)
if isinstance(tokens, np.ndarray):
tokens = tokens.tolist()
if isinstance(tokens, str): if isinstance(tokens, str):
tokens = [tokens] tokens = [tokens]
return self.c_vocab.tokens_to_ids(tokens) return self.c_vocab.tokens_to_ids(tokens)
@ -93,6 +95,8 @@ class Vocab:
>>> token = vocab.ids_to_tokens(0) >>> token = vocab.ids_to_tokens(0)
""" """
check_vocab(self.c_vocab) check_vocab(self.c_vocab)
if isinstance(ids, np.ndarray):
ids = ids.tolist()
if isinstance(ids, int): if isinstance(ids, int):
ids = [ids] ids = [ids]
return self.c_vocab.ids_to_tokens(ids) return self.c_vocab.ids_to_tokens(ids)

View File

@ -16,6 +16,7 @@
validators for text ops validators for text ops
""" """
from functools import wraps from functools import wraps
import numpy as np
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@ -93,10 +94,10 @@ def check_tokens_to_ids(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
[tokens], _ = parse_user_args(method, *args, **kwargs) [tokens], _ = parse_user_args(method, *args, **kwargs)
type_check(tokens, (str, list), "tokens") type_check(tokens, (str, list, np.ndarray), "tokens")
if isinstance(tokens, list): if isinstance(tokens, list):
param_names = ["tokens[{0}]".format(i) for i in range(len(tokens))] param_names = ["tokens[{0}]".format(i) for i in range(len(tokens))]
type_check_list(tokens, (str,), param_names) type_check_list(tokens, (str, np.str_), param_names)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
@ -109,12 +110,12 @@ def check_ids_to_tokens(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
[ids], _ = parse_user_args(method, *args, **kwargs) [ids], _ = parse_user_args(method, *args, **kwargs)
type_check(ids, (int, list), "ids") type_check(ids, (int, list, np.ndarray), "ids")
if isinstance(ids, int): if isinstance(ids, int):
check_value(ids, (0, INT32_MAX), "ids") check_value(ids, (0, INT32_MAX), "ids")
if isinstance(ids, list): if isinstance(ids, list):
for index, id_ in enumerate(ids): for index, id_ in enumerate(ids):
type_check(id_, (int,), "ids[{}]".format(index)) type_check(id_, (int, np.int_), "ids[{}]".format(index))
check_value(id_, (0, INT32_MAX), "ids[{}]".format(index)) check_value(id_, (0, INT32_MAX), "ids[{}]".format(index))
return method(self, *args, **kwargs) return method(self, *args, **kwargs)

View File

@ -60,6 +60,12 @@ def test_vocab_tokens_to_ids():
ids = vocab.tokens_to_ids("hello") ids = vocab.tokens_to_ids("hello")
assert ids == -1 assert ids == -1
ids = vocab.tokens_to_ids(np.array(["w1", "w3"]))
assert ids == [1, 3]
ids = vocab.tokens_to_ids(np.array("w1"))
assert ids == 1
def test_vocab_ids_to_tokens(): def test_vocab_ids_to_tokens():
""" """
@ -82,6 +88,12 @@ def test_vocab_ids_to_tokens():
tokens = vocab.ids_to_tokens(7) tokens = vocab.ids_to_tokens(7)
assert tokens == "" assert tokens == ""
tokens = vocab.ids_to_tokens(np.array([2, 3]))
assert tokens == ["w2", "w3"]
tokens = vocab.ids_to_tokens(np.array(2))
assert tokens == "w2"
def test_vocab_exception(): def test_vocab_exception():
""" """