!2317 rework on dataset.text.vocab to support any user special_tokens

Merge pull request !2317 from ZiruiWu/vocab_rework
This commit is contained in:
mindspore-ci-bot 2020-06-19 05:25:31 +08:00 committed by Gitee
commit 3784220056
15 changed files with 354 additions and 113 deletions

View File

@ -1283,18 +1283,18 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<Datas
py::tuple tp = py::reinterpret_borrow<py::tuple>(value);
if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow<py::int_>(tp[0]));
if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow<py::int_>(tp[1]));
}
if (key == "top_k") {
} else if (key == "top_k") {
builder->SetTopK(py::reinterpret_borrow<py::int_>(value));
}
if (key == "columns") {
} else if (key == "columns") {
(void)builder->SetColumnNames(ToStringVector(value));
}
if (key == "vocab") {
} else if (key == "vocab") {
(void)builder->SetVocab(value.cast<std::shared_ptr<Vocab>>());
}
if (key == "num_parallel_workers") {
} else if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "special_first") {
(void)builder->SetSpecialFirst(ToBool(value));
} else if (key == "special_tokens") {
(void)builder->SetSpecialTokens(ToStringVector(value));
}
}
}

View File

@ -673,15 +673,16 @@ void bindVocabObjects(py::module *m) {
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
.def(py::init<>())
.def_static("from_list",
[](const py::list &words) {
[](const py::list &words, const py::list &special_tokens, bool special_first) {
std::shared_ptr<Vocab> v;
THROW_IF_ERROR(Vocab::BuildFromPyList(words, &v));
THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v));
return v;
})
.def_static("from_file",
[](const std::string &path, const std::string &dlm, int32_t vocab_size) {
[](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens,
bool special_first) {
std::shared_ptr<Vocab> v;
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, &v));
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v));
return v;
})
.def_static("from_dict", [](const py::dict &words) {

View File

@ -27,13 +27,16 @@ namespace mindspore {
namespace dataset {
BuildVocabOp::BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names,
std::pair<int64_t, int64_t> freq_r, int64_t top_k, int32_t num_workers, int32_t op_conn_size)
std::pair<int64_t, int64_t> freq_r, int64_t top_k, const std::vector<std::string> &tokens,
bool prepend, int32_t num_workers, int32_t op_conn_size)
: ParallelOp(num_workers, op_conn_size),
interval_(op_conn_size * num_workers),
vocab_(vocab),
col_names_(col_names),
freq_range_(freq_r),
top_k_(top_k) {
top_k_(top_k),
special_tokens_(tokens),
special_first_(prepend) {
// init two queues for thread sync
distributor_queue_ = std::make_unique<Queue<TensorRow>>(num_workers * op_conn_size);
collector_queue_ =
@ -129,7 +132,7 @@ Status BuildVocabOp::CollectorThread() {
} // all frequencies are obtained
CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty");
std::vector<std::string> words;
// make sure enough is reserved
// make sure enough is reserved, this will become a partially sorted list eventually
words.reserve(wrkr_map->size());
for (auto it = word_cnt_.begin(); it != word_cnt_.end();) {
@ -140,6 +143,15 @@ Status BuildVocabOp::CollectorThread() {
it = word_cnt_.erase(it);
}
}
std::string err_msg;
for (const std::string &sp_tk : special_tokens_) {
// if a special word exists in dataset, warn user about this
err_msg += (word_cnt_.find(sp_tk) != word_cnt_.end() ? sp_tk + "\t" : "");
}
CHECK_FAIL_RETURN_UNEXPECTED(err_msg.empty(), "These specials words are already in the dataset: " + err_msg + ".");
int64_t num_words = std::min(static_cast<int64_t>(words.size()), top_k_);
if (num_words == 0) {
MS_LOG(WARNING) << "No word falls in the frequency range: (" << freq_range_.first << "," << freq_range_.second
@ -152,9 +164,19 @@ Status BuildVocabOp::CollectorThread() {
int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2];
return f1 == f2 ? w1 < w2 : f1 > f2;
});
if (special_first_) {
for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk);
}
for (int64_t i = 0; i < num_words; i++) {
vocab_->append_word(words[i]);
}
if (!special_first_) {
for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk);
}
RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
// then use std::nth_element to partial sort
@ -166,16 +188,17 @@ Status BuildVocabOp::Builder::Build(std::shared_ptr<BuildVocabOp> *op) {
CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number");
CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0,
"frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)");
(*op) = std::make_shared<BuildVocabOp>(builder_vocab_, builder_col_names_,
std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_,
builder_num_workers_, builder_connector_size_);
(*op) = std::make_shared<BuildVocabOp>(
builder_vocab_, builder_col_names_, std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_,
builder_speical_tokens_, builder_special_first_, builder_num_workers_, builder_connector_size_);
return Status::OK();
}
BuildVocabOp::Builder::Builder()
: builder_top_k_(std::numeric_limits<int64_t>::max()),
builder_min_freq_(0),
builder_max_freq_(std::numeric_limits<int64_t>::max()) {
builder_max_freq_(std::numeric_limits<int64_t>::max()),
builder_special_first_(true) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_connector_size_ = cfg->op_connector_size();

View File

@ -88,12 +88,26 @@ class BuildVocabOp : public ParallelOp {
return *this;
}
// set special tokens
// @param const std::vector<std::string> & col_names - name of columns to get words
// @return Builder & reference to builder class object
Builder &SetSpecialTokens(const std::vector<std::string> &tokens) {
builder_speical_tokens_ = tokens;
return *this;
}
// set vocab object
Builder &SetVocab(std::shared_ptr<Vocab> vocab) {
builder_vocab_ = vocab;
return *this;
}
// set special tokens first (or last)
Builder &SetSpecialFirst(bool prepend) {
builder_special_first_ = prepend;
return *this;
}
// The builder "build" method creates the final object.
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
// @return - The error code return
@ -104,13 +118,16 @@ class BuildVocabOp : public ParallelOp {
int32_t builder_connector_size_;
int64_t builder_min_freq_;
int64_t builder_max_freq_;
bool builder_special_first_;
std::vector<std::string> builder_col_names_;
std::vector<std::string> builder_speical_tokens_;
std::shared_ptr<Vocab> builder_vocab_;
int64_t builder_top_k_;
};
BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names, std::pair<int64_t, int64_t> freq_range,
int64_t top_k, int32_t num_workers, int32_t op_connector_size);
int64_t top_k, const std::vector<std::string> &tokens, bool prepend, int32_t num_workers,
int32_t op_connector_size);
~BuildVocabOp() = default;
@ -137,9 +154,11 @@ class BuildVocabOp : public ParallelOp {
private:
const int32_t interval_;
bool special_first_;
std::shared_ptr<Vocab> vocab_;
std::vector<std::string> col_names_;
std::vector<int32_t> col_ids_;
std::vector<std::string> special_tokens_;
// pair = {min_f, max_f}
// make sure that 0<= min_f < max_f <= int32_max in the builder
std::pair<int64_t, int64_t> freq_range_;

View File

@ -33,7 +33,7 @@ class LookupOp : public TensorOp {
// constructor for lookup, takes in a vocab object
// @param std::shared_ptr<Vocab> vocab -
// @param WordIdType default_id, id to lookup if a word is not in vocab
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = Vocab::kSpecialTokens::unk);
explicit LookupOp(std::shared_ptr<Vocab> vocab, WordIdType default_id = 1);
~LookupOp() = default;

View File

@ -14,7 +14,8 @@
* limitations under the License.
*/
#include <fstream>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <utility>
#include "dataset/text/vocab.h"
@ -28,41 +29,38 @@ WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const {
return itr == word2id_.end() ? default_id : itr->second;
}
WordType Vocab::Lookup(WordIdType id) {
// this operation is most likely only done with since reverse lookup is only needed when training is done
// hence, the worst case of inserting while keep looking up isn't likely to happen
if (id2word_.size() != word2id_.size() && (id - kSpecialTokens::num_tokens >= id2word_.size())) {
id2word_.clear();
id2word_.reserve(word2id_.size());
for (auto p : word2id_) {
id2word_[p.second - kSpecialTokens::num_tokens] = p.first;
}
}
if (id < kSpecialTokens::num_tokens) {
return reserved_token_str_[id];
} else if (id - kSpecialTokens::num_tokens >= id2word_.size()) {
return reserved_token_str_[kSpecialTokens::unk];
} else {
return id2word_[id - kSpecialTokens::num_tokens];
}
}
Status Vocab::BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab) {
Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special,
std::shared_ptr<Vocab> *vocab) {
// check of duplication on both words and special_tokens will be performed in python
// special_tokens and words both need to be unique, and shouldn't overlap
std::unordered_map<WordType, WordIdType> word2id;
WordIdType word_id = kSpecialTokens::num_tokens;
// if special is added in front, normal words id will start from number of special tokens
WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
for (auto word : words) {
const std::string s = py::str(word);
CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(s) == word2id.end(), "duplicate word:" + s);
word2id[s] = word_id++;
word2id[py::str(word)] = word_id++;
}
word_id = prepend_special ? 0 : word2id.size();
for (auto special_token : special_tokens) {
word2id[py::str(special_token)] = word_id++;
}
*vocab = std::make_shared<Vocab>(std::move(word2id));
return Status::OK();
}
Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size,
std::shared_ptr<Vocab> *vocab) {
const py::list &special_tokens, bool prepend_special, std::shared_ptr<Vocab> *vocab) {
// python validator checks special_tokens doesn't contain any duplicate words
std::unordered_set<std::string> specials;
// used to check that words in file don't contain any special token that already exists
for (auto word : special_tokens) {
specials.insert(py::str(word));
}
WordIdType word_id = prepend_special ? static_cast<WordIdType>(special_tokens.size()) : 0;
std::unordered_map<WordType, WordIdType> word2id;
WordIdType word_id = kSpecialTokens::num_tokens;
std::fstream handle(path, std::ios::in);
CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path);
std::string word;
@ -71,40 +69,35 @@ Status Vocab::BuildFromFile(const std::string &path, const std::string &delimite
// if delimiter is not found, find_first_of would return std::string::npos which is -1
word = word.substr(0, word.find_first_of(delimiter));
}
CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word);
CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word + ".");
CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(), word + " is already in special_tokens.");
word2id[word] = word_id++;
// break if enough row is read, if vocab_size is smaller than 0
if (word_id == vocab_size + kSpecialTokens::num_tokens) break;
if (word2id.size() == vocab_size) break;
}
word_id = prepend_special ? 0 : word2id.size();
for (auto special_token : special_tokens) {
word2id[py::str(special_token)] = word_id++;
}
*vocab = std::make_shared<Vocab>(std::move(word2id));
return Status::OK();
}
Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr<Vocab> *vocab) {
std::unordered_map<WordType, WordIdType> word2id;
std::map<WordIdType, WordType> id2word;
for (auto p : words) {
WordIdType word_id = py::reinterpret_borrow<py::int_>(p.second);
if (word_id < kSpecialTokens::num_tokens) continue; // skip id that are reserved
std::string word = py::str(p.first);
CHECK_FAIL_RETURN_UNEXPECTED(id2word.find(word_id) == id2word.end(), "duplicate id:" + word);
id2word[word_id] = word;
word2id[py::str(p.first)] = py::reinterpret_borrow<py::int_>(p.second);
}
WordIdType cnt = kSpecialTokens::num_tokens;
for (auto p : id2word) {
CHECK_FAIL_RETURN_UNEXPECTED(p.first == cnt++, "word id needs to be continuous starting from 2");
word2id[p.second] = p.first;
}
*vocab = std::make_shared<Vocab>(std::move(word2id));
return Status::OK();
}
const std::vector<WordType> Vocab::reserved_token_str_ = {"<pad>", "<unk>"};
void Vocab::append_word(const std::string &word) {
if (word2id_.find(word) == word2id_.end()) {
word2id_[word] = word2id_.size() + kSpecialTokens::num_tokens;
word2id_[word] = word2id_.size();
}
}
} // namespace dataset

