diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/ir/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/ir/bindings.cc index 48121570af2..d9a8b4b28bb 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/ir/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/ir/bindings.cc @@ -154,8 +154,8 @@ PYBIND_REGISTER( PYBIND_REGISTER(LookupOperation, 1, ([](const py::module *m) { (void)py::class_>( *m, "LookupOperation") - .def(py::init([](const std::shared_ptr &vocab, const std::string &unknown_token, - const std::string &data_type) { + .def(py::init([](const std::shared_ptr &vocab, + const std::optional &unknown_token, const std::string &data_type) { auto lookup = std::make_shared(vocab, unknown_token, data_type); THROW_IF_ERROR(lookup->ValidateParams()); return lookup; diff --git a/mindspore/ccsrc/minddata/dataset/api/text.cc b/mindspore/ccsrc/minddata/dataset/api/text.cc index 2544f835220..5d998415c15 100644 --- a/mindspore/ccsrc/minddata/dataset/api/text.cc +++ b/mindspore/ccsrc/minddata/dataset/api/text.cc @@ -87,8 +87,8 @@ std::shared_ptr JiebaTokenizer(const std::string &hmm_p return op->ValidateParams() ? op : nullptr; } -std::shared_ptr Lookup(const std::shared_ptr &vocab, const std::string &unknown_token, - const std::string &data_type) { +std::shared_ptr Lookup(const std::shared_ptr &vocab, + const std::optional &unknown_token, const std::string &data_type) { auto op = std::make_shared(vocab, unknown_token, data_type); return op->ValidateParams() ? op : nullptr; @@ -340,7 +340,7 @@ Status JiebaTokenizerOperation::AddWord(const std::string &word, int64_t freq) { } // LookupOperation -LookupOperation::LookupOperation(const std::shared_ptr &vocab, const std::string &unknown_token, +LookupOperation::LookupOperation(const std::shared_ptr &vocab, const std::optional &unknown_token, const std::string &data_type) : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {} @@ -352,10 +352,10 @@ Status LookupOperation::ValidateParams() { MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } - if (!unknown_token_.empty()) { - default_id_ = vocab_->Lookup(unknown_token_); + if (unknown_token_ != std::nullopt) { + default_id_ = vocab_->Lookup(*unknown_token_); if (default_id_ == Vocab::kNoTokenExists) { - std::string err_msg = "Lookup: \"" + unknown_token_ + "\" doesn't exist in vocab."; + std::string err_msg = "Lookup: \"" + *unknown_token_ + "\" doesn't exist in vocab."; MS_LOG(ERROR) << err_msg; RETURN_STATUS_SYNTAX_ERROR(err_msg); } diff --git a/mindspore/ccsrc/minddata/dataset/include/text.h b/mindspore/ccsrc/minddata/dataset/include/text.h index bf5522fe678..146fb5e799f 100644 --- a/mindspore/ccsrc/minddata/dataset/include/text.h +++ b/mindspore/ccsrc/minddata/dataset/include/text.h @@ -18,6 +18,7 @@ #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TEXT_H_ #include +#include #include #include #include @@ -143,11 +144,13 @@ std::shared_ptr JiebaTokenizer(const std::string &hmm_p /// \brief Lookup operator that looks up a word to an id. /// \param[in] vocab a Vocab object. /// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov). -/// If unknown_token is oov, runtime error will be thrown. +/// If unknown_token is oov, runtime error will be thrown. If unknown_token is {}, which means that not to +// specify unknown_token when word being out of Vocabulary (default={}). /// \param[in] data_type type of the tensor after lookup, typically int32. /// \return Shared pointer to the current TensorOperation. -std::shared_ptr Lookup(const std::shared_ptr &vocab, const std::string &unknown_token, +std::shared_ptr Lookup(const std::shared_ptr &vocab, + const std::optional &unknown_token = {}, const std::string &data_type = "int32"); /// \brief TensorOp to generate n-gram from a 1-D string Tensor. @@ -343,7 +346,7 @@ class JiebaTokenizerOperation : public TensorOperation { class LookupOperation : public TensorOperation { public: - explicit LookupOperation(const std::shared_ptr &vocab, const std::string &unknown_token, + explicit LookupOperation(const std::shared_ptr &vocab, const std::optional &unknown_token, const std::string &data_type); ~LookupOperation(); @@ -356,7 +359,7 @@ class LookupOperation : public TensorOperation { private: std::shared_ptr vocab_; - std::string unknown_token_; + std::optional unknown_token_; int32_t default_id_; std::string data_type_; }; diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 858a8f576aa..fb0f3295344 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -295,7 +295,7 @@ class Lookup(TextTensorOperation): @check_lookup def __init__(self, vocab, unknown_token=None, data_type=mstype.int32): self.vocab = vocab - self.unknown_token = replace_none(unknown_token, '') + self.unknown_token = unknown_token self.data_type = data_type def parse(self): diff --git a/tests/ut/python/dataset/test_vocab.py b/tests/ut/python/dataset/test_vocab.py index a6818ac2e7e..7c1d2568c39 100644 --- a/tests/ut/python/dataset/test_vocab.py +++ b/tests/ut/python/dataset/test_vocab.py @@ -119,6 +119,31 @@ def test_from_list(): assert "is not of type" in test_config("w1", ["w1", "w2"], ["s1"], True, 123) +def test_from_list_lookup_empty_string(): + # "" is a valid word in vocab, which can be looked up by LookupOp + vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["", ""], True) + lookup = text.Lookup(vocab, "") + data = ds.TextFileDataset(DATA_FILE, shuffle=False) + data = data.map(operations=lookup, input_columns=["text"]) + ind = 0 + res = [2, 1, 4, 5, 6, 7] + for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): + assert d["text"] == res[ind], ind + ind += 1 + + # unknown_token of LookUp is None, it will convert to std::nullopt in C++, + # so it has nothing to do with "" in vocab and C++ will skip looking up unknown_token + vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["", ""], True) + lookup = text.Lookup(vocab) + data = ds.TextFileDataset(DATA_FILE, shuffle=False) + data = data.map(operations=lookup, input_columns=["text"]) + try: + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + pass + except RuntimeError as e: + assert "token: \"is\" doesn't exist in vocab and no unknown token is specified" in str(e) + + def test_from_file(): def gen(texts): for word in texts.split(" "):