forked from mindspore-Ecosystem/mindspore
!2317 rework on dataset.text.vocab to support any user special_tokens
Merge pull request !2317 from ZiruiWu/vocab_rework
This commit is contained in:
commit
3784220056
|
@ -1283,18 +1283,18 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<Datas
|
|||
py::tuple tp = py::reinterpret_borrow<py::tuple>(value);
|
||||
if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow<py::int_>(tp[0]));
|
||||
if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow<py::int_>(tp[1]));
|
||||
}
|
||||
if (key == "top_k") {
|
||||
} else if (key == "top_k") {
|
||||
builder->SetTopK(py::reinterpret_borrow<py::int_>(value));
|
||||
}
|
||||
if (key == "columns") {
|
||||
} else if (key == "columns") {
|
||||
(void)builder->SetColumnNames(ToStringVector(value));
|
||||
}
|
||||
if (key == "vocab") {
|
||||
} else if (key == "vocab") {
|
||||
(void)builder->SetVocab(value.cast<std::shared_ptr<Vocab>>());
|
||||
}
|
||||
if (key == "num_parallel_workers") {
|
||||
} else if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "special_first") {
|
||||
(void)builder->SetSpecialFirst(ToBool(value));
|
||||
} else if (key == "special_tokens") {
|
||||
(void)builder->SetSpecialTokens(ToStringVector(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -673,15 +673,16 @@ void bindVocabObjects(py::module *m) {
|
|||
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
|
||||
.def(py::init<>())
|
||||
.def_static("from_list",
|
||||
[](const py::list &words) {
|
||||
[](const py::list &words, const py::list &special_tokens, bool special_first) {
|
||||
std::shared_ptr<Vocab> v;
|
||||
THROW_IF_ERROR(Vocab::BuildFromPyList(words, &v));
|
||||
THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v));
|
||||
return v;
|
||||
})
|
||||
.def_static("from_file",
|
||||
[](const std::string &path, const std::string &dlm, int32_t vocab_size) {
|
||||
[](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens,
|
||||
bool special_first) {
|
||||
std::shared_ptr<Vocab> v;
|
||||
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, &v));
|
||||
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v));
|
||||
return v;
|
||||
})
|
||||
.def_static("from_dict", [](const py::dict &words) {
|
||||
|
|
|
@ -27,13 +27,16 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
BuildVocabOp::BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names,
|
||||
std::pair<int64_t, int64_t> freq_r, int64_t top_k, int32_t num_workers, int32_t op_conn_size)
|
||||
std::pair<int64_t, int64_t> freq_r, int64_t top_k, const std::vector<std::string> &tokens,
|
||||
bool prepend, int32_t num_workers, int32_t op_conn_size)
|
||||
: ParallelOp(num_workers, op_conn_size),
|
||||
interval_(op_conn_size * num_workers),
|
||||
vocab_(vocab),
|
||||
col_names_(col_names),
|
||||
freq_range_(freq_r),
|
||||
top_k_(top_k) {
|
||||
top_k_(top_k),
|
||||
special_tokens_(tokens),
|
||||
special_first_(prepend) {
|
||||
// init two queues for thread sync
|
||||
distributor_queue_ = std::make_unique<Queue<TensorRow>>(num_workers * op_conn_size);
|
||||
collector_queue_ =
|
||||
|
@ -129,7 +132,7 @@ Status BuildVocabOp::CollectorThread() {
|
|||
} // all frequencies are obtained
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty");
|
||||
std::vector<std::string> words;
|
||||
// make sure enough is reserved
|
||||
// make sure enough is reserved, this will become a partially sorted list eventually
|
||||
words.reserve(wrkr_map->size());
|
||||
|
||||
for (auto it = word_cnt_.begin(); it != word_cnt_.end();) {
|
||||
|
@ -140,6 +143,15 @@ Status BuildVocabOp::CollectorThread() {
|
|||
it = word_cnt_.erase(it);
|
||||
}
|
||||
}
|
||||
std::string err_msg;
|
||||
|
||||
for (const std::string &sp_tk : special_tokens_) {
|
||||
// if a special word exists in dataset, warn user about this
|
||||
err_msg += (word_cnt_.find(sp_tk) != word_cnt_.end() ? sp_tk + "\t" : "");
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(err_msg.empty(), "These specials words are already in the dataset: " + err_msg + ".");
|
||||
|
||||
int64_t num_words = std::min(static_cast<int64_t>(words.size()), top_k_);
|
||||
if (num_words == 0) {
|
||||
MS_LOG(WARNING) << "No word falls in the frequency range: (" << freq_range_.first << "," << freq_range_.second
|
||||
|
@ -152,9 +164,19 @@ Status BuildVocabOp::CollectorThread() {
|
|||
int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2];
|
||||
return f1 == f2 ? w1 < w2 : f1 > f2;
|
||||
});
|
||||
|
||||
if (special_first_) {
|
||||
for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < num_words; i++) {
|
||||
vocab_->append_word(words[i]);
|
||||
}
|
||||
|
||||
if (!special_first_) {
|
||||
for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
|
||||
// then use std::nth_element to partial sort
|
||||
|
@ -166,16 +188,17 @@ Status BuildVocabOp::Builder::Build(std::shared_ptr<BuildVocabOp> *op) {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0,
|
||||
"frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)");
|
||||
(*op) = std::make_shared<BuildVocabOp>(builder_vocab_, builder_col_names_,
|
||||
std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_,
|
||||
builder_num_workers_, builder_connector_size_);
|
||||
(*op) = std::make_shared<BuildVocabOp>(
|
||||
builder_vocab_, builder_col_names_, std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_,
|
||||
builder_speical_tokens_, builder_special_first_, builder_num_workers_, builder_connector_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
BuildVocabOp::Builder::Builder()
|
||||
: builder_top_k_(std::numeric_limits<int64_t>::max()),
|
||||
builder_min_freq_(0),
|
||||
builder_max_freq_(std::numeric_limits<int64_t>::max()) {
|
||||
builder_max_freq_(std::numeric_limits<int64_t>::max()),
|
||||
builder_special_first_(true) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_connector_size_ = cfg->op_connector_size();
|
||||
|
|
|
@ -88,12 +88,26 @@ class BuildVocabOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// set special tokens
|
||||
// @param const std::vector<std::string> & col_names - name of columns to get words
|
||||
// @return Builder & reference to builder class object
|
||||
Builder &SetSpecialTokens(const std::vector<std::string> &tokens) {
|
||||
builder_speical_tokens_ = tokens;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// set vocab object
|
||||
Builder &SetVocab(std::shared_ptr<Vocab> vocab) {
|
||||
builder_vocab_ = vocab;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// set special tokens first (or last)
|
||||
Builder &SetSpecialFirst(bool prepend) {
|
||||
builder_special_first_ = prepend;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
|
@ -104,13 +118,16 @@ class BuildVocabOp : public ParallelOp {
|
|||
int32_t builder_connector_size_;
|
||||
int64_t builder_min_freq_;
|
||||
int64_t builder_max_freq_;
|
||||
bool builder_special_first_;
|
||||
std::vector<std::string> builder_col_names_;
|
||||
std::vector<std::string> builder_speical_tokens_;
|
||||
std::shared_ptr<Vocab> builder_vocab_;
|
||||
int64_t builder_top_k_;
|
||||
};
|
||||
|
||||
BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names, std::pair<int64_t, int64_t> freq_range,
|
||||
int64_t top_k, int32_t num_workers, int32_t op_connector_size);
|
||||
int64_t top_k, const std::vector<std::string> &tokens, bool prepend, int32_t num_workers,
|
||||
int32_t op_connector_size);
|
||||
|
||||
~BuildVocabOp() = default;
|
||||
|
||||
|
@ -137,9 +154,11 @@ class BuildVocabOp : public ParallelOp {
|
|||
|
||||
private:
|
||||
const int32_t interval_;
|
||||
bool special_first_;
|
||||
std::shared_ptr<Vocab> vocab_;
|
||||
std::vector<std::string> col_names_;
|
||||
std::vector<int32_t> col_ids_;
|
||||
std::vector<std::string> special_tokens_;
|
||||
// pair = {min_f, max_f}
|
||||
// make sure that 0<= min_f < max_f <= int32_max in the builder
|
||||
std::pair<int64_t, int64_t> freq_range_;
|
||||
|
|
|
@ -33,7 +33,7 @@ class LookupOp : public TensorOp {
|
|||
// constructor for lookup, takes in a vocab object
|
||||
// @param std::shared_ptr<Vocab> vocab -
|
||||
// @param WordIdType default_id, id to lookup if a word is not in vocab
|
||||
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = Vocab::kSpecialTokens::unk);
|
||||
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = 1);
|
||||
|
||||
~LookupOp() = default;
|
||||
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/text/vocab.h"
|
||||
|
@ -28,41 +29,38 @@ WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const {
|
|||
return itr == word2id_.end() ? default_id : itr->second;
|
||||
}
|
||||
|
||||
WordType Vocab::Lookup(WordIdType id) {
|
||||
// this operation is most likely only done with since reverse lookup is only needed when training is done
|
||||
// hence, the worst case of inserting while keep looking up isn't likely to happen
|
||||
if (id2word_.size() != word2id_.size() && (id - kSpecialTokens::num_tokens >= id2word_.size())) {
|
||||
id2word_.clear();
|
||||
id2word_.reserve(word2id_.size());
|
||||
for (auto p : word2id_) {
|
||||
id2word_[p.second - kSpecialTokens::num_tokens] = p.first;
|
||||
}
|
||||
}
|
||||
if (id < kSpecialTokens::num_tokens) {
|
||||
return reserved_token_str_[id];
|
||||
} else if (id - kSpecialTokens::num_tokens >= id2word_.size()) {
|
||||
return reserved_token_str_[kSpecialTokens::unk];
|
||||
} else {
|
||||
return id2word_[id - kSpecialTokens::num_tokens];
|
||||
}
|
||||
}
|
||||
|
||||
Status Vocab::BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab) {
|
||||
Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special,
|
||||
std::shared_ptr<Vocab> *vocab) {
|
||||
// check of duplication on both words and special_tokens will be performed in python
|
||||
// special_tokens and words both need to be unique, and shouldn't overlap
|
||||
std::unordered_map<WordType, WordIdType> word2id;
|
||||
WordIdType word_id = kSpecialTokens::num_tokens;
|
||||
// if special is added in front, normal words id will start from number of special tokens
|
||||
WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
|
||||
|
||||
for (auto word : words) {
|
||||
const std::string s = py::str(word);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(s) == word2id.end(), "duplicate word:" + s);
|
||||
word2id[s] = word_id++;
|
||||
word2id[py::str(word)] = word_id++;
|
||||
}
|
||||
|
||||
word_id = prepend_special ? 0 : word2id.size();
|
||||
|
||||
for (auto special_token : special_tokens) {
|
||||
word2id[py::str(special_token)] = word_id++;
|
||||
}
|
||||
|
||||
*vocab = std::make_shared<Vocab>(std::move(word2id));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size,
|
||||
std::shared_ptr<Vocab> *vocab) {
|
||||
const py::list &special_tokens, bool prepend_special, std::shared_ptr<Vocab> *vocab) {
|
||||
// python validator checks special_tokens doesn't contain any duplicate words
|
||||
std::unordered_set<std::string> specials;
|
||||
// used to check that words in file don't contain any special token that already exists
|
||||
for (auto word : special_tokens) {
|
||||
specials.insert(py::str(word));
|
||||
}
|
||||
WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
|
||||
std::unordered_map<WordType, WordIdType> word2id;
|
||||
WordIdType word_id = kSpecialTokens::num_tokens;
|
||||
std::fstream handle(path, std::ios::in);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path);
|
||||
std::string word;
|
||||
|
@ -71,40 +69,35 @@ Status Vocab::BuildFromFile(const std::string &path, const std::string &delimite
|
|||
// if delimiter is not found, find_first_of would return std::string::npos which is -1
|
||||
word = word.substr(0, word.find_first_of(delimiter));
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word + ".");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(), word + " is already in special_tokens.");
|
||||
word2id[word] = word_id++;
|
||||
// break if enough row is read, if vocab_size is smaller than 0
|
||||
if (word_id == vocab_size + kSpecialTokens::num_tokens) break;
|
||||
if (word2id.size() == vocab_size) break;
|
||||
}
|
||||
|
||||
word_id = prepend_special ? 0 : word2id.size();
|
||||
|
||||
for (auto special_token : special_tokens) {
|
||||
word2id[py::str(special_token)] = word_id++;
|
||||
}
|
||||
|
||||
*vocab = std::make_shared<Vocab>(std::move(word2id));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab) {
|
||||
std::unordered_map<WordType, WordIdType> word2id;
|
||||
std::map<WordIdType, WordType> id2word;
|
||||
for (auto p : words) {
|
||||
WordIdType word_id = py::reinterpret_borrow<py::int_>(p.second);
|
||||
if (word_id < kSpecialTokens::num_tokens) continue; // skip id that are reserved
|
||||
std::string word = py::str(p.first);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(id2word.find(word_id) == id2word.end(), "duplicate id:" + word);
|
||||
id2word[word_id] = word;
|
||||
word2id[py::str(p.first)] = py::reinterpret_borrow<py::int_>(p.second);
|
||||
}
|
||||
|
||||
WordIdType cnt = kSpecialTokens::num_tokens;
|
||||
for (auto p : id2word) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(p.first == cnt++, "word id needs to be continuous starting from 2");
|
||||
word2id[p.second] = p.first;
|
||||
}
|
||||
|
||||
*vocab = std::make_shared<Vocab>(std::move(word2id));
|
||||
return Status::OK();
|
||||
}
|
||||
const std::vector<WordType> Vocab::reserved_token_str_ = {"<pad>", "<unk>"};
|
||||
|
||||
void Vocab::append_word(const std::string &word) {
|
||||
if (word2id_.find(word) == word2id_.end()) {
|
||||
word2id_[word] = word2id_.size() + kSpecialTokens::num_tokens;
|
||||
word2id_[word] = word2id_.size();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -45,7 +45,8 @@ class Vocab {
|
|||
// @param const py::list &words - a list of string, used to build vocab, id starts from 2
|
||||
// @param std::shared_ptr<Vocab> *vocab - return value, vocab object
|
||||
// @return error code
|
||||
static Status BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab);
|
||||
static Status BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special,
|
||||
std::shared_ptr<Vocab> *vocab);
|
||||
|
||||
// Build a vocab from reading a vocab file, id are automatically assigned, start from 2
|
||||
// @param std::string &path - path to vocab file , each line is assumed to contain 1 word
|
||||
|
@ -54,7 +55,7 @@ class Vocab {
|
|||
// @param std::shared_ptr<Vocab> *vocab - return value, vocab object
|
||||
// @return error code
|
||||
static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size,
|
||||
std::shared_ptr<Vocab> *vocab);
|
||||
const py::list &special_tokens, bool prepend_special, std::shared_ptr<Vocab> *vocab);
|
||||
|
||||
// Lookup the id of a word, if word doesn't exist in vocab, return default_id
|
||||
// @param const WordType word - word to look up
|
||||
|
@ -80,15 +81,8 @@ class Vocab {
|
|||
// destructor
|
||||
~Vocab() = default;
|
||||
|
||||
// enum type that holds all special tokens, add more if needed
|
||||
enum kSpecialTokens : WordIdType { pad = 0, unk = 1, num_tokens = 2 };
|
||||
|
||||
// reversed lookup table for the reserved tokens
|
||||
static const std::vector<WordType> reserved_token_str_;
|
||||
|
||||
private:
|
||||
std::unordered_map<WordType, WordIdType> word2id_;
|
||||
std::vector<WordType> id2word_; // reverse lookup
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -894,9 +894,9 @@ class Dataset:
|
|||
|
||||
return ProjectDataset(self, columns)
|
||||
|
||||
def build_vocab(self, vocab, columns, freq_range, top_k):
|
||||
def build_vocab(self, vocab, columns, freq_range, top_k, special_tokens, special_first):
|
||||
""" Internal function for building a vocab"""
|
||||
return BuildVocabDataset(self, vocab, columns, freq_range, top_k)
|
||||
return BuildVocabDataset(self, vocab, columns, freq_range, top_k, special_tokens, special_first)
|
||||
|
||||
def apply(self, apply_func):
|
||||
"""
|
||||
|
@ -4869,9 +4869,15 @@ class BuildVocabDataset(DatasetOp):
|
|||
top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
|
||||
taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken (default=None,
|
||||
all words are included).
|
||||
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
|
||||
(default=None, no special tokens will be added).
|
||||
special_first(bool): whether special_tokens will be prepended/appended to vocab, If special_tokens is
|
||||
specified and special_first is set to None, special_tokens will be prepended. (default=None).
|
||||
prefetch_size (int, optional): prefetch number of records ahead of the user's request (default=None).
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, vocab, columns, freq_range, top_k, prefetch_size=None):
|
||||
def __init__(self, input_dataset, vocab, columns, freq_range, top_k, special_tokens, special_first,
|
||||
prefetch_size=None):
|
||||
super().__init__()
|
||||
self.columns = columns
|
||||
self.input.append(input_dataset)
|
||||
|
@ -4879,6 +4885,8 @@ class BuildVocabDataset(DatasetOp):
|
|||
self.vocab = vocab
|
||||
self.freq_range = freq_range
|
||||
self.top_k = top_k
|
||||
self.special_tokens = special_tokens
|
||||
self.special_first = special_first
|
||||
input_dataset.output.append(self)
|
||||
|
||||
def get_args(self):
|
||||
|
@ -4888,6 +4896,8 @@ class BuildVocabDataset(DatasetOp):
|
|||
args["freq_range"] = self.freq_range
|
||||
args["prefetch_size"] = self.prefetch_size
|
||||
args["top_k"] = self.top_k
|
||||
args["special_tokens"] = self.special_tokens
|
||||
args["special_first"] = self.special_first
|
||||
return args
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
|
@ -4904,4 +4914,7 @@ class BuildVocabDataset(DatasetOp):
|
|||
new_op.freq_range = copy.deepcopy(self.freq_range, memodict)
|
||||
new_op.top_k = copy.deepcopy(self.top_k, memodict)
|
||||
new_op.vocab = self.vocab
|
||||
new_op.special_tokens = copy.deepcopy(self.special_tokens)
|
||||
new_op.special_first = copy.deepcopy(self.special_first)
|
||||
|
||||
return new_op
|
||||
|
|
|
@ -28,10 +28,12 @@ from .validators import check_lookup, check_jieba_add_dict, \
|
|||
|
||||
class Lookup(cde.LookupOp):
|
||||
"""
|
||||
Lookup operator that looks up a word to an id
|
||||
Lookup operator that looks up a word to an id.
|
||||
Args:
|
||||
vocab(Vocab): a Vocab object.
|
||||
unknown(int): default id to lookup a word that is out of vocab (default is None).
|
||||
unknown(int, optional): default id to lookup a word that is out of vocab. If no argument is passed, 1 will be
|
||||
used to be the default id which is the convention for unknown_token <unk>. Otherwise, user is strongly
|
||||
encouraged to pass in the id for <unk> (default=None).
|
||||
"""
|
||||
|
||||
@check_lookup
|
||||
|
|
|
@ -25,12 +25,13 @@ from .validators import check_from_file, check_from_list, check_from_dict, check
|
|||
|
||||
class Vocab(cde.Vocab):
|
||||
"""
|
||||
Vocab object that is used for lookup word.
|
||||
Vocab object that is used to lookup a word. It contains a map that maps each word(str) to an id (int)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@check_from_dataset
|
||||
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None):
|
||||
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
|
||||
special_first=None):
|
||||
"""
|
||||
Build a vocab from a dataset. This would collect all unique words in a dataset and return a vocab within
|
||||
the frequency range specified by user in freq_range. User would be warned if no words fall into the frequency.
|
||||
|
@ -49,11 +50,16 @@ class Vocab(cde.Vocab):
|
|||
top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
|
||||
taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default=None
|
||||
all words are included).
|
||||
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
|
||||
(default=None, no special tokens will be added).
|
||||
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens
|
||||
is specified and special_first is set to None, special_tokens will be prepended. (default=None).
|
||||
return:
|
||||
text.Vocab: Vocab object built from dataset.
|
||||
"""
|
||||
|
||||
vocab = Vocab()
|
||||
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k)
|
||||
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
|
||||
for d in root.create_dict_iterator():
|
||||
if d is not None:
|
||||
raise ValueError("from_dataset should receive data other than None.")
|
||||
|
@ -61,17 +67,21 @@ class Vocab(cde.Vocab):
|
|||
|
||||
@classmethod
|
||||
@check_from_list
|
||||
def from_list(cls, word_list):
|
||||
def from_list(cls, word_list, special_tokens=None, special_first=None):
|
||||
"""
|
||||
build a vocab object from a list of word.
|
||||
Args:
|
||||
word_list(list): a list of string where each element is a word.
|
||||
word_list(list): a list of string where each element is a word of type string.
|
||||
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
|
||||
(default=None, no special tokens will be added).
|
||||
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
|
||||
is specified and special_first is set to None, special_tokens will be prepended. (default=None).
|
||||
"""
|
||||
return super().from_list(word_list)
|
||||
return super().from_list(word_list, special_tokens, special_first)
|
||||
|
||||
@classmethod
|
||||
@check_from_file
|
||||
def from_file(cls, file_path, delimiter=None, vocab_size=None):
|
||||
def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None):
|
||||
"""
|
||||
build a vocab object from a list of word.
|
||||
Args:
|
||||
|
@ -79,8 +89,12 @@ class Vocab(cde.Vocab):
|
|||
delimiter(str, optional): a delimiter to break up each line in file, the first element is taken to be
|
||||
the word (default=None).
|
||||
vocab_size(int, optional): number of words to read from file_path (default=None, all words are taken).
|
||||
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
|
||||
(default=None, no special tokens will be added).
|
||||
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
|
||||
is specified and special_first is set to None, special_tokens will be prepended. (default=None).
|
||||
"""
|
||||
return super().from_file(file_path, delimiter, vocab_size)
|
||||
return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first)
|
||||
|
||||
@classmethod
|
||||
@check_from_dict
|
||||
|
@ -88,7 +102,8 @@ class Vocab(cde.Vocab):
|
|||
"""
|
||||
build a vocab object from a dict.
|
||||
Args:
|
||||
word_dict(dict): dict contains word, id pairs. id should start from 2 and be continuous.
|
||||
word_dict(dict): dict contains word, id pairs where word should be str and id int. id is recommended to
|
||||
start from 0 and be continuous. ValueError will be raised if id is negative.
|
||||
"""
|
||||
return super().from_dict(word_dict)
|
||||
|
||||
|
|
|
@ -23,6 +23,21 @@ import mindspore._c_dataengine as cde
|
|||
from ..transforms.validators import check_uint32, check_pos_int64
|
||||
|
||||
|
||||
def check_unique_list_of_words(words, arg_name):
|
||||
"""Check that words is a list and each element is a str without any duplication"""
|
||||
|
||||
if not isinstance(words, list):
|
||||
raise ValueError(arg_name + " needs to be a list of words of type string.")
|
||||
words_set = set()
|
||||
for word in words:
|
||||
if not isinstance(word, str):
|
||||
raise ValueError("each word in " + arg_name + " needs to be type str.")
|
||||
if word in words_set:
|
||||
raise ValueError(arg_name + " contains duplicate word: " + word + ".")
|
||||
words_set.add(word)
|
||||
return words_set
|
||||
|
||||
|
||||
def check_lookup(method):
|
||||
"""A wrapper that wrap a parameter checker to the original function."""
|
||||
|
||||
|
@ -52,13 +67,17 @@ def check_from_file(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
file_path, delimiter, vocab_size = (list(args) + 3 * [None])[:3]
|
||||
file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5]
|
||||
if "file_path" in kwargs:
|
||||
file_path = kwargs.get("file_path")
|
||||
if "delimiter" in kwargs:
|
||||
delimiter = kwargs.get("delimiter")
|
||||
if "vocab_size" in kwargs:
|
||||
vocab_size = kwargs.get("vocab_size")
|
||||
if "special_tokens" in kwargs:
|
||||
special_tokens = kwargs.get("special_tokens")
|
||||
if "special_first" in kwargs:
|
||||
special_first = kwargs.get("special_first")
|
||||
|
||||
if not isinstance(file_path, str):
|
||||
raise ValueError("file_path needs to be str.")
|
||||
|
@ -73,9 +92,24 @@ def check_from_file(method):
|
|||
raise ValueError("vocab size needs to be a positive integer.")
|
||||
else:
|
||||
vocab_size = -1
|
||||
|
||||
if special_first is None:
|
||||
special_first = True
|
||||
|
||||
if not isinstance(special_first, bool):
|
||||
raise ValueError("special_first needs to be a boolean value")
|
||||
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
|
||||
check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
|
||||
kwargs["file_path"] = file_path
|
||||
kwargs["delimiter"] = delimiter
|
||||
kwargs["vocab_size"] = vocab_size
|
||||
kwargs["special_tokens"] = special_tokens
|
||||
kwargs["special_first"] = special_first
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -86,16 +120,32 @@ def check_from_list(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
word_list, = (list(args) + [None])[:1]
|
||||
word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3]
|
||||
if "word_list" in kwargs:
|
||||
word_list = kwargs.get("word_list")
|
||||
if not isinstance(word_list, list):
|
||||
raise ValueError("word_list needs to be a list of words.")
|
||||
for word in word_list:
|
||||
if not isinstance(word, str):
|
||||
raise ValueError("each word in word list needs to be type str.")
|
||||
if "special_tokens" in kwargs:
|
||||
special_tokens = kwargs.get("special_tokens")
|
||||
if "special_first" in kwargs:
|
||||
special_first = kwargs.get("special_first")
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
word_set = check_unique_list_of_words(word_list, "word_list")
|
||||
token_set = check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
|
||||
intersect = word_set.intersection(token_set)
|
||||
|
||||
if intersect != set():
|
||||
raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
|
||||
|
||||
if special_first is None:
|
||||
special_first = True
|
||||
|
||||
if not isinstance(special_first, bool):
|
||||
raise ValueError("special_first needs to be a boolean value.")
|
||||
|
||||
kwargs["word_list"] = word_list
|
||||
kwargs["special_tokens"] = special_tokens
|
||||
kwargs["special_first"] = special_first
|
||||
return method(self, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -113,9 +163,9 @@ def check_from_dict(method):
|
|||
raise ValueError("word_dict needs to be a list of word,id pairs.")
|
||||
for word, word_id in word_dict.items():
|
||||
if not isinstance(word, str):
|
||||
raise ValueError("each word in word_dict needs to be type str.")
|
||||
raise ValueError("Each word in word_dict needs to be type string.")
|
||||
if not (isinstance(word_id, int) and word_id >= 0):
|
||||
raise ValueError("each word id needs to be positive integer.")
|
||||
raise ValueError("Each word id needs to be positive integer.")
|
||||
kwargs["word_dict"] = word_dict
|
||||
return method(self, **kwargs)
|
||||
|
||||
|
@ -135,11 +185,11 @@ def check_jieba_init(method):
|
|||
mp_path = kwargs.get("mp_path")
|
||||
if hmm_path is None:
|
||||
raise ValueError(
|
||||
"the dict of HMMSegment in cppjieba is not provided.")
|
||||
"The dict of HMMSegment in cppjieba is not provided.")
|
||||
kwargs["hmm_path"] = hmm_path
|
||||
if mp_path is None:
|
||||
raise ValueError(
|
||||
"the dict of MPSegment in cppjieba is not provided.")
|
||||
"The dict of MPSegment in cppjieba is not provided.")
|
||||
kwargs["mp_path"] = mp_path
|
||||
if model is not None:
|
||||
kwargs["model"] = model
|
||||
|
@ -171,7 +221,7 @@ def check_jieba_add_word(method):
|
|||
|
||||
|
||||
def check_jieba_add_dict(method):
|
||||
"""Wrapper method to check the parameters of add dict"""
|
||||
"""Wrapper method to check the parameters of add dict."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -189,10 +239,10 @@ def check_jieba_add_dict(method):
|
|||
def check_from_dataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original function."""
|
||||
|
||||
# def from_dataset(cls, dataset, columns, freq_range=None, top_k=None):
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
dataset, columns, freq_range, top_k = (list(args) + 4 * [None])[:4]
|
||||
|
||||
dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6]
|
||||
if "dataset" in kwargs:
|
||||
dataset = kwargs.get("dataset")
|
||||
if "columns" in kwargs:
|
||||
|
@ -201,6 +251,10 @@ def check_from_dataset(method):
|
|||
freq_range = kwargs.get("freq_range")
|
||||
if "top_k" in kwargs:
|
||||
top_k = kwargs.get("top_k")
|
||||
if "special_tokens" in kwargs:
|
||||
special_tokens = kwargs.get("special_tokens")
|
||||
if "special_first" in kwargs:
|
||||
special_first = kwargs.get("special_first")
|
||||
|
||||
if columns is None:
|
||||
columns = []
|
||||
|
@ -232,10 +286,23 @@ def check_from_dataset(method):
|
|||
if isinstance(top_k, int) and top_k <= 0:
|
||||
raise ValueError("top_k needs to be a positive integer.")
|
||||
|
||||
if special_first is None:
|
||||
special_first = True
|
||||
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
|
||||
if not isinstance(special_first, bool):
|
||||
raise ValueError("special_first needs to be a boolean value.")
|
||||
|
||||
check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
|
||||
kwargs["dataset"] = dataset
|
||||
kwargs["columns"] = columns
|
||||
kwargs["freq_range"] = freq_range
|
||||
kwargs["top_k"] = top_k
|
||||
kwargs["special_tokens"] = special_tokens
|
||||
kwargs["special_first"] = special_first
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
w1
|
||||
w2
|
||||
w3
|
||||
w4
|
||||
w5
|
||||
w6
|
||||
w7
|
||||
w8
|
|
@ -23,19 +23,21 @@ import mindspore.dataset.text as text
|
|||
def test_demo_basic_from_dataset():
|
||||
""" this is a tutorial on how from_dataset should be used in a normal use case"""
|
||||
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
|
||||
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None)
|
||||
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
|
||||
special_first=True)
|
||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in data.create_dict_iterator():
|
||||
res.append(d["text"].item())
|
||||
assert res == [4, 5, 3, 6, 7, 2]
|
||||
assert res == [4, 5, 3, 6, 7, 2], res
|
||||
|
||||
|
||||
def test_demo_basic_from_dataset_with_tokenizer():
|
||||
""" this is a tutorial on how from_dataset should be used in a normal use case with tokenizer"""
|
||||
data = ds.TextFileDataset("../data/dataset/testTokenizerData/1.txt", shuffle=False)
|
||||
data = data.map(input_columns=["text"], operations=text.UnicodeCharTokenizer())
|
||||
vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None)
|
||||
vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
|
||||
special_first=True)
|
||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in data.create_dict_iterator():
|
||||
|
@ -55,7 +57,8 @@ def test_from_dataset():
|
|||
|
||||
def test_config(freq_range, top_k):
|
||||
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
|
||||
vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k)
|
||||
vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k, special_tokens=["<pad>", "<unk>"],
|
||||
special_first=True)
|
||||
corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in corpus_dataset.create_dict_iterator():
|
||||
|
@ -87,6 +90,35 @@ def test_from_dataset():
|
|||
assert test6_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1], [1]], str(test6_res)
|
||||
|
||||
|
||||
def test_from_dataset_special_token():
|
||||
""" test build vocab with generator dataset """
|
||||
|
||||
def gen_corpus():
|
||||
# key: word, value: number of occurrences, reason for using letters is so their order is apparent
|
||||
corpus = {"D": 1, "C": 1, "B": 1, "A": 1}
|
||||
for k, v in corpus.items():
|
||||
yield (np.array([k] * v, dtype='S'),)
|
||||
|
||||
def gen_input(texts):
|
||||
for word in texts.split(" "):
|
||||
yield (np.array(word, dtype='S'),)
|
||||
|
||||
def test_config(texts, top_k, special_tokens, special_first):
|
||||
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
|
||||
vocab = text.Vocab.from_dataset(corpus_dataset, None, None, top_k, special_tokens, special_first)
|
||||
data = ds.GeneratorDataset(gen_input(texts), column_names=["text"])
|
||||
data = data.map(input_columns="text", operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in data.create_dict_iterator():
|
||||
res.append(d["text"].item())
|
||||
return res
|
||||
|
||||
# test special tokens are inserted before
|
||||
assert test_config("A B C D <pad> <unk>", 4, ["<pad>", "<unk>"], True) == [2, 3, 4, 5, 0, 1]
|
||||
# test special tokens are inserted after
|
||||
assert test_config("A B C D <pad> <unk>", 4, ["<pad>", "<unk>"], False) == [0, 1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
def test_from_dataset_exceptions():
|
||||
""" test various exceptions during that are checked in validator """
|
||||
|
||||
|
@ -105,8 +137,10 @@ def test_from_dataset_exceptions():
|
|||
test_config("text", (2, 3), 0, "top_k needs to be a positive integer")
|
||||
test_config([123], (2, 3), 0, "columns need to be a list of strings")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_demo_basic_from_dataset()
|
||||
test_from_dataset()
|
||||
test_from_dataset_exceptions()
|
||||
test_demo_basic_from_dataset_with_tokenizer()
|
||||
test_from_dataset_special_token()
|
||||
|
|
|
@ -33,7 +33,7 @@ def test_on_tokenized_line():
|
|||
word = line.split(',')[0]
|
||||
jieba_op.add_word(word)
|
||||
data = data.map(input_columns=["text"], operations=jieba_op)
|
||||
vocab = text.Vocab.from_file(VOCAB_FILE, ",")
|
||||
vocab = text.Vocab.from_file(VOCAB_FILE, ",", special_tokens=["<pad>", "<unk>"])
|
||||
lookup = text.Lookup(vocab)
|
||||
data = data.map(input_columns=["text"], operations=lookup)
|
||||
res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14],
|
||||
|
|
|
@ -1,13 +1,31 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.text as text
|
||||
|
||||
# this file contains "home is behind the world head" each word is 1 line
|
||||
DATA_FILE = "../data/dataset/testVocab/words.txt"
|
||||
VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt"
|
||||
SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt"
|
||||
|
||||
|
||||
def test_from_list():
|
||||
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "))
|
||||
def test_from_list_tutorial():
|
||||
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True)
|
||||
lookup = text.Lookup(vocab)
|
||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
data = data.map(input_columns=["text"], operations=lookup)
|
||||
|
@ -18,8 +36,8 @@ def test_from_list():
|
|||
ind += 1
|
||||
|
||||
|
||||
def test_from_file():
|
||||
vocab = text.Vocab.from_file(VOCAB_FILE, ",")
|
||||
def test_from_file_tutorial():
|
||||
vocab = text.Vocab.from_file(VOCAB_FILE, ",", None, ["<pad>", "<unk>"], True)
|
||||
lookup = text.Lookup(vocab)
|
||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
data = data.map(input_columns=["text"], operations=lookup)
|
||||
|
@ -30,7 +48,7 @@ def test_from_file():
|
|||
ind += 1
|
||||
|
||||
|
||||
def test_from_dict():
|
||||
def test_from_dict_tutorial():
|
||||
vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6})
|
||||
lookup = text.Lookup(vocab, 6) # default value is -1
|
||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
|
@ -41,7 +59,61 @@ def test_from_dict():
|
|||
assert d["text"] == res[ind], ind
|
||||
ind += 1
|
||||
|
||||
|
||||
def test_from_list():
|
||||
def gen(texts):
|
||||
for word in texts.split(" "):
|
||||
yield (np.array(word, dtype='S'),)
|
||||
|
||||
def test_config(lookup_str, vocab_input, special_tokens, special_first):
|
||||
try:
|
||||
vocab = text.Vocab.from_list(vocab_input, special_tokens, special_first)
|
||||
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
|
||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in data.create_dict_iterator():
|
||||
res.append(d["text"].item())
|
||||
return res
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
|
||||
# test normal operations
|
||||
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], True) == [2, 3, 4, 0, 1]
|
||||
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False) == [0, 1, 2, 3, 4]
|
||||
assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, True) == [2, 1, 0]
|
||||
assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, False) == [2, 1, 0]
|
||||
|
||||
# test exceptions
|
||||
assert "word_list contains duplicate" in test_config("w1", ["w1", "w1"], [], True)
|
||||
assert "special_tokens contains duplicate" in test_config("w1", ["w1", "w2"], ["s1", "s1"], True)
|
||||
assert "special_tokens and word_list contain duplicate" in test_config("w1", ["w1", "w2"], ["s1", "w1"], True)
|
||||
|
||||
|
||||
def test_from_file():
|
||||
def gen(texts):
|
||||
for word in texts.split(" "):
|
||||
yield (np.array(word, dtype='S'),)
|
||||
|
||||
def test_config(lookup_str, special_tokens, special_first):
|
||||
try:
|
||||
vocab = text.Vocab.from_file(SIMPLE_VOCAB_FILE, special_tokens=special_tokens, special_first=special_first)
|
||||
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
|
||||
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
|
||||
res = []
|
||||
for d in data.create_dict_iterator():
|
||||
res.append(d["text"].item())
|
||||
return res
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
|
||||
assert test_config("w1 w2 w3", ["s1", "s2", "s3"], True) == [3, 4, 5]
|
||||
assert test_config("w1 w2 w3", ["s1", "s2", "s3"], False) == [0, 1, 2]
|
||||
assert "special_tokens contains duplicate" in test_config("w1", ["s1", "s1"], True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_from_list_tutorial()
|
||||
test_from_file_tutorial()
|
||||
test_from_dict_tutorial()
|
||||
test_from_list()
|
||||
test_from_file()
|
||||
test_from_dict()
|
||||
|
|
Loading…
Reference in New Issue