View File

@ -45,7 +45,8 @@ class Vocab {
// @param const py::list &words - a list of string, used to build vocab, id starts from 2
// @param std::shared_ptr<Vocab> *vocab - return value, vocab object
// @return error code
static Status BuildFromPyList(const py::list &words, std::shared_ptr<Vocab> *vocab);
static Status BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special,
std::shared_ptr<Vocab> *vocab);
// Build a vocab from reading a vocab file, id are automatically assigned, start from 2
// @param std::string &path - path to vocab file , each line is assumed to contain 1 word
@ -54,7 +55,7 @@ class Vocab {
// @param std::shared_ptr<Vocab> *vocab - return value, vocab object
// @return error code
static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size,
std::shared_ptr<Vocab> *vocab);
const py::list &special_tokens, bool prepend_special, std::shared_ptr<Vocab> *vocab);
// Lookup the id of a word, if word doesn't exist in vocab, return default_id
// @param const WordType word - word to look up
@ -80,15 +81,8 @@ class Vocab {
// destructor
~Vocab() = default;
// enum type that holds all special tokens, add more if needed
enum kSpecialTokens : WordIdType { pad = 0, unk = 1, num_tokens = 2 };
// reversed lookup table for the reserved tokens
static const std::vector<WordType> reserved_token_str_;
private:
std::unordered_map<WordType, WordIdType> word2id_;
std::vector<WordType> id2word_; // reverse lookup
};
} // namespace dataset

