forked from mindspore-Ecosystem/mindspore
phase I of Vocab rework
phase II vocab rework added more test cases fix api doc string address review cmts and fix CI address ci complains fix review cmts ci
This commit is contained in:
parent
ef08dc0d21
commit
b6e9504b31
|
@ -1277,18 +1277,18 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<Datas
|
|||
py::tuple tp = py::reinterpret_borrow<py::tuple>(value);
|
||||
if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow<py::int_>(tp[0]));
|
||||
if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow<py::int_>(tp[1]));
|
||||
}
|
||||
if (key == "top_k") {
|
||||
} else if (key == "top_k") {
|
||||
builder->SetTopK(py::reinterpret_borrow<py::int_>(value));
|
||||
}
|
||||
if (key == "columns") {
|
||||
} else if (key == "columns") {
|
||||
(void)builder->SetColumnNames(ToStringVector(value));
|
||||
}
|
||||
if (key == "vocab") {
|
||||
} else if (key == "vocab") {
|
||||
(void)builder->SetVocab(value.cast<std::shared_ptr<Vocab>>());
|
||||
}
|
||||
if (key == "num_parallel_workers") {
|
||||
} else if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "special_first") {
|
||||
(void)builder->SetSpecialFirst(ToBool(value));
|
||||
} else if (key == "special_tokens") {
|
||||
(void)builder->SetSpecialTokens(ToStringVector(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -668,15 +668,16 @@ void bindVocabObjects(py::module *m) {
|
|||
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
|
||||
.def(py::init<>())
|
||||
.def_static("from_list",
|
||||
[](const py::list &words) {
|
||||
[](const py::list &words, const py::list &special_tokens, bool special_first) {
|
||||
std::shared_ptr<Vocab> v;
|
||||
THROW_IF_ERROR(Vocab::BuildFromPyList(words, &v));
|
||||
THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v));
|
||||
return v;
|
||||
})
|
||||
.def_static("from_file",
|
||||
[](const std::string &path, const std::string &dlm, int32_t vocab_size) {
|
||||
[](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens,
|
||||
bool special_first) {
|
||||
std::shared_ptr<Vocab> v;
|
||||
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, &v));
|
||||
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v));
|
||||
return v;
|
||||
})
|
||||
.def_static("from_dict", [](const py::dict &words) {
|
||||
|
|
|
@ -27,13 +27,16 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
BuildVocabOp::BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names,
|
||||
std::pair<int64_t, int64_t> freq_r, int64_t top_k, int32_t num_workers, int32_t op_conn_size)
|
||||
std::pair<int64_t, int64_t> freq_r, int64_t top_k, const std::vector<std::string> &tokens,
|
||||
bool prepend, int32_t num_workers, int32_t op_conn_size)
|
||||
: ParallelOp(num_workers, op_conn_size),
|
||||
interval_(op_conn_size * num_workers),
|
||||
vocab_(vocab),
|
||||
col_names_(col_names),
|
||||
freq_range_(freq_r),
|
||||
top_k_(top_k) {
|
||||
top_k_(top_k),
|
||||
special_tokens_(tokens),
|
||||
special_first_(prepend) {
|
||||
// init two queues for thread sync
|
||||
distributor_queue_ = std::make_unique<Queue<TensorRow>>(num_workers * op_conn_size);
|
||||
collector_queue_ =
|
||||
|
@ -129,7 +132,7 @@ Status BuildVocabOp::CollectorThread() {
|
|||
} // all frequencies are obtained
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty");
|
||||
std::vector<std::string> words;
|
||||
// make sure enough is reserved
|
||||
// make sure enough is reserved, this will become a partially sorted list eventually
|
||||
words.reserve(wrkr_map->size());
|
||||
|
||||
for (auto it = word_cnt_.begin(); it != word_cnt_.end();) {
|
||||
|
@ -140,6 +143,15 @@ Status BuildVocabOp::CollectorThread() {
|
|||
it = word_cnt_.erase(it);
|
||||
}
|
||||
}
|
||||
std::string err_msg;
|
||||
|
||||
for (const std::string &sp_tk : special_tokens_) {
|
||||
// if a special word exists in dataset, warn user about this
|
||||
err_msg += (word_cnt_.find(sp_tk) != word_cnt_.end() ? sp_tk + "\t" : "");
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(err_msg.empty(), "These specials words are already in the dataset: " + err_msg + ".");
|
||||
|
||||
int64_t num_words = std::min(static_cast<int64_t>(words.size()), top_k_);
|
||||
if (num_words == 0) {
|
||||
MS_LOG(WARNING) << "No word falls in the frequency range: (" << freq_range_.first << "," << freq_range_.second
|
||||
|
@ -152,9 +164,19 @@ Status BuildVocabOp::CollectorThread() {
|
|||
int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2];
|
||||
return f1 == f2 ? w1 < w2 : f1 > f2;
|
||||
});
|
||||
|
||||
if (special_first_) {
|
||||
for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < num_words; i++) {
|
||||
vocab_->append_word(words[i]);
|
||||
}
|
||||
|
||||
if (!special_first_) {
|
||||
for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
|
||||
// then use std::nth_element to partial sort
|
||||
|
@ -166,16 +188,17 @@ Status BuildVocabOp::Builder::Build(std::shared_ptr<BuildVocabOp> *op) {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0,
|
||||
"frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)");
|
||||
(*op) = std::make_shared<BuildVocabOp>(builder_vocab_, builder_col_names_,
|
||||
std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_,
|
||||
builder_num_workers_, builder_connector_size_);
|
||||
(*op) = std::make_shared<BuildVocabOp>(
|
||||
builder_vocab_, builder_col_names_, std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_,
|
||||
builder_speical_tokens_, builder_special_first_, builder_num_workers_, builder_connector_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
BuildVocabOp::Builder::Builder()
|
||||
: builder_top_k_(std::numeric_limits<int64_t>::max()),
|
||||
builder_min_freq_(0),
|
||||
builder_max_freq_(std::numeric_limits<int64_t>::max()) {
|
||||
builder_max_freq_(std::numeric_limits<int64_t>::max()),
|
||||
builder_special_first_(true) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_connector_size_ = cfg->op_connector_size();
|
||||
|
|
|
@ -88,12 +88,26 @@ class BuildVocabOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// set special tokens
|
||||
// @param const std::vector<std::string> & col_names - name of columns to get words
|
||||
// @return Builder & reference to builder class object
|
||||
Builder &SetSpecialTokens(const std::vector<std::string> &tokens) {
|
||||
builder_speical_tokens_ = tokens;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// set vocab object
|
||||
Builder &SetVocab(std::shared_ptr<Vocab> vocab) {
|
||||
builder_vocab_ = vocab;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// set special tokens first (or last)
|
||||
Builder &SetSpecialFirst(bool prepend) {
|
||||
builder_special_first_ = prepend;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
|
@ -104,13 +118,16 @@ class BuildVocabOp : public ParallelOp {
|
|||
int32_t builder_connector_size_;
|
||||
int64_t builder_min_freq_;
|
||||
int64_t builder_max_freq_;
|
||||
bool builder_special_first_;
|
||||
std::vector<std::string> builder_col_names_;
|
||||
std::vector<std::string> builder_speical_tokens_;
|
||||
std::shared_ptr<Vocab> builder_vocab_;
|
||||
int64_t builder_top_k_;
|
||||
};
|
||||
|
||||
BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names, std::pair<int64_t, int64_t> freq_range,
|
||||
int64_t top_k, int32_t num_workers, int32_t op_connector_size);
|
||||
int64_t top_k, const std::vector<std::string> &tokens, bool prepend, int32_t num_workers,
|
||||
int32_t op_connector_size);
|
||||
|
||||
~BuildVocabOp() = default;
|
||||
|
||||
|
@ -137,9 +154,11 @@ class BuildVocabOp : public ParallelOp {
|
|||
|
||||
private:
|
||||
const int32_t interval_;
|
||||
bool special_first_;
|
||||
std::shared_ptr<Vocab> vocab_;
|
||||
std::vector<std::string> col_names_;
|
||||
std::vector<int32_t> col_ids_;
|
||||
std::vector<std::string> special_tokens_;
|
||||
// pair = {min_f, max_f}
|
||||
// make sure that 0<= min_f < max_f <= int32_max in the builder
|
||||
std::pair<int64_t, int64_t> freq_range_;
|
||||
|
|
|
@ -33,7 +33,7 @@ class LookupOp : public TensorOp {
|
|||
// constructor for lookup, takes in a vocab object
|
||||
// @param std::shared_ptr<Vocab> vocab -
|
||||
// @param WordIdType default_id, id to lookup if a word is not in vocab
|
||||
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = Vocab::kSpecialTokens::unk);
|
||||
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = 1);
|
||||
|
||||
~LookupOp() = default;
|
||||
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/text/vocab.h"
|
||||
|
@ -28,41 +29,38 @@ WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const {
|
|||
return itr == word2id_.end() ? default_id : itr->second;
|
||||
}
|
||||
|
||||
WordType Vocab::Lookup(WordIdType id) {
|
||||
// this operation is most likely only done with since reverse lookup is only needed when training is done
|
||||
// hence, the worst case of inserting while keep looking up isn't likely to happen
|
||||
if (id2word_.size() != word2id_.size() && (id - kSpecialTokens::num_tokens >= id2word_.size())) {
|
||||
id2word_.clear();
|
||||
id2word_.reserve(word2id_.size());
|
||||
for (auto p : word2id_) {
|
||||
id2word_[p.second - kSpecialTokens::num_tokens] = p.first;
|
||||
}
|
||||
}
|
||||
if (id < kSpecialTokens::num_tokens) {
|
||||
return reserved_token_str_[id];
|
||||
} else if (id - kSpecialTokens::num_tokens >= id2word_.size()) {
|
||||
return reserved_token_str_[kSpecialTokens::unk];
|
||||
} else {
|
||||
return id2word_[id - kSpecialTokens::num_tokens];
|
||||
}
|
||||
}
|
||||
|
||||
Status Vocab::BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab) {
|
||||
Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special,
|
||||
std::shared_ptr<Vocab> *vocab) {
|
||||
// check of duplication on both words and special_tokens will be performed in python
|
||||
// special_tokens and words both need to be unique, and shouldn't overlap
|
||||
std::unordered_map<WordType, WordIdType> word2id;
|
||||
WordIdType word_id = kSpecialTokens::num_tokens;
|
||||
// if special is added in front, normal words id will start from number of special tokens
|
||||
WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
|
||||
|
||||
for (auto word : words) {
|
||||
const std::string s = py::str(word);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(s) == word2id.end(), "duplicate word:" + s);
|
||||
word2id[s] = word_id++;
|
||||
word2id[py::str(word)] = word_id++;
|
||||
}
|
||||
|
||||
word_id = prepend_special ? 0 : word2id.size();
|
||||
|
||||
for (auto special_token : special_tokens) {
|
||||
word2id[py::str(special_token)] = word_id++;
|
||||
}
|
||||
|
||||
*vocab = std::make_shared<Vocab>(std::move(word2id));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size,
|
||||
std::shared_ptr<Vocab> *vocab) {
|
||||
const py::list &special_tokens, bool prepend_special, std::shared_ptr<Vocab> *vocab) {
|
||||
// python validator checks special_tokens doesn't contain any duplicate words
|
||||
std::unordered_set<std::string> specials;
|
||||
// used to check that words in file don't contain any special token that already exists
|
||||
for (auto word : special_tokens) {
|
||||
specials.insert(py::str(word));
|
||||
}
|
||||
WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
|
||||
std::unordered_map<WordType, WordIdType> word2id;
|
||||
WordIdType word_id = kSpecialTokens::num_tokens;
|
||||
std::fstream handle(path, std::ios::in);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path);
|
||||
std::string word;
|
||||
|
@ -71,40 +69,35 @@ Status Vocab::BuildFromFile(const std::string &path, const std::string &delimite
|
|||
// if delimiter is not found, find_first_of would return std::string::npos which is -1
|
||||
word = word.substr(0, word.find_first_of(delimiter));
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word + ".");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(), word + " is already in special_tokens.");
|
||||
word2id[word] = word_id++;
|
||||
// break if enough row is read, if vocab_size is smaller than 0
|
||||
if (word_id == vocab_size + kSpecialTokens::num_tokens) break;
|
||||
if (word2id.size() == vocab_size) break;
|
||||
}
|
||||
|
||||
word_id = prepend_special ? 0 : word2id.size();
|
||||
|
||||
for (auto special_token : special_tokens) {
|
||||
word2id[py::str(special_token)] = word_id++;
|
||||
}
|
||||
|
||||
*vocab = std::make_shared<Vocab>(std::move(word2id));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab) {
|
||||
std::unordered_map<WordType, WordIdType> word2id;
|
||||
std::map<WordIdType, WordType> id2word;
|
||||
for (auto p : words) {
|
||||
WordIdType word_id = py::reinterpret_borrow<py::int_>(p.second);
|
||||
if (word_id < kSpecialTokens::num_tokens) continue; // skip id that are reserved
|
||||
std::string word = py::str(p.first);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(id2word.find(word_id) == id2word.end(), "duplicate id:" + word);
|
||||
id2word[word_id] = word;
|
||||
word2id[py::str(p.first)] = py::reinterpret_borrow<py::int_>(p.second);
|
||||
}
|
||||
|
||||
WordIdType cnt = kSpecialTokens::num_tokens;
|
||||
for (auto p : id2word) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(p.first == cnt++, "word id needs to be continuous starting from 2");
|
||||
word2id[p.second] = p.first;
|
||||
}
|
||||
|
||||
*vocab = std::make_shared<Vocab>(std::move(word2id));
|
||||
return Status::OK();
|
||||
}
|
||||
const std::vector<WordType> Vocab::reserved_token_str_ = {"<pad>", "<unk>"};
|
||||
|
||||
void Vocab::append_word(const std::string &word) {
|
||||
if (word2id_.find(word) == word2id_.end()) {
|
||||
word2id_[word] = word2id_.size() + kSpecialTokens::num_tokens;
|
||||
word2id_[word] = word2id_.size();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -45,7 +45,8 @@ class Vocab {
|
|||
// @param const py::list &words - a list of string, used to build vocab, id starts from 2
|
||||
// @param std::shared_ptr<Vocab> *vocab - return value, vocab object
|
||||
// @return error code
|
||||
static Status BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab);
|
||||
static Status BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special,
|
||||
std::shared_ptr<Vocab> *vocab);
|
||||
|
||||
// Build a vocab from reading a vocab file, id are automatically assigned, start from 2
|
||||
// @param std::string &path - path to vocab file , each line is assumed to contain 1 word
|
||||
|
@ -54,7 +55,7 @@ class Vocab {
|
|||
// @param std::shared_ptr<Vocab> *vocab - return value, vocab object
|
||||
// @return error code
|
||||
static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size,
|
||||
std::shared_ptr<Vocab> *vocab);
|
||||
const py::list &special_tokens, bool prepend_special, std::shared_ptr<Vocab> *vocab);
|
||||
|
||||
// Lookup the id of a word, if word doesn't exist in vocab, return default_id
|
||||
// @param const WordType word - word to look up
|
||||
|
@ -80,15 +81,8 @@ class Vocab {
|
|||
// destructor
|
||||
~Vocab() = default;
|
||||
|
||||
// enum type that holds all special tokens, add more if needed
|
||||
enum kSpecialTokens : WordIdType { pad = 0, unk = 1, num_tokens = 2 };
|
||||
|
||||
// reversed lookup table for the reserved tokens
|
||||
static const std::vector<WordType> reserved_token_str_;
|
||||
|
||||
private:
|
||||
std::unordered_map<WordType, WordIdType> word2id_;
|
||||
std::vector<WordType> id2word_; // reverse lookup
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -894,9 +894,9 @@ class Dataset:
|
|||
|
||||
return ProjectDataset(self, columns)
|
||||
|
||||
def build_vocab(self, vocab, columns, freq_range, top_k):
|
||||
def build_vocab(self, vocab, columns, freq_range, top_k, special_tokens, special_first):
|
||||
""" Internal function for building a vocab"""
|
||||
return BuildVocabDataset(self, vocab, columns, freq_range, top_k)
|
||||
return BuildVocabDataset(self, vocab, columns, freq_range, top_k, special_tokens, special_first)
|
||||
|
||||
def apply(self, apply_func):
|
||||
"""
|
||||
|
@ -4865,9 +4865,15 @@ class BuildVocabDataset(DatasetOp):
|
|||
top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
|
||||
taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken (default=None,
|
||||
all words are included).
|
||||
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
|
||||
(default=None, no special tokens will be added).
|
||||
special_first(bool): whether special_tokens will be prepended/appended to vocab, If special_tokens is
|
||||
specified and special_first is set to None, special_tokens will be prepended. (default=None).
|
||||
prefetch_size (int, optional): prefetch number of records ahead of the user's request (default=None).
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, vocab, columns, freq_range, top_k, prefetch_size=None):
|
||||
def __init__(self, input_dataset, vocab, columns, freq_range, top_k, special_tokens, special_first,
|
||||
prefetch_size=None):
|
||||
super().__init__()
|
||||
self.columns = columns
|
||||
self.input.append(input_dataset)
|
||||
|
@ -4875,6 +4881,8 @@ class BuildVocabDataset(DatasetOp):
|
|||
self.vocab = vocab
|
||||
self.freq_range = freq_range
|
||||
self.top_k = top_k
|
||||
self.special_tokens = special_tokens
|
||||
self.special_first = special_first
|
||||
input_dataset.output.append(self)
|
||||
|
||||
def get_args(self):
|
||||
|
@ -4884,6 +4892,8 @@ class BuildVocabDataset(DatasetOp):
|
|||
args["freq_range"] = self.freq_range
|
||||
args["prefetch_size"] = self.prefetch_size
|
||||
args["top_k"] = self.top_k
|
||||
args["special_tokens"] = self.special_tokens
|
||||
args["special_first"] = self.special_first
|
||||
return args
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
|
@ -4900,4 +4910,7 @@ class BuildVocabDataset(DatasetOp):
|
|||
new_op.freq_range = copy.deepcopy(self.freq_range, memodict)
|
||||
new_op.top_k = copy.deepcopy(self.top_k, memodict)
|
||||
new_op.vocab = self.vocab
|
||||
new_op.special_tokens = copy.deepcopy(self.special_tokens)
|
||||
new_op.special_first = copy.deepcopy(self.special_first)
|
||||
|
||||
return new_op
|
||||
|
|
|
@ -28,10 +28,12 @@ from .validators import check_lookup, check_jieba_add_dict, \
|
|||
|
||||
class Lookup(cde.LookupOp):
|
||||
"""
|
||||
Lookup operator that looks up a word to an id
|
||||
Lookup operator that looks up a word to an id.
|
||||
Args:
|
||||
vocab(Vocab): a Vocab object.
|
||||
unknown(int): default id to lookup a word that is out of vocab (default is None).
|
||||
unknown(int, optional): default id to lookup a word that is out of vocab. If no argument is passed, 1 will be
|
||||
used to be the default id which is the convention for unknown_token <unk>. Otherwise, user is strongly
|
||||
encouraged to pass in the id for <unk> (default=None).
|
||||
"""
|
||||
|
||||
@check_lookup
|
||||
|
|
|
@ -25,12 +25,13 @@ from .validators import check_from_file, check_from_list, check_from_dict, check
|
|||
|
||||
class Vocab(cde.Vocab):
|
||||
"""
|
||||
Vocab object that is used for lookup word.
|
||||
Vocab object that is used to lookup a word. It contains a map that maps each word(str) to an id (int)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@check_from_dataset
|
||||
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None):
|
||||
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
|
||||
special_first=None):
|
||||
"""
|
||||
Build a vocab from a dataset. This would collect all unique words in a dataset and return a vocab within
|
||||
the frequency range specified by user in freq_range. User would be warned if no words fall into the frequency.
|
||||
|
@ -49,11 +50,16 @@ class Vocab(cde.Vocab):
|
|||
top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
|
||||
taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default=None
|
||||
all words are included).
|
||||
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
|
||||
(default=None, no special tokens will be added).
|
||||
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens
|
||||
is specified and special_first is set to None, special_tokens will be prepended. (default=None).
|
||||
return:
|
||||
text.Vocab: Vocab object built from dataset.
|
||||
"""
|
||||
|
||||
vocab = Vocab()
|
||||
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k)
|
||||
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
|
||||
for d in root.create_dict_iterator():
|
||||
if d is not None:
|
||||
raise ValueError("from_dataset should receive data other than None.")
|
||||
|
@ -61,17 +67,21 @@ class Vocab(cde.Vocab):
|
|||
|
||||
@classmethod
|
||||
@check_from_list
|
||||
def from_list(cls, word_list):
|
||||
def from_list(cls, word_list, special_tokens=None, special_first=None):
|
||||
"""
|
||||
build a vocab object from a list of word.
|
||||
Args:
|
||||
word_list(list): a list of string where each element is a word.
|
||||
word_list(list): a list of string where each element is a word of type string.
|
||||
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
|
||||
(default=None, no special tokens will be added).
|
||||
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
|
||||
is specified and special_first is set to None, special_tokens will be prepended. (default=None).
|
||||
"""
|
||||
return super().from_list(word_list)
|
||||
return super().from_list(word_list, special_tokens, special_first)
|
||||
|
||||
@classmethod
|
||||
@check_from_file
|
||||
def from_file(cls, file_path, delimiter=None, vocab_size=None):
|
||||
def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None):
|
||||
"""
|
||||
build a vocab object from a list of word.
|
||||
Args:
|
||||
|
@ -79,8 +89,12 @@ class Vocab(cde.Vocab):
|
|||
delimiter(str, optional): a delimiter to break up each line in file, the first element is taken to be
|
||||
the word (default=None).
|
||||
vocab_size(int, optional): number of words to read from file_path (default=None, all words are taken).
|
||||
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
|
||||
(default=None, no special tokens will be added).
|
||||
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
|
||||
is specified and special_first is set to None, special_tokens will be prepended. (default=None).
|
||||
"""
|
||||
return super().from_file(file_path, delimiter, vocab_size)
|
||||
return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first)
|
||||
|
||||
@classmethod
|
||||
@check_from_dict
|
||||
|
@ -88,7 +102,8 @@ class Vocab(cde.Vocab):
|
|||
"""
|
||||
build a vocab object from a dict.
|
||||
Args:
|
||||
word_dict(dict): dict contains word, id pairs. id should start from 2 and be continuous.
|
||||
word_dict(dict): dict contains word, id pairs where word should be str and id int. id is recommended to
|
||||
start from 0 and be continuous. ValueError will be raised if id is negative.
|
||||
"""
|
||||
return super().from_dict(word_dict)
|
||||
|
||||
|
|
|
@ -23,6 +23,21 @@ import mindspore._c_dataengine as cde
|
|||
from ..transforms.validators import check_uint32, check_pos_int64
|
||||
|
||||
|
||||
def check_unique_list_of_words(words, arg_name):
|
||||
"""Check that words is a list and each element is a str without any duplication"""
|
||||
|
||||
if not isinstance(words, list):
|
||||
raise ValueError(arg_name + " needs to be a list of words of type string.")
|
||||
words_set = set()
|
||||
for word in words:
|
||||
if not isinstance(word, str):
|
||||
raise ValueError("each word in " + arg_name + " needs to be type str.")
|
||||
if word in words_set:
|
||||
raise ValueError(arg_name + " contains duplicate word: " + word + ".")
|
||||
words_set.add(word)
|
||||
return words_set
|
||||
|
||||
|
||||
def check_lookup(method):
|
||||
"""A wrapper that wrap a parameter checker to the original function."""
|
||||
|
||||
|
@ -52,13 +67,17 @@ def check_from_file(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
file_path, delimiter, vocab_size = (list(args) + 3 * [None])[:3]
|
||||
file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5]
|
||||
if "file_path" in kwargs:
|
||||
file_path = kwargs.get("file_path")
|
||||
if "delimiter" in kwargs:
|
||||
delimiter = kwargs.get("delimiter")
|
||||
if "vocab_size" in kwargs:
|
||||
vocab_size = kwargs.get("vocab_size")
|
||||
if "special_tokens" in kwargs:
|
||||
special_tokens = kwargs.get("special_tokens")
|
||||
if "special_first" in kwargs:
|
||||
special_first = kwargs.get("special_first")
|
||||
|
||||
if not isinstance(file_path, str):
|
||||
raise ValueError("file_path needs to be str.")
|
||||
|
@ -73,9 +92,24 @@ def check_from_file(method):
|
|||
raise ValueError("vocab size needs to be a positive integer.")
|
||||
else:
|
||||
vocab_size = -1
|
||||
|
||||
if special_first is None:
|
||||
special_first = True
|
||||
|
||||
if not isinstance(special_first, bool):
|
||||
raise ValueError("special_first needs to be a boolean value")
|
||||
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
|
||||
check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
|
||||
kwargs["file_path"] = file_path
|
||||
kwargs["delimiter"] = delimiter
|
||||
kwargs["vocab_size"] = vocab_size
|
||||
kwargs["special_tokens"] = special_tokens
|
||||
kwargs["special_first"] = special_first
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -86,16 +120,32 @@ def check_from_list(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
word_list, = (list(args) + [None])[:1]
|
||||
word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3]
|
||||
if "word_list" in kwargs:
|
||||
word_list = kwargs.get("word_list")
|
||||
if not isinstance(word_list, list):
|
||||
raise ValueError("word_list needs to be a list of words.")
|
||||
for word in word_list:
|
||||
if not isinstance(word, str):
|
||||
raise ValueError("each word in word list needs to be type str.")
|
||||
if "special_tokens" in kwargs:
|
||||
special_tokens = kwargs.get("special_tokens")
|
||||
if "special_first" in kwargs:
|
||||
special_first = kwargs.get("special_first")
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
word_set = check_unique_list_of_words(word_list, "word_list")
|
||||
token_set = check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
|
||||
intersect = word_set.intersection(token_set)
|
||||
|
||||
if intersect != set():
|
||||
raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
|
||||
|
||||
if special_first is None:
|
||||
special_first = True
|
||||
|
||||
if not isinstance(special_first, bool):
|
||||
raise ValueError("special_first needs to be a boolean value.")
|
||||
|
||||
kwargs["word_list"] = word_list
|
||||
kwargs["special_tokens"] = special_tokens
|
||||
kwargs["special_first"] = special_first
|
||||
return method(self, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -113,9 +163,9 @@ def check_from_dict(method):
|
|||
raise ValueError("word_dict needs to be a list of word,id pairs.")
|
||||
for word, word_id in word_dict.items():
|
||||
if not isinstance(word, str):
|
||||
raise ValueError("each word in word_dict needs to be type str.")
|
||||
raise ValueError("Each word in word_dict needs to be type string.")
|
||||
if not (isinstance(word_id, int) and word_id >= 0):
|
||||
raise ValueError("each word id needs to be positive integer.")
|
||||
raise ValueError("Each word id needs to be positive integer.")
|
||||
kwargs["word_dict"] = word_dict
|
||||
return method(self, **kwargs)
|
||||
|
||||
|
@ -135,11 +185,11 @@ def check_jieba_init(method):
|
|||
mp_path = kwargs.get("mp_path")
|
||||
if hmm_path is None:
|
||||
raise ValueError(
|
||||
"the dict of HMMSegment in cppjieba is not provided.")
|
||||
"The dict of HMMSegment in cppjieba is not provided.")
|
||||
kwargs["hmm_path"] = hmm_path
|
||||
if mp_path is None:
|
||||
raise ValueError(
|
||||
"the dict of MPSegment in cppjieba is not provided.")
|
||||
"The dict of MPSegment in cppjieba is not provided.")
|
||||
kwargs["mp_path"] = mp_path
|
||||
if model is not None:
|
||||
kwargs["model"] = model
|
||||
|
@ -171,7 +221,7 @@ def check_jieba_add_word(method):
|
|||
|
||||
|
||||
def check_jieba_add_dict(method):
|
||||
"""Wrapper method to check the parameters of add dict"""
|
||||
"""Wrapper method to check the parameters of add dict."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -189,10 +239,10 @@ def check_jieba_add_dict(method):
|
|||
def check_from_dataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original function."""
|
||||
|
||||
# def from_dataset(cls, dataset, columns, freq_range=None, top_k=None):
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
dataset, columns, freq_range, top_k = (list(args) + 4 * [None])[:4]
|
||||
|
||||
dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6]
|
||||
if "dataset" in kwargs:
|
||||
dataset = kwargs.get("dataset")
|
||||
if "columns" in kwargs:
|
||||
|
@ -201,6 +251,10 @@ def check_from_dataset(method):
|
|||
freq_range = kwargs.get("freq_range")
|
||||
if "top_k" in kwargs:
|
||||
top_k = kwargs.get("top_k")
|
||||
if "special_tokens" in kwargs:
|
||||
special_tokens = kwargs.get("special_tokens")
|
||||
if "special_first" in kwargs:
|
||||
special_first = kwargs.get("special_first")
|
||||
|
||||
if columns is None:
|
||||
columns = []
|
||||
|
@ -232,10 +286,23 @@ def check_from_dataset(method):
|
|||
if isinstance(top_k, int) and top_k <= 0:
|
||||
raise ValueError("top_k needs to be a positive integer.")
|
||||
|
||||
if special_first is None:
|
||||
special_first = True
|
||||
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
|
||||
if not isinstance(special_first, bool):
|
||||
raise ValueError("special_first needs to be a boolean value.")
|
||||
|
||||
check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
|
||||
kwargs["dataset"] = dataset
|
||||
kwargs["columns"] = columns
|
||||
kwargs["freq_range"] = freq_range
|
||||
kwargs["top_k"] = top_k
|
||||
kwargs["special_tokens"] = special_tokens
|
||||
kwargs["special_first"] = special_first
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
w1
|
||||
w2
|
||||
w3
|
||||
w4
|
||||
w5
|
||||
w6
|
||||
w7
|
||||
w8
|
|
@ -23,19 +23,21 @@ import mindspore.dataset.text as text
|
|||
def test_demo_basic_from_dataset():
|
||||
""" this is a tutorial on how from_dataset should be used in a normal use case"""
|
||||
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
|
||||
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None)
|
||||
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
|
||||
special_first=True)
|
||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in data.create_dict_iterator():
|
||||
res.append(d["text"].item())
|
||||
assert res == [4, 5, 3, 6, 7, 2]
|
||||
assert res == [4, 5, 3, 6, 7, 2], res
|
||||
|
||||
|
||||
def test_demo_basic_from_dataset_with_tokenizer():
|
||||
""" this is a tutorial on how from_dataset should be used in a normal use case with tokenizer"""
|
||||
data = ds.TextFileDataset("../data/dataset/testTokenizerData/1.txt", shuffle=False)
|
||||
data = data.map(input_columns=["text"], operations=text.UnicodeCharTokenizer())
|
||||
vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None)
|
||||
vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
|
||||
special_first=True)
|
||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in data.create_dict_iterator():
|
||||
|
@ -55,7 +57,8 @@ def test_from_dataset():
|
|||
|
||||
def test_config(freq_range, top_k):
|
||||
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
|
||||
vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k)
|
||||
vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k, special_tokens=["<pad>", "<unk>"],
|
||||
special_first=True)
|
||||
corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in corpus_dataset.create_dict_iterator():
|
||||
|
@ -87,6 +90,35 @@ def test_from_dataset():
|
|||
assert test6_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1], [1]], str(test6_res)
|
||||
|
||||
|
||||
def test_from_dataset_special_token():
|
||||
""" test build vocab with generator dataset """
|
||||
|
||||
def gen_corpus():
|
||||
# key: word, value: number of occurrences, reason for using letters is so their order is apparent
|
||||
corpus = {"D": 1, "C": 1, "B": 1, "A": 1}
|
||||
for k, v in corpus.items():
|
||||
yield (np.array([k] * v, dtype='S'),)
|
||||
|
||||
def gen_input(texts):
|
||||
for word in texts.split(" "):
|
||||
yield (np.array(word, dtype='S'),)
|
||||
|
||||
def test_config(texts, top_k, special_tokens, special_first):
|
||||
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
|
||||
vocab = text.Vocab.from_dataset(corpus_dataset, None, None, top_k, special_tokens, special_first)
|
||||
data = ds.GeneratorDataset(gen_input(texts), column_names=["text"])
|
||||
data = data.map(input_columns="text", operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in data.create_dict_iterator():
|
||||
res.append(d["text"].item())
|
||||
return res
|
||||
|
||||
# test special tokens are inserted before
|
||||
assert test_config("A B C D <pad> <unk>", 4, ["<pad>", "<unk>"], True) == [2, 3, 4, 5, 0, 1]
|
||||
# test special tokens are inserted after
|
||||
assert test_config("A B C D <pad> <unk>", 4, ["<pad>", "<unk>"], False) == [0, 1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
def test_from_dataset_exceptions():
|
||||
""" test various exceptions during that are checked in validator """
|
||||
|
||||
|
@ -105,8 +137,10 @@ def test_from_dataset_exceptions():
|
|||
test_config("text", (2, 3), 0, "top_k needs to be a positive integer")
|
||||
test_config([123], (2, 3), 0, "columns need to be a list of strings")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_demo_basic_from_dataset()
|
||||
test_from_dataset()
|
||||
test_from_dataset_exceptions()
|
||||
test_demo_basic_from_dataset_with_tokenizer()
|
||||
test_from_dataset_special_token()
|
||||
|
|
|
@ -33,7 +33,7 @@ def test_on_tokenized_line():
|
|||
word = line.split(',')[0]
|
||||
jieba_op.add_word(word)
|
||||
data = data.map(input_columns=["text"], operations=jieba_op)
|
||||
vocab = text.Vocab.from_file(VOCAB_FILE, ",")
|
||||
vocab = text.Vocab.from_file(VOCAB_FILE, ",", special_tokens=["<pad>", "<unk>"])
|
||||
lookup = text.Lookup(vocab)
|
||||
data = data.map(input_columns=["text"], operations=lookup)
|
||||
res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14],
|
||||
|
|
|
@ -1,13 +1,31 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.text as text
|
||||
|
||||
# this file contains "home is behind the world head" each word is 1 line
|
||||
DATA_FILE = "../data/dataset/testVocab/words.txt"
|
||||
VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt"
|
||||
SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt"
|
||||
|
||||
|
||||
def test_from_list():
|
||||
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "))
|
||||
def test_from_list_tutorial():
|
||||
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True)
|
||||
lookup = text.Lookup(vocab)
|
||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
data = data.map(input_columns=["text"], operations=lookup)
|
||||
|
@ -18,8 +36,8 @@ def test_from_list():
|
|||
ind += 1
|
||||
|
||||
|
||||
def test_from_file():
|
||||
vocab = text.Vocab.from_file(VOCAB_FILE, ",")
|
||||
def test_from_file_tutorial():
|
||||
vocab = text.Vocab.from_file(VOCAB_FILE, ",", None, ["<pad>", "<unk>"], True)
|
||||
lookup = text.Lookup(vocab)
|
||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
data = data.map(input_columns=["text"], operations=lookup)
|
||||
|
@ -30,7 +48,7 @@ def test_from_file():
|
|||
ind += 1
|
||||
|
||||
|
||||
def test_from_dict():
|
||||
def test_from_dict_tutorial():
|
||||
vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6})
|
||||
lookup = text.Lookup(vocab, 6) # default value is -1
|
||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
|
@ -41,7 +59,61 @@ def test_from_dict():
|
|||
assert d["text"] == res[ind], ind
|
||||
ind += 1
|
||||
|
||||
|
||||
def test_from_list():
|
||||
def gen(texts):
|
||||
for word in texts.split(" "):
|
||||
yield (np.array(word, dtype='S'),)
|
||||
|
||||
def test_config(lookup_str, vocab_input, special_tokens, special_first):
|
||||
try:
|
||||
vocab = text.Vocab.from_list(vocab_input, special_tokens, special_first)
|
||||
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
|
||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in data.create_dict_iterator():
|
||||
res.append(d["text"].item())
|
||||
return res
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
|
||||
# test normal operations
|
||||
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], True) == [2, 3, 4, 0, 1]
|
||||
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False) == [0, 1, 2, 3, 4]
|
||||
assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, True) == [2, 1, 0]
|
||||
assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, False) == [2, 1, 0]
|
||||
|
||||
# test exceptions
|
||||
assert "word_list contains duplicate" in test_config("w1", ["w1", "w1"], [], True)
|
||||
assert "special_tokens contains duplicate" in test_config("w1", ["w1", "w2"], ["s1", "s1"], True)
|
||||
assert "special_tokens and word_list contain duplicate" in test_config("w1", ["w1", "w2"], ["s1", "w1"], True)
|
||||
|
||||
|
||||
def test_from_file():
|
||||
def gen(texts):
|
||||
for word in texts.split(" "):
|
||||
yield (np.array(word, dtype='S'),)
|
||||
|
||||
def test_config(lookup_str, special_tokens, special_first):
|
||||
try:
|
||||
vocab = text.Vocab.from_file(SIMPLE_VOCAB_FILE, special_tokens=special_tokens, special_first=special_first)
|
||||
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
|
||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in data.create_dict_iterator():
|
||||
res.append(d["text"].item())
|
||||
return res
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
|
||||
assert test_config("w1 w2 w3", ["s1", "s2", "s3"], True) == [3, 4, 5]
|
||||
assert test_config("w1 w2 w3", ["s1", "s2", "s3"], False) == [0, 1, 2]
|
||||
assert "special_tokens contains duplicate" in test_config("w1", ["s1", "s1"], True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_from_list_tutorial()
|
||||
test_from_file_tutorial()
|
||||
test_from_dict_tutorial()
|
||||
test_from_list()
|
||||
test_from_file()
|
||||
test_from_dict()
|
||||
|
|
Loading…
Reference in New Issue