forked from mindspore-Ecosystem/mindspore
!29719 fix vocab validation
Merge pull request !29719 from luoyang/fix_vocab
This commit is contained in:
commit
f17b96ac68
|
@ -72,6 +72,8 @@ class Vocab:
|
|||
>>> ids = vocab.tokens_to_ids(["w1", "w3"])
|
||||
"""
|
||||
check_vocab(self.c_vocab)
|
||||
if isinstance(tokens, np.ndarray):
|
||||
tokens = tokens.tolist()
|
||||
if isinstance(tokens, str):
|
||||
tokens = [tokens]
|
||||
return self.c_vocab.tokens_to_ids(tokens)
|
||||
|
@ -93,6 +95,8 @@ class Vocab:
|
|||
>>> token = vocab.ids_to_tokens(0)
|
||||
"""
|
||||
check_vocab(self.c_vocab)
|
||||
if isinstance(ids, np.ndarray):
|
||||
ids = ids.tolist()
|
||||
if isinstance(ids, int):
|
||||
ids = [ids]
|
||||
return self.c_vocab.ids_to_tokens(ids)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
validators for text ops
|
||||
"""
|
||||
from functools import wraps
|
||||
import numpy as np
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -93,10 +94,10 @@ def check_tokens_to_ids(method):
|
|||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[tokens], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(tokens, (str, list), "tokens")
|
||||
type_check(tokens, (str, list, np.ndarray), "tokens")
|
||||
if isinstance(tokens, list):
|
||||
param_names = ["tokens[{0}]".format(i) for i in range(len(tokens))]
|
||||
type_check_list(tokens, (str,), param_names)
|
||||
type_check_list(tokens, (str, np.str_), param_names)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
@ -109,12 +110,12 @@ def check_ids_to_tokens(method):
|
|||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[ids], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(ids, (int, list), "ids")
|
||||
type_check(ids, (int, list, np.ndarray), "ids")
|
||||
if isinstance(ids, int):
|
||||
check_value(ids, (0, INT32_MAX), "ids")
|
||||
if isinstance(ids, list):
|
||||
for index, id_ in enumerate(ids):
|
||||
type_check(id_, (int,), "ids[{}]".format(index))
|
||||
type_check(id_, (int, np.int_), "ids[{}]".format(index))
|
||||
check_value(id_, (0, INT32_MAX), "ids[{}]".format(index))
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
|
|
@ -60,6 +60,12 @@ def test_vocab_tokens_to_ids():
|
|||
ids = vocab.tokens_to_ids("hello")
|
||||
assert ids == -1
|
||||
|
||||
ids = vocab.tokens_to_ids(np.array(["w1", "w3"]))
|
||||
assert ids == [1, 3]
|
||||
|
||||
ids = vocab.tokens_to_ids(np.array("w1"))
|
||||
assert ids == 1
|
||||
|
||||
|
||||
def test_vocab_ids_to_tokens():
|
||||
"""
|
||||
|
@ -82,6 +88,12 @@ def test_vocab_ids_to_tokens():
|
|||
tokens = vocab.ids_to_tokens(7)
|
||||
assert tokens == ""
|
||||
|
||||
tokens = vocab.ids_to_tokens(np.array([2, 3]))
|
||||
assert tokens == ["w2", "w3"]
|
||||
|
||||
tokens = vocab.ids_to_tokens(np.array(2))
|
||||
assert tokens == "w2"
|
||||
|
||||
|
||||
def test_vocab_exception():
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue