diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/bindings.cc index 995e8789c56..efcfa640f31 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/text/kernels/bindings.cc @@ -121,12 +121,13 @@ PYBIND_REGISTER(UnicodeCharTokenizerOp, 1, ([](const py::module *m) { PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) { (void)py::class_>(*m, "LookupOp") - .def(py::init([](std::shared_ptr vocab, const py::object &py_word) { + .def(py::init([](std::shared_ptr vocab, const py::object &py_word, + const DataType &data_type) { if (vocab == nullptr) { THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null.")); } if (py_word.is_none()) { - return std::make_shared(vocab, Vocab::kNoTokenExists); + return std::make_shared(vocab, Vocab::kNoTokenExists, data_type); } std::string word = py::reinterpret_borrow(py_word); WordIdType default_id = vocab->Lookup(word); @@ -134,7 +135,7 @@ PYBIND_REGISTER(LookupOp, 1, ([](const py::module *m) { THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "default unknown token: " + word + " doesn't exist in vocab.")); } - return std::make_shared(vocab, default_id); + return std::make_shared(vocab, default_id, data_type); })); })); diff --git a/mindspore/ccsrc/minddata/dataset/api/text.cc b/mindspore/ccsrc/minddata/dataset/api/text.cc index 5e5846ff064..594a4410cf4 100644 --- a/mindspore/ccsrc/minddata/dataset/api/text.cc +++ b/mindspore/ccsrc/minddata/dataset/api/text.cc @@ -22,8 +22,9 @@ namespace dataset { namespace api { namespace text { -std::shared_ptr Lookup(const std::shared_ptr &vocab, const std::string &unknown_token) { - auto op = std::make_shared(vocab, unknown_token); +std::shared_ptr Lookup(const std::shared_ptr &vocab, const std::string &unknown_token, + const DataType &data_type) { + auto op = std::make_shared(vocab, unknown_token, data_type); if (!op->ValidateParams()) { return nullptr; @@ -32,8 +33,9 @@ std::shared_ptr Lookup(const std::shared_ptr &vocab, con } // LookupOperation -LookupOperation::LookupOperation(const std::shared_ptr &vocab, const std::string &unknown_token) - : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists) {} +LookupOperation::LookupOperation(const std::shared_ptr &vocab, const std::string &unknown_token, + const DataType &data_type) + : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {} bool LookupOperation::ValidateParams() { if (vocab_ == nullptr) { @@ -54,7 +56,7 @@ bool LookupOperation::ValidateParams() { } std::shared_ptr LookupOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(vocab_, default_id_); + std::shared_ptr tensor_op = std::make_shared(vocab_, default_id_, data_type_); return tensor_op; } diff --git a/mindspore/ccsrc/minddata/dataset/include/text.h b/mindspore/ccsrc/minddata/dataset/include/text.h index 7edcdc027cd..3b9caddafee 100644 --- a/mindspore/ccsrc/minddata/dataset/include/text.h +++ b/mindspore/ccsrc/minddata/dataset/include/text.h @@ -20,9 +20,11 @@ #include #include #include + #include "minddata/dataset/core/constants.h" #include "minddata/dataset/include/transforms.h" #include "minddata/dataset/text/vocab.h" +#include "mindspore/ccsrc/minddata/dataset/core/data_type.h" namespace mindspore { namespace dataset { @@ -37,15 +39,18 @@ class LookupOperation; /// \brief Lookup operator that looks up a word to an id. /// \param[in] vocab a Vocab object. /// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov). -/// If unknown_token is oov, runtime error will be thrown +/// If unknown_token is oov, runtime error will be thrown. +/// \param[in] DataType type of the tensor after lookup, typically int32. /// \return Shared pointer to the current TensorOperation. -std::shared_ptr Lookup(const std::shared_ptr &vocab, const std::string &unknown_token); +std::shared_ptr Lookup(const std::shared_ptr &vocab, const std::string &unknown_token, + const mindspore::dataset::DataType &data_type = DataType("int32")); /* ####################################### Derived TensorOperation classes ################################# */ class LookupOperation : public TensorOperation { public: - explicit LookupOperation(const std::shared_ptr &vocab, const std::string &unknown_token); + explicit LookupOperation(const std::shared_ptr &vocab, const std::string &unknown_token, + const DataType &data_type); ~LookupOperation() = default; @@ -57,6 +62,7 @@ class LookupOperation : public TensorOperation { std::shared_ptr vocab_; std::string unknown_token_; int32_t default_id_; + DataType data_type_; }; } // namespace text } // namespace api diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc index 03178044160..802731b8fc1 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "minddata/dataset/text/kernels/lookup_op.h" - #include +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/text/kernels/lookup_op.h" + namespace mindspore { namespace dataset { -LookupOp::LookupOp(std::shared_ptr vocab, WordIdType default_id) - : vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} +LookupOp::LookupOp(std::shared_ptr vocab, WordIdType default_id, const DataType &data_type) + : vocab_(vocab), default_id_(default_id), type_(data_type) {} Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); @@ -37,6 +38,14 @@ Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptrshape(), output)); + + // type cast to user's requirements if what user wants isn't int32_t + if ((*output)->type() != type_) { + std::shared_ptr cast_to; + RETURN_IF_NOT_OK(TypeCast(*output, &cast_to, type_)); + *output = cast_to; + } + return Status::OK(); } Status LookupOp::OutputType(const std::vector &inputs, std::vector &outputs) { diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h index bd1bf67cd30..1b6ecf2c2af 100644 --- a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h @@ -18,9 +18,9 @@ #define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_KERNELS_LOOKUP_OP_H_ #include -#include -#include #include +#include +#include #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/kernels/tensor_op.h" @@ -31,26 +31,27 @@ namespace mindspore { namespace dataset { class LookupOp : public TensorOp { public: - // constructor for lookup, takes in a vocab object - // @param std::shared_ptr vocab - - // @param WordIdType default_id, id to lookup if a word is not in vocab - explicit LookupOp(std::shared_ptr vocab, WordIdType default_id = 1); + /// \brief constructor for lookup, takes in a vocab object. + /// \param[in] std::shared_ptr vocab - vocab used for lookup. + /// \param[in] WordIdType default_id, id to lookup if a word is not in vocab. + /// \param[in] DataType type of the tensor after lookup, mostly int32. + explicit LookupOp(std::shared_ptr vocab, WordIdType default_id, const DataType &data_type); ~LookupOp() = default; - // perform actual lookup on each tensor - // @param const std::shared_ptr &input - // @param std::shared_ptr *output - // @return error code + /// \brief perform actual lookup on each tensor. + /// \param[in] const std::shared_ptr &input + /// \param[in] std::shared_ptr *output + /// \return[out] error code. Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - // print method - // @param std::ostream out + /// \brief print method. + /// \param[in] std::ostream out void Print(std::ostream &out) const override; - // @param std::vector &inputs - - // @param std::vector &outputs - - // @return error code + /// \param[in] std::vector &inputs - + /// \param[in] std::vector &outputs - + /// \return[out] error code. Status OutputType(const std::vector &inputs, std::vector &outputs) override; std::string Name() const override { return kLookupOp; } diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 39b5ea78af2..bcd95de46ce 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -49,6 +49,7 @@ import platform import numpy as np import mindspore._c_dataengine as cde +import mindspore.common.dtype as mstype from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType from .validators import check_lookup, check_jieba_add_dict, \ @@ -66,11 +67,12 @@ class Lookup(cde.LookupOp): vocab(Vocab): a Vocab object. unknown_token(str, optional): word to use for lookup if the word being looked up is out of Vocabulary (oov). If unknown_token is oov, runtime error will be thrown (default=None). + data_type (mindspore.dtype, optional): mindspore.dtype lookup maps string to (default=mstype.int32) """ @check_lookup - def __init__(self, vocab, unknown_token=None): - super().__init__(vocab, unknown_token) + def __init__(self, vocab, unknown_token=None, data_type=mstype.int32): + super().__init__(vocab, unknown_token, mstype_to_detype(data_type)) class SlidingWindow(cde.SlidingWindowOp): @@ -103,7 +105,6 @@ class SlidingWindow(cde.SlidingWindowOp): super().__init__(width, axis) - class Ngram(cde.NgramOp): """ TensorOp to generate n-gram from a 1-D string Tensor. diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index bb4118c3374..7054a95087b 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -44,12 +44,13 @@ def check_lookup(method): @wraps(method) def new_method(self, *args, **kwargs): - [vocab, unknown_token], _ = parse_user_args(method, *args, **kwargs) + [vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs) if unknown_token is not None: type_check(unknown_token, (str,), "unknown_token") type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.") + type_check(data_type, (typing.Type,), "data_type") return method(self, *args, **kwargs) @@ -327,6 +328,7 @@ def check_from_dataset(method): return new_method + def check_slidingwindow(method): """A wrapper that wraps a parameter checker to the original function(sliding window operation).""" @@ -339,6 +341,7 @@ def check_slidingwindow(method): return new_method + def check_ngram(method): """A wrapper that wraps a parameter checker to the original function.""" diff --git a/tests/ut/cpp/dataset/c_api_dataset_vocab.cc b/tests/ut/cpp/dataset/c_api_dataset_vocab.cc index 78276702dfd..3d15b3cc9e4 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_vocab.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_vocab.cc @@ -26,9 +26,10 @@ #include "minddata/dataset/include/text.h" using namespace mindspore::dataset::api; +using mindspore::dataset::DataType; using mindspore::dataset::ShuffleMode; -using mindspore::dataset::Tensor; using mindspore::dataset::Status; +using mindspore::dataset::Tensor; using mindspore::dataset::Vocab; class MindDataTestPipeline : public UT::DatasetOpTesting { @@ -50,7 +51,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOp) { EXPECT_EQ(s, Status::OK()); // Create Lookup operation on ds - std::shared_ptr lookup = text::Lookup(vocab, ""); + std::shared_ptr lookup = text::Lookup(vocab, "", DataType("int32")); EXPECT_NE(lookup, nullptr); // Create Map operation on ds @@ -94,7 +95,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail1) { // Create lookup op for ds // Expected failure: "" is not a word of vocab - std::shared_ptr lookup = text::Lookup(vocab, ""); + std::shared_ptr lookup = text::Lookup(vocab, "", DataType("int32")); EXPECT_EQ(lookup, nullptr); } @@ -105,7 +106,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpFail2) { // Create lookup op // Expected failure: vocab is null - std::shared_ptr lookup = text::Lookup(vocab, ""); + std::shared_ptr lookup = text::Lookup(vocab, "", DataType("int32")); EXPECT_EQ(lookup, nullptr); } @@ -126,7 +127,7 @@ TEST_F(MindDataTestPipeline, TestVocabLookupOpWithEmptyUnknownToken) { // Create Lookup operation on ds // Expected failure: "" is not a word of vocab - std::shared_ptr lookup = text::Lookup(vocab, ""); + std::shared_ptr lookup = text::Lookup(vocab, "", DataType("int32")); EXPECT_EQ(lookup, nullptr); } @@ -148,7 +149,7 @@ TEST_F(MindDataTestPipeline, TestVocabFromDataset) { EXPECT_EQ(home_index, 4); // Create Lookup operation on ds - std::shared_ptr lookup = text::Lookup(vocab, ""); + std::shared_ptr lookup = text::Lookup(vocab, "", DataType("int32")); EXPECT_NE(lookup, nullptr); // Create Map operation on ds @@ -212,12 +213,15 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetDefault) { uint64_t i = 0; std::vector expected = {2, 3, 1, 4, 5, 0}; + std::vector not_expected = {2, 3, 1, 4, 5, 0}; while (row.size() != 0) { auto ind = row["text"]; MS_LOG(INFO) << ind->shape() << " " << *ind; - std::shared_ptr expected_item; + std::shared_ptr expected_item, not_expected_item; Tensor::CreateScalar(expected[i], &expected_item); + Tensor::CreateScalar(not_expected[i], ¬_expected_item); EXPECT_EQ(*ind, *expected_item); + EXPECT_NE(*ind, *not_expected_item); iter->GetNextRow(&row); i++; } @@ -233,8 +237,8 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail1) { // Create vocab from dataset // Expected failure: top_k can not be negative - std::shared_ptr vocab = ds->BuildVocab({"text"}, {0, std::numeric_limits::max()}, - -2, {"", ""}, true); + std::shared_ptr vocab = + ds->BuildVocab({"text"}, {0, std::numeric_limits::max()}, -2, {"", ""}, true); EXPECT_EQ(vocab, nullptr); } @@ -247,9 +251,9 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail2) { EXPECT_NE(ds, nullptr); // Create vocab from dataset - // Expected failure: requency_range [a,b] should be 0 <= a <= b - std::shared_ptr vocab = ds->BuildVocab({"text"}, {4, 1}, - std::numeric_limits::max(), {"", ""}, true); + // Expected failure: frequency_range [a,b] should be 0 <= a <= b + std::shared_ptr vocab = + ds->BuildVocab({"text"}, {4, 1}, std::numeric_limits::max(), {"", ""}, true); EXPECT_EQ(vocab, nullptr); } @@ -266,3 +270,52 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail3) { std::shared_ptr vocab = ds->BuildVocab({"ColumnNotExist"}); EXPECT_EQ(vocab, nullptr); } + +TEST_F(MindDataTestPipeline, TestVocabFromDatasetInt64) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVocabFromDatasetInt64."; + + // Create a TextFile dataset + std::string data_file = datasets_root_path_ + "/testVocab/words.txt"; + std::shared_ptr ds = TextFile({data_file}, 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create vocab from dataset + std::shared_ptr vocab = ds->BuildVocab(); + EXPECT_NE(vocab, nullptr); + + // Check if vocab has words or not + int32_t home_index = vocab->Lookup("home"); + EXPECT_EQ(home_index, 2); + + // Create Lookup operation on ds + std::shared_ptr lookup = text::Lookup(vocab, "home", DataType("int64")); + EXPECT_NE(lookup, nullptr); + + // Create Map operation on ds + ds = ds->Map({lookup}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + std::vector expected = {2, 3, 1, 4, 5, 0}; + std::vector not_expected = {2, 3, 1, 4, 5, 0}; + while (row.size() != 0) { + auto ind = row["text"]; + MS_LOG(INFO) << ind->shape() << " " << *ind; + std::shared_ptr expected_item, not_expected_item; + Tensor::CreateScalar(expected[i], &expected_item); + Tensor::CreateScalar(not_expected[i], ¬_expected_item); + EXPECT_EQ(*ind, *expected_item); + EXPECT_NE(*ind, *not_expected_item); + iter->GetNextRow(&row); + i++; + } +} \ No newline at end of file diff --git a/tests/ut/python/dataset/test_vocab.py b/tests/ut/python/dataset/test_vocab.py index 04cb463eb8c..5ced80d7fe0 100644 --- a/tests/ut/python/dataset/test_vocab.py +++ b/tests/ut/python/dataset/test_vocab.py @@ -17,6 +17,7 @@ import numpy as np import mindspore.dataset as ds import mindspore.dataset.text as text +import mindspore.common.dtype as mstype # this file contains "home is behind the world head" each word is 1 line DATA_FILE = "../data/dataset/testVocab/words.txt" @@ -137,6 +138,36 @@ def test_from_file(): assert "Input vocab_size must be greater than 0" in test_config("w1 w2", 0, [], True) assert "Input vocab_size must be greater than 0" in test_config("w1 w2", -1, [], True) + +def test_lookup_cast_type(): + def gen(texts): + for word in texts.split(" "): + yield (np.array(word, dtype='S'),) + + def test_config(lookup_str, data_type=None): + try: + vocab = text.Vocab.from_list(["w1", "w2", "w3"], special_tokens=[""], special_first=True) + data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"]) + # if data_type is None, test the default value of data_type + op = text.Lookup(vocab, "") if data_type is None else text.Lookup(vocab, "", data_type) + data = data.map(input_columns=["text"], operations=op) + res = [] + for d in data.create_dict_iterator(num_epochs=1): + res.append(d["text"]) + return res[0].dtype + except (ValueError, RuntimeError, TypeError) as e: + return str(e) + + # test result is correct + assert test_config("w1", mstype.int8) == np.dtype("int8") + assert test_config("w2", mstype.int32) == np.dtype("int32") + assert test_config("w3", mstype.int64) == np.dtype("int64") + assert test_config("unk", mstype.float32) != np.dtype("int32") + assert test_config("unk") == np.dtype("int32") + # test exception, data_type isn't the correct type + assert "tldr is not of type (,)" in test_config("unk", "tldr") + + if __name__ == '__main__': test_from_dict_exception() test_from_list_tutorial() @@ -144,3 +175,4 @@ if __name__ == '__main__': test_from_dict_tutorial() test_from_list() test_from_file() + test_lookup_cast_type()