forked from mindspore-Ecosystem/mindspore
fix vocab validation
This commit is contained in:
parent
e1c2c4c268
commit
c248274d91
|
@ -72,6 +72,8 @@ class Vocab:
|
||||||
>>> ids = vocab.tokens_to_ids(["w1", "w3"])
|
>>> ids = vocab.tokens_to_ids(["w1", "w3"])
|
||||||
"""
|
"""
|
||||||
check_vocab(self.c_vocab)
|
check_vocab(self.c_vocab)
|
||||||
|
if isinstance(tokens, np.ndarray):
|
||||||
|
tokens = tokens.tolist()
|
||||||
if isinstance(tokens, str):
|
if isinstance(tokens, str):
|
||||||
tokens = [tokens]
|
tokens = [tokens]
|
||||||
return self.c_vocab.tokens_to_ids(tokens)
|
return self.c_vocab.tokens_to_ids(tokens)
|
||||||
|
@ -93,6 +95,8 @@ class Vocab:
|
||||||
>>> token = vocab.ids_to_tokens(0)
|
>>> token = vocab.ids_to_tokens(0)
|
||||||
"""
|
"""
|
||||||
check_vocab(self.c_vocab)
|
check_vocab(self.c_vocab)
|
||||||
|
if isinstance(ids, np.ndarray):
|
||||||
|
ids = ids.tolist()
|
||||||
if isinstance(ids, int):
|
if isinstance(ids, int):
|
||||||
ids = [ids]
|
ids = [ids]
|
||||||
return self.c_vocab.ids_to_tokens(ids)
|
return self.c_vocab.ids_to_tokens(ids)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
validators for text ops
|
validators for text ops
|
||||||
"""
|
"""
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import mindspore._c_dataengine as cde
|
import mindspore._c_dataengine as cde
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
|
@ -93,10 +94,10 @@ def check_tokens_to_ids(method):
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[tokens], _ = parse_user_args(method, *args, **kwargs)
|
[tokens], _ = parse_user_args(method, *args, **kwargs)
|
||||||
type_check(tokens, (str, list), "tokens")
|
type_check(tokens, (str, list, np.ndarray), "tokens")
|
||||||
if isinstance(tokens, list):
|
if isinstance(tokens, list):
|
||||||
param_names = ["tokens[{0}]".format(i) for i in range(len(tokens))]
|
param_names = ["tokens[{0}]".format(i) for i in range(len(tokens))]
|
||||||
type_check_list(tokens, (str,), param_names)
|
type_check_list(tokens, (str, np.str_), param_names)
|
||||||
|
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
@ -109,12 +110,12 @@ def check_ids_to_tokens(method):
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[ids], _ = parse_user_args(method, *args, **kwargs)
|
[ids], _ = parse_user_args(method, *args, **kwargs)
|
||||||
type_check(ids, (int, list), "ids")
|
type_check(ids, (int, list, np.ndarray), "ids")
|
||||||
if isinstance(ids, int):
|
if isinstance(ids, int):
|
||||||
check_value(ids, (0, INT32_MAX), "ids")
|
check_value(ids, (0, INT32_MAX), "ids")
|
||||||
if isinstance(ids, list):
|
if isinstance(ids, list):
|
||||||
for index, id_ in enumerate(ids):
|
for index, id_ in enumerate(ids):
|
||||||
type_check(id_, (int,), "ids[{}]".format(index))
|
type_check(id_, (int, np.int_), "ids[{}]".format(index))
|
||||||
check_value(id_, (0, INT32_MAX), "ids[{}]".format(index))
|
check_value(id_, (0, INT32_MAX), "ids[{}]".format(index))
|
||||||
|
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
|
@ -60,6 +60,12 @@ def test_vocab_tokens_to_ids():
|
||||||
ids = vocab.tokens_to_ids("hello")
|
ids = vocab.tokens_to_ids("hello")
|
||||||
assert ids == -1
|
assert ids == -1
|
||||||
|
|
||||||
|
ids = vocab.tokens_to_ids(np.array(["w1", "w3"]))
|
||||||
|
assert ids == [1, 3]
|
||||||
|
|
||||||
|
ids = vocab.tokens_to_ids(np.array("w1"))
|
||||||
|
assert ids == 1
|
||||||
|
|
||||||
|
|
||||||
def test_vocab_ids_to_tokens():
|
def test_vocab_ids_to_tokens():
|
||||||
"""
|
"""
|
||||||
|
@ -82,6 +88,12 @@ def test_vocab_ids_to_tokens():
|
||||||
tokens = vocab.ids_to_tokens(7)
|
tokens = vocab.ids_to_tokens(7)
|
||||||
assert tokens == ""
|
assert tokens == ""
|
||||||
|
|
||||||
|
tokens = vocab.ids_to_tokens(np.array([2, 3]))
|
||||||
|
assert tokens == ["w2", "w3"]
|
||||||
|
|
||||||
|
tokens = vocab.ids_to_tokens(np.array(2))
|
||||||
|
assert tokens == "w2"
|
||||||
|
|
||||||
|
|
||||||
def test_vocab_exception():
|
def test_vocab_exception():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue