!29719 fix vocab validation

Merge pull request !29719 from luoyang/fix_vocab
This commit is contained in:
i-robot 2022-02-08 01:10:56 +00:00 committed by Gitee
commit f17b96ac68
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 21 additions and 4 deletions

View File

@ -72,6 +72,8 @@ class Vocab:
>>> ids = vocab.tokens_to_ids(["w1", "w3"])
"""
check_vocab(self.c_vocab)
if isinstance(tokens, np.ndarray):
tokens = tokens.tolist()
if isinstance(tokens, str):
tokens = [tokens]
return self.c_vocab.tokens_to_ids(tokens)
@ -93,6 +95,8 @@ class Vocab:
>>> token = vocab.ids_to_tokens(0)
"""
check_vocab(self.c_vocab)
if isinstance(ids, np.ndarray):
ids = ids.tolist()
if isinstance(ids, int):
ids = [ids]
return self.c_vocab.ids_to_tokens(ids)

View File

@ -16,6 +16,7 @@
validators for text ops
"""
from functools import wraps
import numpy as np
import mindspore._c_dataengine as cde
import mindspore.common.dtype as mstype
@ -93,10 +94,10 @@ def check_tokens_to_ids(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[tokens], _ = parse_user_args(method, *args, **kwargs)
type_check(tokens, (str, list), "tokens")
type_check(tokens, (str, list, np.ndarray), "tokens")
if isinstance(tokens, list):
param_names = ["tokens[{0}]".format(i) for i in range(len(tokens))]
type_check_list(tokens, (str,), param_names)
type_check_list(tokens, (str, np.str_), param_names)
return method(self, *args, **kwargs)
@ -109,12 +110,12 @@ def check_ids_to_tokens(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[ids], _ = parse_user_args(method, *args, **kwargs)
type_check(ids, (int, list), "ids")
type_check(ids, (int, list, np.ndarray), "ids")
if isinstance(ids, int):
check_value(ids, (0, INT32_MAX), "ids")
if isinstance(ids, list):
for index, id_ in enumerate(ids):
type_check(id_, (int,), "ids[{}]".format(index))
type_check(id_, (int, np.int_), "ids[{}]".format(index))
check_value(id_, (0, INT32_MAX), "ids[{}]".format(index))
return method(self, *args, **kwargs)

View File

@ -60,6 +60,12 @@ def test_vocab_tokens_to_ids():
ids = vocab.tokens_to_ids("hello")
assert ids == -1
ids = vocab.tokens_to_ids(np.array(["w1", "w3"]))
assert ids == [1, 3]
ids = vocab.tokens_to_ids(np.array("w1"))
assert ids == 1
def test_vocab_ids_to_tokens():
"""
@ -82,6 +88,12 @@ def test_vocab_ids_to_tokens():
tokens = vocab.ids_to_tokens(7)
assert tokens == ""
tokens = vocab.ids_to_tokens(np.array([2, 3]))
assert tokens == ["w2", "w3"]
tokens = vocab.ids_to_tokens(np.array(2))
assert tokens == "w2"
def test_vocab_exception():
"""