View File

@ -894,9 +894,9 @@ class Dataset:
return ProjectDataset(self, columns)
def build_vocab(self, vocab, columns, freq_range, top_k):
def build_vocab(self, vocab, columns, freq_range, top_k, special_tokens, special_first):
""" Internal function for building a vocab"""
return BuildVocabDataset(self, vocab, columns, freq_range, top_k)
return BuildVocabDataset(self, vocab, columns, freq_range, top_k, special_tokens, special_first)
def apply(self, apply_func):
"""
@ -4869,9 +4869,15 @@ class BuildVocabDataset(DatasetOp):
top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken (default=None,
all words are included).
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
(default=None, no special tokens will be added).
special_first(bool): whether special_tokens will be prepended/appended to vocab, If special_tokens is
specified and special_first is set to None, special_tokens will be prepended. (default=None).
prefetch_size (int, optional): prefetch number of records ahead of the user's request (default=None).
"""
def __init__(self, input_dataset, vocab, columns, freq_range, top_k, prefetch_size=None):
def __init__(self, input_dataset, vocab, columns, freq_range, top_k, special_tokens, special_first,
prefetch_size=None):
super().__init__()
self.columns = columns
self.input.append(input_dataset)
@ -4879,6 +4885,8 @@ class BuildVocabDataset(DatasetOp):
self.vocab = vocab
self.freq_range = freq_range
self.top_k = top_k
self.special_tokens = special_tokens
self.special_first = special_first
input_dataset.output.append(self)
def get_args(self):
@ -4888,6 +4896,8 @@ class BuildVocabDataset(DatasetOp):
args["freq_range"] = self.freq_range
args["prefetch_size"] = self.prefetch_size
args["top_k"] = self.top_k
args["special_tokens"] = self.special_tokens
args["special_first"] = self.special_first
return args
def __deepcopy__(self, memodict):
@ -4904,4 +4914,7 @@ class BuildVocabDataset(DatasetOp):
new_op.freq_range = copy.deepcopy(self.freq_range, memodict)
new_op.top_k = copy.deepcopy(self.top_k, memodict)
new_op.vocab = self.vocab
new_op.special_tokens = copy.deepcopy(self.special_tokens)
new_op.special_first = copy.deepcopy(self.special_first)
return new_op

