diff --git a/mindspore/ccsrc/dataset/CMakeLists.txt b/mindspore/ccsrc/dataset/CMakeLists.txt index abea7a7c47f..6abd9286c24 100644 --- a/mindspore/ccsrc/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/CMakeLists.txt @@ -52,6 +52,7 @@ add_subdirectory(core) add_subdirectory(kernels) add_subdirectory(engine) add_subdirectory(api) +add_subdirectory(nlp) ###################################################################### ################### Create _c_dataengine Library ###################### @@ -68,6 +69,8 @@ set(submodules $ $ $ + $ + $ ) if (ENABLE_TDTQUE) diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 951aaaeccf5..bf77f67fad8 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -40,6 +40,8 @@ #include "dataset/kernels/data/type_cast_op.h" #include "dataset/kernels/text/jieba_tokenizer_op.h" #include "dataset/kernels/text/unicode_char_tokenizer_op.h" +#include "dataset/nlp/vocab.h" +#include "dataset/nlp/kernels/lookup_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/image_folder_op.h" #include "dataset/engine/datasetops/source/io_block.h" @@ -414,10 +416,13 @@ void bindTensorOps5(py::module *m) { py::arg("mode") = JiebaMode::kMix) .def("add_word", [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); - (void)py::class_>( *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") .def(py::init<>()); + (void)py::class_>(*m, "LookupOp", + "Tensor operation to LookUp each word") + .def(py::init, WordIdType>(), py::arg("vocab"), py::arg("unknown")) + .def(py::init>(), py::arg("vocab")); } void bindSamplerOps(py::module *m) { @@ -479,6 +484,27 @@ void bindInfoObjects(py::module *m) { .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); } +void bindVocabObjects(py::module *m) { + (void)py::class_>(*m, "Vocab") + .def_static("from_list", + [](const py::list &words) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromPyList(words, &v)); + return v; + }) + .def_static("from_file", + [](const std::string &path, const std::string &dlm, int32_t vocab_size) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, &v)); + return v; + }) + .def_static("from_dict", [](const py::dict &words) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v)); + return v; + }); +} + // This is where we externalize the C logic as python modules PYBIND11_MODULE(_c_dataengine, m) { m.doc() = "pybind11 for _c_dataengine"; @@ -543,6 +569,7 @@ PYBIND11_MODULE(_c_dataengine, m) { bindSamplerOps(&m); bindDatasetOps(&m); bindInfoObjects(&m); + bindVocabObjects(&m); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/nlp/CMakeLists.txt b/mindspore/ccsrc/dataset/nlp/CMakeLists.txt new file mode 100644 index 00000000000..b6a7b79107d --- /dev/null +++ b/mindspore/ccsrc/dataset/nlp/CMakeLists.txt @@ -0,0 +1,7 @@ +add_subdirectory(kernels) + +add_library(nlp OBJECT + vocab.cc + ) + +add_dependencies(nlp nlp-kernels) \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/nlp/kernels/CMakeLists.txt b/mindspore/ccsrc/dataset/nlp/kernels/CMakeLists.txt new file mode 100644 index 00000000000..d0019334b2f --- /dev/null +++ b/mindspore/ccsrc/dataset/nlp/kernels/CMakeLists.txt @@ -0,0 +1,3 @@ +add_library(nlp-kernels OBJECT + lookup_op.cc + ) diff --git a/mindspore/ccsrc/dataset/nlp/kernels/lookup_op.cc b/mindspore/ccsrc/dataset/nlp/kernels/lookup_op.cc new file mode 100644 index 00000000000..b9789633c86 --- /dev/null +++ b/mindspore/ccsrc/dataset/nlp/kernels/lookup_op.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/nlp/kernels/lookup_op.h" + +#include + +namespace mindspore { +namespace dataset { + +LookupOp::LookupOp(std::shared_ptr vocab, WordIdType default_id) + : vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} + +Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + RETURN_UNEXPECTED_IF_NULL(vocab_); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor"); + std::vector word_ids; + word_ids.reserve(input->Size()); + for (auto itr = input->begin(); itr != input->end(); itr++) { + word_ids.push_back(vocab_->Lookup(std::string(*itr), default_id_)); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_, + reinterpret_cast(word_ids.data()))); + return Status::OK(); +} +Status LookupOp::OutputType(const std::vector &inputs, std::vector &outputs) { + CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match"); + CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type"); + outputs[0] = type_; + return Status::OK(); +} + +void LookupOp::Print(std::ostream &out) const { + out << "LookupOp: " + << "type: " << type_ << "\n default lookup id: " << default_id_ << "\n"; +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/nlp/kernels/lookup_op.h b/mindspore/ccsrc/dataset/nlp/kernels/lookup_op.h new file mode 100644 index 00000000000..e9fdeb33517 --- /dev/null +++ b/mindspore/ccsrc/dataset/nlp/kernels/lookup_op.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_NLP_KERNELS_LOOKUP_OP_H_ +#define DATASET_NLP_KERNELS_LOOKUP_OP_H_ + +#include +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/status.h" +#include "dataset/nlp/vocab.h" + +namespace mindspore { +namespace dataset { +class LookupOp : public TensorOp { + public: + // constructor for lookup, takes in a vocab object + // @param std::shared_ptr vocab - + // @param WordIdType default_id, id to lookup if a word is not in vocab + explicit LookupOp(std::shared_ptr vocab, WordIdType default_id = Vocab::kSpecialTokens::unk); + + // perform actual lookup on each tensor + // @param const std::shared_ptr &input + // @param std::shared_ptr *output + // @return error code + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + // print method + // @param std::ostream out + void Print(std::ostream &out) const override; + + // @param std::vector &inputs - + // @param std::vector &outputs - + // @return error code + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + private: + std::shared_ptr vocab_; + WordIdType default_id_; + DataType type_; // type of tensor after lookup +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_NLP_KERNELS_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/dataset/nlp/vocab.cc b/mindspore/ccsrc/dataset/nlp/vocab.cc new file mode 100644 index 00000000000..291d0d2686d --- /dev/null +++ b/mindspore/ccsrc/dataset/nlp/vocab.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "dataset/nlp/vocab.h" + +namespace mindspore { +namespace dataset { +Vocab::Vocab(std::unordered_map word2id) { + word2id_ = std::move(word2id); + id2word_.resize(word2id_.size()); + for (auto p : word2id_) { + id2word_[p.second - kSpecialTokens::num_tokens] = p.first; + } +} + +WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const { + auto itr = word2id_.find(word); + return itr == word2id_.end() ? default_id : itr->second; +} +WordType Vocab::Lookup(WordIdType id) const { + if (id < kSpecialTokens::num_tokens) { + return reserved_token_str_[id]; + } else if (id - kSpecialTokens::num_tokens >= id2word_.size()) { + return reserved_token_str_[kSpecialTokens::unk]; + } else { + return id2word_[id - kSpecialTokens::num_tokens]; + } +} + +Status Vocab::BuildFromPyList(const py::list &words, std::shared_ptr *vocab) { + std::unordered_map word2id; + WordIdType word_id = kSpecialTokens::num_tokens; + for (auto word : words) { + const std::string s = py::str(word); + CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(s) == word2id.end(), "duplicate word:" + s); + word2id[s] = word_id++; + } + *vocab = std::make_shared(std::move(word2id)); + return Status::OK(); +} + +Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, + std::shared_ptr *vocab) { + std::unordered_map word2id; + WordIdType word_id = kSpecialTokens::num_tokens; + std::fstream handle(path, std::ios::in); + CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path); + std::string word; + while (std::getline(handle, word)) { + if (!delimiter.empty()) { + // if delimiter is not found, find_first_of would return std::string::npos which is -1 + word = word.substr(0, word.find_first_of(delimiter)); + } + CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word); + word2id[word] = word_id++; + // break if enough row is read, if vocab_size is smaller than 0 + if (word_id == vocab_size + kSpecialTokens::num_tokens) break; + } + *vocab = std::make_shared(std::move(word2id)); + return Status::OK(); +} + +Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr *vocab) { + std::unordered_map word2id; + std::map id2word; + for (auto p : words) { + WordIdType word_id = py::reinterpret_borrow(p.second); + if (word_id < kSpecialTokens::num_tokens) continue; // skip id that are reserved + std::string word = py::str(p.first); + CHECK_FAIL_RETURN_UNEXPECTED(id2word.find(word_id) == id2word.end(), "duplicate id:" + word); + id2word[word_id] = word; + } + + WordIdType cnt = kSpecialTokens::num_tokens; + for (auto p : id2word) { + CHECK_FAIL_RETURN_UNEXPECTED(p.first == cnt++, "word id needs to be continuous starting from 2"); + word2id[p.second] = p.first; + } + + *vocab = std::make_shared(std::move(word2id)); + return Status::OK(); +} +const std::vector Vocab::reserved_token_str_ = {"", ""}; +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/nlp/vocab.h b/mindspore/ccsrc/dataset/nlp/vocab.h new file mode 100644 index 00000000000..4cc6dcaa8cf --- /dev/null +++ b/mindspore/ccsrc/dataset/nlp/vocab.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_NLP_VOCAB_H_ +#define DATASET_NLP_VOCAB_H_ + +#include +#include +#include +#include + +#include "dataset/util/status.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace mindspore { +namespace dataset { +namespace py = pybind11; + +using WordIdType = int32_t; +using WordType = std::string; + +class Vocab { + public: + // Build a vocab from a python dictionary key is each word ,id needs to start from 2, no duplicate and continuous + // @param const py::dict &words - a dictionary containing word, word id pair. + // @param std::shared_ptr *vocab - return value, vocab object + // @return error code + static Status BuildFromPyDict(const py::dict &words, std::shared_ptr *vocab); + + // Build a vocab from a python list, id will be assigned automatically, start from 2 + // @param const py::list &words - a list of string, used to build vocab, id starts from 2 + // @param std::shared_ptr *vocab - return value, vocab object + // @return error code + static Status BuildFromPyList(const py::list &words, std::shared_ptr *vocab); + + // Build a vocab from reading a vocab file, id are automatically assigned, start from 2 + // @param std::string &path - path to vocab file , each line is assumed to contain 1 word + // @param std::string &delimiter - delimiter to break each line with + // @param int32_t vocab_size - number of words to read from file + // @param std::shared_ptr *vocab - return value, vocab object + // @return error code + static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, + std::shared_ptr *vocab); + + // Lookup the id of a word, if word doesn't exist in vocab, return default_id + // @param const WordType word - word to look up + // @param WordIdType default_id - word id to return to user when its not in the vocab + // @return WordIdType, word_id + WordIdType Lookup(const WordType &word, WordIdType default_id) const; + + // reverse lookup, lookup the word based on its id + // @param WordIdType id - word id to lookup to + // @return WordType the word + WordType Lookup(WordIdType id) const; + + // constructor, shouldn't be called directly, can't be private due to std::make_unique() + // @param std::unordered_map map - sanitized word2id map + explicit Vocab(std::unordered_map map); + + // enum type that holds all special tokens, add more if needed + enum kSpecialTokens : WordIdType { pad = 0, unk = 1, num_tokens = 2 }; + + // reversed lookup table for the reserved tokens + static const std::vector reserved_token_str_; + + private: + std::unordered_map word2id_; + std::vector id2word_; // reverse lookup +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_NLP_VOCAB_H_ diff --git a/mindspore/dataset/text/__init__.py b/mindspore/dataset/text/__init__.py new file mode 100644 index 00000000000..3ce73b45d09 --- /dev/null +++ b/mindspore/dataset/text/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +mindspore.dataset.text +""" + +from .c_transforms import * diff --git a/mindspore/dataset/text/c_transforms.py b/mindspore/dataset/text/c_transforms.py new file mode 100644 index 00000000000..6bb609e8f1a --- /dev/null +++ b/mindspore/dataset/text/c_transforms.py @@ -0,0 +1,77 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +c transforms for all text related operators +""" + +import mindspore._c_dataengine as cde +from .validators import check_lookup, check_from_list, check_from_dict, check_from_file + + +class Vocab(cde.Vocab): + """ + Vocab object that is used for lookup word + Args: + """ + + def __init__(self): + pass + + @classmethod + @check_from_list + def from_list(cls, word_list): + """ + build a vocab object from a list of word + Args: + word_list(list): a list of string where each element is a word + """ + return super().from_list(word_list) + + @classmethod + @check_from_file + def from_file(cls, file_path, delimiter=None, vocab_size=None): + """ + build a vocab object from a list of word + Args: + file_path(str): path to the file which contains the vocab list + delimiter(None, str): a delimiter to break up each line in file, the first element is taken to be the word + vocab_size(None, int): number of words to read from file_path + """ + return super().from_file(file_path, delimiter, vocab_size) + + @classmethod + @check_from_dict + def from_dict(cls, word_dict): + """ + build a vocab object from a dict. + Args: + word_dict(dict): dict contains word, id pairs. id should start from 2 and continuous + """ + return super().from_dict(word_dict) + + +class Lookup(cde.LookupOp): + """ + Lookup operator that looks up a word to an id + Args: + vocab(Vocab): a Vocab object + unknown(None,int): default id to lookup a word that is out of vocab + """ + + @check_lookup + def __init__(self, vocab, unknown=None): + if unknown is None: + super().__init__(vocab) + else: + super().__init__(vocab, unknown) diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py new file mode 100644 index 00000000000..d043a230f66 --- /dev/null +++ b/mindspore/dataset/text/validators.py @@ -0,0 +1,108 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +validators for text ops +""" + +from functools import wraps +import mindspore._c_dataengine as cde + + +def check_lookup(method): + """A wrapper that wrap a parameter checker to the original function(crop operation).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + vocab, unknown = (list(args) + 2 * [None])[:2] + if "vocab" in kwargs: + vocab = kwargs.get("vocab") + if "unknown" in kwargs: + unknown = kwargs.get("unknown") + if unknown is not None: + assert isinstance(unknown, int) and unknown >= 0, "unknown needs to be a non-negative integer" + + assert isinstance(vocab, cde.Vocab), "vocab is not an instance of cde.Vocab" + + kwargs["vocab"] = vocab + kwargs["unknown"] = unknown + return method(self, **kwargs) + + return new_method + + +def check_from_file(method): + """A wrapper that wrap a parameter checker to the original function(crop operation).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + file_path, delimiter, vocab_size = (list(args) + 3 * [None])[:3] + if "file_path" in kwargs: + file_path = kwargs.get("file_path") + if "delimiter" in kwargs: + delimiter = kwargs.get("delimiter") + if "vocab_size" in kwargs: + vocab_size = kwargs.get("vocab_size") + + assert isinstance(file_path, str), "file_path needs to be str" + if delimiter is not None: + assert isinstance(delimiter, str), "delimiter needs to be str" + else: + delimiter = "" + if vocab_size is not None: + assert isinstance(vocab_size, int) and vocab_size > 0, "vocab size needs to be a positive integer" + else: + vocab_size = -1 + kwargs["file_path"] = file_path + kwargs["delimiter"] = delimiter + kwargs["vocab_size"] = vocab_size + return method(self, **kwargs) + + return new_method + + +def check_from_list(method): + """A wrapper that wrap a parameter checker to the original function(crop operation).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + word_list, = (list(args) + [None])[:1] + if "word_list" in kwargs: + word_list = kwargs.get("word_list") + assert isinstance(word_list, list), "word_list needs to be a list of words" + for word in word_list: + assert isinstance(word, str), "each word in word list needs to be type str" + + kwargs["word_list"] = word_list + return method(self, **kwargs) + + return new_method + + +def check_from_dict(method): + """A wrapper that wrap a parameter checker to the original function(crop operation).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + word_dict, = (list(args) + [None])[:1] + if "word_dict" in kwargs: + word_dict = kwargs.get("word_dict") + assert isinstance(word_dict, dict), "word_dict needs to be a list of word,id pairs" + for word, word_id in word_dict.items(): + assert isinstance(word, str), "each word in word_dict needs to be type str" + assert isinstance(word_id, int) and word_id >= 0, "each word id needs to be positive integer" + kwargs["word_dict"] = word_dict + return method(self, **kwargs) + + return new_method diff --git a/tests/ut/data/dataset/testVocab/vocab_list.txt b/tests/ut/data/dataset/testVocab/vocab_list.txt new file mode 100644 index 00000000000..eaaa7d4e444 --- /dev/null +++ b/tests/ut/data/dataset/testVocab/vocab_list.txt @@ -0,0 +1,14 @@ +not,1 +all,2 +those,3 +who,4 +wonder,5 +are,6 +lost,7 +Tolkein,8 +home,9 +is,10 +behind,11 +world,12 +ahead,13 +the,14 diff --git a/tests/ut/data/dataset/testVocab/words.txt b/tests/ut/data/dataset/testVocab/words.txt new file mode 100644 index 00000000000..6535e86e8a7 --- /dev/null +++ b/tests/ut/data/dataset/testVocab/words.txt @@ -0,0 +1,6 @@ +home +is +behind +the +world +ahead diff --git a/tests/ut/python/dataset/test_vocab.py b/tests/ut/python/dataset/test_vocab.py new file mode 100644 index 00000000000..4230f9324b7 --- /dev/null +++ b/tests/ut/python/dataset/test_vocab.py @@ -0,0 +1,47 @@ +import mindspore.dataset as ds +import mindspore.dataset.text as text + +# this file contains "home is behind the world head" each word is 1 line +DATA_FILE = "../data/dataset/testVocab/words.txt" +VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt" + + +def test_from_list(): + vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" ")) + lookup = text.Lookup(vocab) + data = ds.TextFileDataset(DATA_FILE, shuffle=False) + data = data.map(input_columns=["text"], operations=lookup) + ind = 0 + res = [2, 1, 4, 5, 6, 7] + for d in data.create_dict_iterator(): + assert d["text"] == res[ind], ind + ind += 1 + + +def test_from_file(): + vocab = text.Vocab.from_file(VOCAB_FILE, ",") + lookup = text.Lookup(vocab) + data = ds.TextFileDataset(DATA_FILE, shuffle=False) + data = data.map(input_columns=["text"], operations=lookup) + ind = 0 + res = [10, 11, 12, 15, 13, 14] + for d in data.create_dict_iterator(): + assert d["text"] == res[ind], ind + ind += 1 + + +def test_from_dict(): + vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "": 6}) + lookup = text.Lookup(vocab, 6) # default value is -1 + data = ds.TextFileDataset(DATA_FILE, shuffle=False) + data = data.map(input_columns=["text"], operations=lookup) + res = [3, 6, 2, 4, 5, 6] + ind = 0 + for d in data.create_dict_iterator(): + assert d["text"] == res[ind], ind + ind += 1 + +if __name__ == '__main__': + test_from_list() + test_from_file() + test_from_dict()