forked from mindspore-Ecosystem/mindspore
rework on lookup
add test caser fix ci address review cmts ci addr review cmt fix typo address review cmts add 2 more test cases cpplint fix addr cpplint addr ci fix tst case err fix doc str
This commit is contained in:
parent
b23fc4e492
commit
7b15e5a742
|
@ -609,10 +609,23 @@ void bindTokenizerOps(py::module *m) {
|
||||||
*m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.")
|
*m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.")
|
||||||
.def(py::init<>());
|
.def(py::init<>());
|
||||||
(void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp",
|
(void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp",
|
||||||
"Tensor operation to LookUp each word")
|
"Tensor operation to LookUp each word.")
|
||||||
.def(py::init<std::shared_ptr<Vocab>, WordIdType>(), py::arg("vocab"), py::arg("unknown"))
|
.def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word) {
|
||||||
.def(py::init<std::shared_ptr<Vocab>>(), py::arg("vocab"));
|
if (vocab == nullptr) {
|
||||||
(void)py::class_<NgramOp, TensorOp, std::shared_ptr<NgramOp>>(*m, "NgramOp", "TensorOp performs ngram mapping")
|
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null."));
|
||||||
|
}
|
||||||
|
if (py_word.is_none()) {
|
||||||
|
return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists);
|
||||||
|
}
|
||||||
|
std::string word = py::reinterpret_borrow<py::str>(py_word);
|
||||||
|
WordIdType default_id = vocab->Lookup(word);
|
||||||
|
if (default_id == Vocab::kNoTokenExists) {
|
||||||
|
THROW_IF_ERROR(
|
||||||
|
Status(StatusCode::kUnexpectedError, "default unknown token:" + word + " doesn't exist in vocab."));
|
||||||
|
}
|
||||||
|
return std::make_shared<LookupOp>(vocab, default_id);
|
||||||
|
}));
|
||||||
|
(void)py::class_<NgramOp, TensorOp, std::shared_ptr<NgramOp>>(*m, "NgramOp", "TensorOp performs ngram mapping.")
|
||||||
.def(py::init<const std::vector<int32_t> &, int32_t, int32_t, const std::string &, const std::string &,
|
.def(py::init<const std::vector<int32_t> &, int32_t, int32_t, const std::string &, const std::string &,
|
||||||
const std::string &>(),
|
const std::string &>(),
|
||||||
py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"),
|
py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"),
|
||||||
|
|
|
@ -26,11 +26,15 @@ LookupOp::LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id)
|
||||||
Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
IO_CHECK(input, output);
|
IO_CHECK(input, output);
|
||||||
RETURN_UNEXPECTED_IF_NULL(vocab_);
|
RETURN_UNEXPECTED_IF_NULL(vocab_);
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor");
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor.");
|
||||||
std::vector<WordIdType> word_ids;
|
std::vector<WordIdType> word_ids;
|
||||||
word_ids.reserve(input->Size());
|
word_ids.reserve(input->Size());
|
||||||
for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) {
|
for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) {
|
||||||
word_ids.push_back(vocab_->Lookup(std::string(*itr), default_id_));
|
WordIdType word_id = vocab_->Lookup(std::string(*itr));
|
||||||
|
word_ids.emplace_back(word_id == Vocab::kNoTokenExists ? default_id_ : word_id);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||||
|
word_ids.back() != Vocab::kNoTokenExists,
|
||||||
|
"Lookup Error: token" + std::string(*itr) + "doesn't exist in vocab and no unknown token is specified.");
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_,
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_,
|
||||||
|
|
|
@ -43,8 +43,7 @@ Status WordpieceTokenizerOp::LookupWord(const std::string &input_token, const Ru
|
||||||
if (start > 0) {
|
if (start > 0) {
|
||||||
word = suffix_indicator_ + word;
|
word = suffix_indicator_ + word;
|
||||||
}
|
}
|
||||||
WordIdType default_id = -1;
|
if (vocab_->Lookup(word) != Vocab::kNoTokenExists) {
|
||||||
if (vocab_->Lookup(word, default_id) != default_id) {
|
|
||||||
*out_found = true;
|
*out_found = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,9 +24,9 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { word2id_ = std::move(word2id); }
|
Vocab::Vocab(std::unordered_map<WordType, WordIdType> word2id) { word2id_ = std::move(word2id); }
|
||||||
|
|
||||||
WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const {
|
WordIdType Vocab::Lookup(const WordType &word) const {
|
||||||
auto itr = word2id_.find(word);
|
auto itr = word2id_.find(word);
|
||||||
return itr == word2id_.end() ? default_id : itr->second;
|
return itr == word2id_.end() ? kNoTokenExists : itr->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special,
|
Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special,
|
||||||
|
@ -100,5 +100,8 @@ void Vocab::append_word(const std::string &word) {
|
||||||
word2id_[word] = word2id_.size();
|
word2id_[word] = word2id_.size();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const WordIdType Vocab::kNoTokenExists = -1;
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -61,12 +61,7 @@ class Vocab {
|
||||||
// @param const WordType word - word to look up
|
// @param const WordType word - word to look up
|
||||||
// @param WordIdType default_id - word id to return to user when its not in the vocab
|
// @param WordIdType default_id - word id to return to user when its not in the vocab
|
||||||
// @return WordIdType, word_id
|
// @return WordIdType, word_id
|
||||||
WordIdType Lookup(const WordType &word, WordIdType default_id) const;
|
WordIdType Lookup(const WordType &word) const;
|
||||||
|
|
||||||
// reverse lookup, lookup the word based on its id
|
|
||||||
// @param WordIdType id - word id to lookup to
|
|
||||||
// @return WordType the word
|
|
||||||
WordType Lookup(WordIdType id);
|
|
||||||
|
|
||||||
// constructor, shouldn't be called directly, can't be private due to std::make_unique()
|
// constructor, shouldn't be called directly, can't be private due to std::make_unique()
|
||||||
// @param std::unordered_map<WordType, WordIdType> map - sanitized word2id map
|
// @param std::unordered_map<WordType, WordIdType> map - sanitized word2id map
|
||||||
|
@ -81,6 +76,8 @@ class Vocab {
|
||||||
// destructor
|
// destructor
|
||||||
~Vocab() = default;
|
~Vocab() = default;
|
||||||
|
|
||||||
|
static const WordIdType kNoTokenExists;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<WordType, WordIdType> word2id_;
|
std::unordered_map<WordType, WordIdType> word2id_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -63,17 +63,13 @@ class Lookup(cde.LookupOp):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab(Vocab): a Vocab object.
|
vocab(Vocab): a Vocab object.
|
||||||
unknown(int, optional): default id to lookup a word that is out of vocab. If no argument is passed, 1 will be
|
unknown_token(str, optional): word to use for lookup if the word being looked up is out of Vocabulary (oov).
|
||||||
used to be the default id which is the convention for unknown_token <unk>. Otherwise, user is strongly
|
If unknown_token is oov, runtime error will be thrown (default=None).
|
||||||
encouraged to pass in the id for <unk> (default=None).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_lookup
|
@check_lookup
|
||||||
def __init__(self, vocab, unknown=None):
|
def __init__(self, vocab, unknown_token=None):
|
||||||
if unknown is None:
|
super().__init__(vocab, unknown_token)
|
||||||
super().__init__(vocab)
|
|
||||||
else:
|
|
||||||
super().__init__(vocab, unknown)
|
|
||||||
|
|
||||||
|
|
||||||
class Ngram(cde.NgramOp):
|
class Ngram(cde.NgramOp):
|
||||||
|
|
|
@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
|
||||||
import mindspore._c_dataengine as cde
|
import mindspore._c_dataengine as cde
|
||||||
from mindspore._c_expression import typing
|
from mindspore._c_expression import typing
|
||||||
|
|
||||||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, check_positive, \
|
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
|
||||||
INT32_MAX, check_value
|
INT32_MAX, check_value
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,11 +44,11 @@ def check_lookup(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[vocab, unknown], _ = parse_user_args(method, *args, **kwargs)
|
[vocab, unknown_token], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
|
if unknown_token is not None:
|
||||||
|
type_check(unknown_token, (str,), "unknown_token")
|
||||||
|
|
||||||
if unknown is not None:
|
|
||||||
type_check(unknown, (int,), "unknown")
|
|
||||||
check_positive(unknown)
|
|
||||||
type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
|
type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
|
||||||
|
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
|
@ -197,7 +197,7 @@ class PadEnd(cde.PadEndOp):
|
||||||
|
|
||||||
class Concatenate(cde.ConcatenateOp):
|
class Concatenate(cde.ConcatenateOp):
|
||||||
"""
|
"""
|
||||||
Tensor operation to prepend and append to a tensor.
|
Tensor operation that concatenates all columns into a single tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
axis (int, optional): axis to concatenate the tensors along (Default=0).
|
axis (int, optional): axis to concatenate the tensors along (Default=0).
|
||||||
|
|
|
@ -26,7 +26,7 @@ def test_demo_basic_from_dataset():
|
||||||
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None,
|
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None,
|
||||||
special_tokens=["<pad>", "<unk>"],
|
special_tokens=["<pad>", "<unk>"],
|
||||||
special_first=True)
|
special_first=True)
|
||||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
data = data.map(input_columns=["text"], operations=text.Lookup(vocab, "<unk>"))
|
||||||
res = []
|
res = []
|
||||||
for d in data.create_dict_iterator():
|
for d in data.create_dict_iterator():
|
||||||
res.append(d["text"].item())
|
res.append(d["text"].item())
|
||||||
|
@ -39,7 +39,7 @@ def test_demo_basic_from_dataset_with_tokenizer():
|
||||||
data = data.map(input_columns=["text"], operations=text.UnicodeCharTokenizer())
|
data = data.map(input_columns=["text"], operations=text.UnicodeCharTokenizer())
|
||||||
vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
|
vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
|
||||||
special_first=True)
|
special_first=True)
|
||||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
data = data.map(input_columns=["text"], operations=text.Lookup(vocab, "<unk>"))
|
||||||
res = []
|
res = []
|
||||||
for d in data.create_dict_iterator():
|
for d in data.create_dict_iterator():
|
||||||
res.append(list(d["text"]))
|
res.append(list(d["text"]))
|
||||||
|
@ -60,7 +60,7 @@ def test_from_dataset():
|
||||||
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
|
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
|
||||||
vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k, special_tokens=["<pad>", "<unk>"],
|
vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k, special_tokens=["<pad>", "<unk>"],
|
||||||
special_first=True)
|
special_first=True)
|
||||||
corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab))
|
corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab, "<unk>"))
|
||||||
res = []
|
res = []
|
||||||
for d in corpus_dataset.create_dict_iterator():
|
for d in corpus_dataset.create_dict_iterator():
|
||||||
res.append(list(d["text"]))
|
res.append(list(d["text"]))
|
||||||
|
@ -108,7 +108,7 @@ def test_from_dataset_special_token():
|
||||||
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
|
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
|
||||||
vocab = text.Vocab.from_dataset(corpus_dataset, None, None, top_k, special_tokens, special_first)
|
vocab = text.Vocab.from_dataset(corpus_dataset, None, None, top_k, special_tokens, special_first)
|
||||||
data = ds.GeneratorDataset(gen_input(texts), column_names=["text"])
|
data = ds.GeneratorDataset(gen_input(texts), column_names=["text"])
|
||||||
data = data.map(input_columns="text", operations=text.Lookup(vocab))
|
data = data.map(input_columns="text", operations=text.Lookup(vocab, "<unk>"))
|
||||||
res = []
|
res = []
|
||||||
for d in data.create_dict_iterator():
|
for d in data.create_dict_iterator():
|
||||||
res.append(d["text"].item())
|
res.append(d["text"].item())
|
||||||
|
|
|
@ -34,7 +34,7 @@ def test_on_tokenized_line():
|
||||||
jieba_op.add_word(word)
|
jieba_op.add_word(word)
|
||||||
data = data.map(input_columns=["text"], operations=jieba_op)
|
data = data.map(input_columns=["text"], operations=jieba_op)
|
||||||
vocab = text.Vocab.from_file(VOCAB_FILE, ",", special_tokens=["<pad>", "<unk>"])
|
vocab = text.Vocab.from_file(VOCAB_FILE, ",", special_tokens=["<pad>", "<unk>"])
|
||||||
lookup = text.Lookup(vocab)
|
lookup = text.Lookup(vocab, "<unk>")
|
||||||
data = data.map(input_columns=["text"], operations=lookup)
|
data = data.map(input_columns=["text"], operations=lookup)
|
||||||
res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14],
|
res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14],
|
||||||
[11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32)
|
[11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32)
|
||||||
|
|
|
@ -26,7 +26,7 @@ SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt"
|
||||||
|
|
||||||
def test_from_list_tutorial():
|
def test_from_list_tutorial():
|
||||||
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True)
|
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True)
|
||||||
lookup = text.Lookup(vocab)
|
lookup = text.Lookup(vocab, "<unk>")
|
||||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||||
data = data.map(input_columns=["text"], operations=lookup)
|
data = data.map(input_columns=["text"], operations=lookup)
|
||||||
ind = 0
|
ind = 0
|
||||||
|
@ -50,7 +50,7 @@ def test_from_file_tutorial():
|
||||||
|
|
||||||
def test_from_dict_tutorial():
|
def test_from_dict_tutorial():
|
||||||
vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6})
|
vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6})
|
||||||
lookup = text.Lookup(vocab, 6) # default value is -1
|
lookup = text.Lookup(vocab, "<unk>") # any unknown token will be mapped to the id of <unk>
|
||||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||||
data = data.map(input_columns=["text"], operations=lookup)
|
data = data.map(input_columns=["text"], operations=lookup)
|
||||||
res = [3, 6, 2, 4, 5, 6]
|
res = [3, 6, 2, 4, 5, 6]
|
||||||
|
@ -65,28 +65,39 @@ def test_from_list():
|
||||||
for word in texts.split(" "):
|
for word in texts.split(" "):
|
||||||
yield (np.array(word, dtype='S'),)
|
yield (np.array(word, dtype='S'),)
|
||||||
|
|
||||||
def test_config(lookup_str, vocab_input, special_tokens, special_first):
|
def test_config(lookup_str, vocab_input, special_tokens, special_first, unknown_token):
|
||||||
try:
|
try:
|
||||||
vocab = text.Vocab.from_list(vocab_input, special_tokens, special_first)
|
vocab = text.Vocab.from_list(vocab_input, special_tokens, special_first)
|
||||||
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
|
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
|
||||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
data = data.map(input_columns=["text"], operations=text.Lookup(vocab, unknown_token))
|
||||||
res = []
|
res = []
|
||||||
for d in data.create_dict_iterator():
|
for d in data.create_dict_iterator():
|
||||||
res.append(d["text"].item())
|
res.append(d["text"].item())
|
||||||
return res
|
return res
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return str(e)
|
return str(e)
|
||||||
|
except RuntimeError as e:
|
||||||
|
return str(e)
|
||||||
|
except TypeError as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
# test normal operations
|
# test normal operations
|
||||||
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], True) == [2, 3, 4, 0, 1]
|
assert test_config("w1 w2 w3 s1 s2 ephemeral", ["w1", "w2", "w3"], ["s1", "s2"], True, "s2") == [2, 3, 4, 0, 1, 1]
|
||||||
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False) == [0, 1, 2, 3, 4]
|
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False, "s2") == [0, 1, 2, 3, 4]
|
||||||
assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, True) == [2, 1, 0]
|
assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, True, "w1") == [2, 1, 0]
|
||||||
assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, False) == [2, 1, 0]
|
assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, False, "w1") == [2, 1, 0]
|
||||||
|
# test unknown token lookup
|
||||||
|
assert test_config("w1 un1 w3 un2", ["w1", "w2", "w3"], ["<pad>", "<unk>"], True, "<unk>") == [2, 1, 4, 1]
|
||||||
|
assert test_config("w1 un1 w3 un2", ["w1", "w2", "w3"], ["<pad>", "<unk>"], False, "<unk>") == [0, 4, 2, 4]
|
||||||
|
|
||||||
# test exceptions
|
# test exceptions
|
||||||
assert "word_list contains duplicate" in test_config("w1", ["w1", "w1"], [], True)
|
assert "doesn't exist in vocab." in test_config("un1", ["w1"], [], False, "unk")
|
||||||
assert "special_tokens contains duplicate" in test_config("w1", ["w1", "w2"], ["s1", "s1"], True)
|
assert "doesn't exist in vocab and no unknown token is specified." in test_config("un1", ["w1"], [], False, None)
|
||||||
assert "special_tokens and word_list contain duplicate" in test_config("w1", ["w1", "w2"], ["s1", "w1"], True)
|
assert "doesn't exist in vocab" in test_config("un1", ["w1"], [], False, None)
|
||||||
|
assert "word_list contains duplicate" in test_config("w1", ["w1", "w1"], [], True, "w1")
|
||||||
|
assert "special_tokens contains duplicate" in test_config("w1", ["w1", "w2"], ["s1", "s1"], True, "w1")
|
||||||
|
assert "special_tokens and word_list contain duplicate" in test_config("w1", ["w1", "w2"], ["s1", "w1"], True, "w1")
|
||||||
|
assert "is not of type" in test_config("w1", ["w1", "w2"], ["s1"], True, 123)
|
||||||
|
|
||||||
|
|
||||||
def test_from_file():
|
def test_from_file():
|
||||||
|
@ -99,7 +110,7 @@ def test_from_file():
|
||||||
vocab = text.Vocab.from_file(SIMPLE_VOCAB_FILE, vocab_size=vocab_size, special_tokens=special_tokens,
|
vocab = text.Vocab.from_file(SIMPLE_VOCAB_FILE, vocab_size=vocab_size, special_tokens=special_tokens,
|
||||||
special_first=special_first)
|
special_first=special_first)
|
||||||
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
|
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
|
||||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
data = data.map(input_columns=["text"], operations=text.Lookup(vocab, "s2"))
|
||||||
res = []
|
res = []
|
||||||
for d in data.create_dict_iterator():
|
for d in data.create_dict_iterator():
|
||||||
res.append(d["text"].item())
|
res.append(d["text"].item())
|
||||||
|
|
Loading…
Reference in New Issue