View File

@ -28,10 +28,12 @@ from .validators import check_lookup, check_jieba_add_dict, \
class Lookup(cde.LookupOp):
"""
Lookup operator that looks up a word to an id
Lookup operator that looks up a word to an id.
Args:
vocab(Vocab): a Vocab object.
unknown(int): default id to lookup a word that is out of vocab (default is None).
unknown(int, optional): default id to lookup a word that is out of vocab. If no argument is passed, 1 will be
used to be the default id which is the convention for unknown_token <unk>. Otherwise, user is strongly
encouraged to pass in the id for <unk> (default=None).
"""
@check_lookup

View File

@ -25,12 +25,13 @@ from .validators import check_from_file, check_from_list, check_from_dict, check
class Vocab(cde.Vocab):
"""
Vocab object that is used for lookup word.
Vocab object that is used to lookup a word. It contains a map that maps each word(str) to an id (int)
"""
@classmethod
@check_from_dataset
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None):
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
special_first=None):
"""
Build a vocab from a dataset. This would collect all unique words in a dataset and return a vocab within
the frequency range specified by user in freq_range. User would be warned if no words fall into the frequency.
@ -49,11 +50,16 @@ class Vocab(cde.Vocab):
top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
taken. top_k is taken after freq_range. If not enough top_k, all words will be taken. (default=None
all words are included).
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
(default=None, no special tokens will be added).
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens
is specified and special_first is set to None, special_tokens will be prepended. (default=None).
return:
text.Vocab: Vocab object built from dataset.
"""
vocab = Vocab()
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k)
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
for d in root.create_dict_iterator():
if d is not None:
raise ValueError("from_dataset should receive data other than None.")
@ -61,17 +67,21 @@ class Vocab(cde.Vocab):
@classmethod
@check_from_list
def from_list(cls, word_list):
def from_list(cls, word_list, special_tokens=None, special_first=None):
"""
build a vocab object from a list of word.
Args:
word_list(list): a list of string where each element is a word.
word_list(list): a list of string where each element is a word of type string.
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
(default=None, no special tokens will be added).
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
is specified and special_first is set to None, special_tokens will be prepended. (default=None).
"""
return super().from_list(word_list)
return super().from_list(word_list, special_tokens, special_first)
@classmethod
@check_from_file
def from_file(cls, file_path, delimiter=None, vocab_size=None):
def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None):
"""
build a vocab object from a list of word.
Args:
@ -79,8 +89,12 @@ class Vocab(cde.Vocab):
delimiter(str, optional): a delimiter to break up each line in file, the first element is taken to be
the word (default=None).
vocab_size(int, optional): number of words to read from file_path (default=None, all words are taken).
special_tokens(list): a list of strings, each one is a special token. for e.g. ["<pad>","<unk>"]
(default=None, no special tokens will be added).
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
is specified and special_first is set to None, special_tokens will be prepended. (default=None).
"""
return super().from_file(file_path, delimiter, vocab_size)
return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first)
@classmethod
@check_from_dict
@ -88,7 +102,8 @@ class Vocab(cde.Vocab):
"""
build a vocab object from a dict.
Args:
word_dict(dict): dict contains word, id pairs. id should start from 2 and be continuous.
word_dict(dict): dict contains word, id pairs where word should be str and id int. id is recommended to
start from 0 and be continuous. ValueError will be raised if id is negative.
"""
return super().from_dict(word_dict)

View File

@ -23,6 +23,21 @@ import mindspore._c_dataengine as cde
from ..transforms.validators import check_uint32, check_pos_int64
def check_unique_list_of_words(words, arg_name):
"""Check that words is a list and each element is a str without any duplication"""
if not isinstance(words, list):
raise ValueError(arg_name + " needs to be a list of words of type string.")
words_set = set()
for word in words:
if not isinstance(word, str):
raise ValueError("each word in " + arg_name + " needs to be type str.")
if word in words_set:
raise ValueError(arg_name + " contains duplicate word: " + word + ".")
words_set.add(word)
return words_set
def check_lookup(method):
"""A wrapper that wrap a parameter checker to the original function."""
@ -52,13 +67,17 @@ def check_from_file(method):
@wraps(method)
def new_method(self, *args, **kwargs):
file_path, delimiter, vocab_size = (list(args) + 3 * [None])[:3]
file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5]
if "file_path" in kwargs:
file_path = kwargs.get("file_path")
if "delimiter" in kwargs:
delimiter = kwargs.get("delimiter")
if "vocab_size" in kwargs:
vocab_size = kwargs.get("vocab_size")
if "special_tokens" in kwargs:
special_tokens = kwargs.get("special_tokens")
if "special_first" in kwargs:
special_first = kwargs.get("special_first")
if not isinstance(file_path, str):
raise ValueError("file_path needs to be str.")
@ -73,9 +92,24 @@ def check_from_file(method):
raise ValueError("vocab size needs to be a positive integer.")
else:
vocab_size = -1
if special_first is None:
special_first = True
if not isinstance(special_first, bool):
raise ValueError("special_first needs to be a boolean value")
if special_tokens is None:
special_tokens = []
check_unique_list_of_words(special_tokens, "special_tokens")
kwargs["file_path"] = file_path
kwargs["delimiter"] = delimiter
kwargs["vocab_size"] = vocab_size
kwargs["special_tokens"] = special_tokens
kwargs["special_first"] = special_first
return method(self, **kwargs)
return new_method
@ -86,16 +120,32 @@ def check_from_list(method):
@wraps(method)
def new_method(self, *args, **kwargs):
word_list, = (list(args) + [None])[:1]
word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3]
if "word_list" in kwargs:
word_list = kwargs.get("word_list")
if not isinstance(word_list, list):
raise ValueError("word_list needs to be a list of words.")
for word in word_list:
if not isinstance(word, str):
raise ValueError("each word in word list needs to be type str.")
if "special_tokens" in kwargs:
special_tokens = kwargs.get("special_tokens")
if "special_first" in kwargs:
special_first = kwargs.get("special_first")
if special_tokens is None:
special_tokens = []
word_set = check_unique_list_of_words(word_list, "word_list")
token_set = check_unique_list_of_words(special_tokens, "special_tokens")
intersect = word_set.intersection(token_set)
if intersect != set():
raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
if special_first is None:
special_first = True
if not isinstance(special_first, bool):
raise ValueError("special_first needs to be a boolean value.")
kwargs["word_list"] = word_list
kwargs["special_tokens"] = special_tokens
kwargs["special_first"] = special_first
return method(self, **kwargs)
return new_method
@ -113,9 +163,9 @@ def check_from_dict(method):
raise ValueError("word_dict needs to be a list of word,id pairs.")
for word, word_id in word_dict.items():
if not isinstance(word, str):
raise ValueError("each word in word_dict needs to be type str.")
raise ValueError("Each word in word_dict needs to be type string.")
if not (isinstance(word_id, int) and word_id >= 0):
raise ValueError("each word id needs to be positive integer.")
raise ValueError("Each word id needs to be positive integer.")
kwargs["word_dict"] = word_dict
return method(self, **kwargs)
@ -135,11 +185,11 @@ def check_jieba_init(method):
mp_path = kwargs.get("mp_path")
if hmm_path is None:
raise ValueError(
"the dict of HMMSegment in cppjieba is not provided.")
"The dict of HMMSegment in cppjieba is not provided.")
kwargs["hmm_path"] = hmm_path
if mp_path is None:
raise ValueError(
"the dict of MPSegment in cppjieba is not provided.")
"The dict of MPSegment in cppjieba is not provided.")
kwargs["mp_path"] = mp_path
if model is not None:
kwargs["model"] = model
@ -171,7 +221,7 @@ def check_jieba_add_word(method):
def check_jieba_add_dict(method):
"""Wrapper method to check the parameters of add dict"""
"""Wrapper method to check the parameters of add dict."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -189,10 +239,10 @@ def check_jieba_add_dict(method):
def check_from_dataset(method):
"""A wrapper that wrap a parameter checker to the original function."""
# def from_dataset(cls, dataset, columns, freq_range=None, top_k=None):
@wraps(method)
def new_method(self, *args, **kwargs):
dataset, columns, freq_range, top_k = (list(args) + 4 * [None])[:4]
dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6]
if "dataset" in kwargs:
dataset = kwargs.get("dataset")
if "columns" in kwargs:
@ -201,6 +251,10 @@ def check_from_dataset(method):
freq_range = kwargs.get("freq_range")
if "top_k" in kwargs:
top_k = kwargs.get("top_k")
if "special_tokens" in kwargs:
special_tokens = kwargs.get("special_tokens")
if "special_first" in kwargs:
special_first = kwargs.get("special_first")
if columns is None:
columns = []
@ -232,10 +286,23 @@ def check_from_dataset(method):
if isinstance(top_k, int) and top_k <= 0:
raise ValueError("top_k needs to be a positive integer.")
if special_first is None:
special_first = True
if special_tokens is None:
special_tokens = []
if not isinstance(special_first, bool):
raise ValueError("special_first needs to be a boolean value.")
check_unique_list_of_words(special_tokens, "special_tokens")
kwargs["dataset"] = dataset
kwargs["columns"] = columns
kwargs["freq_range"] = freq_range
kwargs["top_k"] = top_k
kwargs["special_tokens"] = special_tokens
kwargs["special_first"] = special_first
return method(self, **kwargs)

View File

@ -0,0 +1,8 @@
w1
w2
w3
w4
w5
w6
w7
w8

View File

@ -23,19 +23,21 @@ import mindspore.dataset.text as text
def test_demo_basic_from_dataset():
""" this is a tutorial on how from_dataset should be used in a normal use case"""
data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False)
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None)
vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
special_first=True)
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
res = []
for d in data.create_dict_iterator():
res.append(d["text"].item())
assert res == [4, 5, 3, 6, 7, 2]
assert res == [4, 5, 3, 6, 7, 2], res
def test_demo_basic_from_dataset_with_tokenizer():
""" this is a tutorial on how from_dataset should be used in a normal use case with tokenizer"""
data = ds.TextFileDataset("../data/dataset/testTokenizerData/1.txt", shuffle=False)
data = data.map(input_columns=["text"], operations=text.UnicodeCharTokenizer())
vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None)
vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"],
special_first=True)
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
res = []
for d in data.create_dict_iterator():
@ -55,7 +57,8 @@ def test_from_dataset():
def test_config(freq_range, top_k):
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k)
vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k, special_tokens=["<pad>", "<unk>"],
special_first=True)
corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab))
res = []
for d in corpus_dataset.create_dict_iterator():
@ -87,6 +90,35 @@ def test_from_dataset():
assert test6_res == [[4, 4, 4, 4], [3, 3, 3, 3], [2, 2, 2, 2], [1, 1, 1], [1, 1, 1], [1, 1], [1]], str(test6_res)
def test_from_dataset_special_token():
""" test build vocab with generator dataset """
def gen_corpus():
# key: word, value: number of occurrences, reason for using letters is so their order is apparent
corpus = {"D": 1, "C": 1, "B": 1, "A": 1}
for k, v in corpus.items():
yield (np.array([k] * v, dtype='S'),)
def gen_input(texts):
for word in texts.split(" "):
yield (np.array(word, dtype='S'),)
def test_config(texts, top_k, special_tokens, special_first):
corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"])
vocab = text.Vocab.from_dataset(corpus_dataset, None, None, top_k, special_tokens, special_first)
data = ds.GeneratorDataset(gen_input(texts), column_names=["text"])
data = data.map(input_columns="text", operations=text.Lookup(vocab))
res = []
for d in data.create_dict_iterator():
res.append(d["text"].item())
return res
# test special tokens are inserted before
assert test_config("A B C D <pad> <unk>", 4, ["<pad>", "<unk>"], True) == [2, 3, 4, 5, 0, 1]
# test special tokens are inserted after
assert test_config("A B C D <pad> <unk>", 4, ["<pad>", "<unk>"], False) == [0, 1, 2, 3, 4, 5]
def test_from_dataset_exceptions():
""" test various exceptions during that are checked in validator """
@ -105,8 +137,10 @@ def test_from_dataset_exceptions():
test_config("text", (2, 3), 0, "top_k needs to be a positive integer")
test_config([123], (2, 3), 0, "columns need to be a list of strings")
if __name__ == '__main__':
test_demo_basic_from_dataset()
test_from_dataset()
test_from_dataset_exceptions()
test_demo_basic_from_dataset_with_tokenizer()
test_from_dataset_special_token()

View File

@ -33,7 +33,7 @@ def test_on_tokenized_line():
word = line.split(',')[0]
jieba_op.add_word(word)
data = data.map(input_columns=["text"], operations=jieba_op)
vocab = text.Vocab.from_file(VOCAB_FILE, ",")
vocab = text.Vocab.from_file(VOCAB_FILE, ",", special_tokens=["<pad>", "<unk>"])
lookup = text.Lookup(vocab)
data = data.map(input_columns=["text"], operations=lookup)
res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14],

View File

@ -1,13 +1,31 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.text as text
# this file contains "home is behind the world head" each word is 1 line
DATA_FILE = "../data/dataset/testVocab/words.txt"
VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt"
SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt"
def test_from_list():
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "))
def test_from_list_tutorial():
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True)
lookup = text.Lookup(vocab)
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
data = data.map(input_columns=["text"], operations=lookup)
@ -18,8 +36,8 @@ def test_from_list():
ind += 1
def test_from_file():
vocab = text.Vocab.from_file(VOCAB_FILE, ",")
def test_from_file_tutorial():
vocab = text.Vocab.from_file(VOCAB_FILE, ",", None, ["<pad>", "<unk>"], True)
lookup = text.Lookup(vocab)
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
data = data.map(input_columns=["text"], operations=lookup)
@ -30,7 +48,7 @@ def test_from_file():
ind += 1
def test_from_dict():
def test_from_dict_tutorial():
vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6})
lookup = text.Lookup(vocab, 6) # default value is -1
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
@ -41,7 +59,61 @@ def test_from_dict():
assert d["text"] == res[ind], ind
ind += 1
def test_from_list():
def gen(texts):
for word in texts.split(" "):
yield (np.array(word, dtype='S'),)
def test_config(lookup_str, vocab_input, special_tokens, special_first):
try:
vocab = text.Vocab.from_list(vocab_input, special_tokens, special_first)
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
res = []
for d in data.create_dict_iterator():
res.append(d["text"].item())
return res
except ValueError as e:
return str(e)
# test normal operations
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], True) == [2, 3, 4, 0, 1]
assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False) == [0, 1, 2, 3, 4]
assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, True) == [2, 1, 0]
assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, False) == [2, 1, 0]
# test exceptions
assert "word_list contains duplicate" in test_config("w1", ["w1", "w1"], [], True)
assert "special_tokens contains duplicate" in test_config("w1", ["w1", "w2"], ["s1", "s1"], True)
assert "special_tokens and word_list contain duplicate" in test_config("w1", ["w1", "w2"], ["s1", "w1"], True)
def test_from_file():
def gen(texts):
for word in texts.split(" "):
yield (np.array(word, dtype='S'),)
def test_config(lookup_str, special_tokens, special_first):
try:
vocab = text.Vocab.from_file(SIMPLE_VOCAB_FILE, special_tokens=special_tokens, special_first=special_first)
data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
data = data.map(input_columns=["text"], operations=text.Lookup(vocab))
res = []
for d in data.create_dict_iterator():
res.append(d["text"].item())
return res
except ValueError as e:
return str(e)
assert test_config("w1 w2 w3", ["s1", "s2", "s3"], True) == [3, 4, 5]
assert test_config("w1 w2 w3", ["s1", "s2", "s3"], False) == [0, 1, 2]
assert "special_tokens contains duplicate" in test_config("w1", ["s1", "s1"], True)
if __name__ == '__main__':
test_from_list_tutorial()
test_from_file_tutorial()
test_from_dict_tutorial()
test_from_list()
test_from_file()
test_from_dict()