From 8921a6099c38856fa2f46e0fa4b8c24b694d19ed Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Mon, 20 Jul 2020 17:32:36 -0400 Subject: [PATCH] C++ API Support for Skip Dataset Op and UTs --- .../ccsrc/minddata/dataset/api/datasets.cc | 37 +++++++++++++ .../ccsrc/minddata/dataset/include/datasets.h | 27 ++++++++++ mindspore/dataset/engine/datasets.py | 4 +- tests/ut/cpp/dataset/c_api_test.cc | 53 +++++++++++++++++++ tests/ut/python/dataset/test_skip.py | 28 ++++++++++ 5 files changed, 147 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 2f899e7f537..12bdcd5d85b 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -27,6 +27,7 @@ #include "minddata/dataset/engine/datasetops/map_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/datasetops/skip_op.h" #include "minddata/dataset/engine/datasetops/project_op.h" #include "minddata/dataset/engine/datasetops/zip_op.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" @@ -173,6 +174,20 @@ std::shared_ptr Dataset::Shuffle(int32_t shuffle_size) { return ds; } +// Function to create a SkipDataset. +std::shared_ptr Dataset::Skip(int32_t count) { + auto ds = std::make_shared(count); + + // Call derived class validation method. + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + // Function to create a ProjectDataset. std::shared_ptr Dataset::Project(const std::vector &columns) { auto ds = std::make_shared(columns); @@ -400,6 +415,28 @@ bool ShuffleDataset::ValidateParams() { return true; } +// Constructor for SkipDataset +SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {} + +// Function to build the SkipOp +std::shared_ptr>> SkipDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(skip_count_, connector_que_size_)); + return std::make_shared>>(node_ops); +} + +// Function to validate the parameters for SkipDataset +bool SkipDataset::ValidateParams() { + if (skip_count_ <= -1) { + MS_LOG(ERROR) << "Skip: Invalid input, skip_count: " << skip_count_; + return false; + } + + return true; +} + // Constructor for Cifar10Dataset Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler) : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {} diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 7588a25f06e..9cf9787841b 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -46,6 +46,7 @@ class BatchDataset; class RepeatDataset; class MapDataset; class ShuffleDataset; +class SkipDataset; class Cifar10Dataset; class ProjectDataset; class ZipDataset; @@ -160,6 +161,12 @@ class Dataset : public std::enable_shared_from_this { /// \return Shared pointer to the current ShuffleDataset std::shared_ptr Shuffle(int32_t shuffle_size); + /// \brief Function to create a SkipDataset + /// \notes Skips count elements in this dataset. + /// \param[in] count Number of elements the dataset to be skipped. + /// \return Shared pointer to the current SkipDataset + std::shared_ptr Skip(int32_t count); + /// \brief Function to create a Project Dataset /// \notes Applies project to the dataset /// \param[in] columns The name of columns to project @@ -293,6 +300,26 @@ class ShuffleDataset : public Dataset { bool reset_every_epoch_; }; +class SkipDataset : public Dataset { + public: + /// \brief Constructor + explicit SkipDataset(int32_t count); + + /// \brief Destructor + ~SkipDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + int32_t skip_count_; +}; + class MapDataset : public Dataset { public: /// \brief Constructor diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 6f11a230916..33b9028115f 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2094,8 +2094,8 @@ class SkipDataset(DatasetOp): The result of applying Skip operator to the input Dataset. Args: - input_dataset (tuple): A tuple of datasets to be skipped. - count (int): Number of rows the dataset should be skipped. + input_dataset (Dataset): Input dataset to have rows skipped. + count (int): Number of rows in the dataset to be skipped. """ def __init__(self, input_dataset, count): diff --git a/tests/ut/cpp/dataset/c_api_test.cc b/tests/ut/cpp/dataset/c_api_test.cc index 03c7c023a5e..560958e2a81 100644 --- a/tests/ut/cpp/dataset/c_api_test.cc +++ b/tests/ut/cpp/dataset/c_api_test.cc @@ -573,6 +573,59 @@ TEST_F(MindDataTestPipeline, TestShuffleDataset) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestSkipDataset) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDataset."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Skip operation on ds + int32_t count = 3; + ds = ds->Skip(count); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + MS_LOG(INFO) << "Number of rows: " << i; + + // Expect 10-3=7 rows + EXPECT_TRUE(i == 7); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestSkipDatasetError1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDatasetError1."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Skip operation on ds with invalid count input + int32_t count = -1; + ds = ds->Skip(count); + // Expect nullptr for invalid input skip_count + EXPECT_TRUE(ds == nullptr); +} + TEST_F(MindDataTestPipeline, TestCifar10Dataset) { // Create a Cifar10 Dataset diff --git a/tests/ut/python/dataset/test_skip.py b/tests/ut/python/dataset/test_skip.py index 5dd7faa66a3..87e4122f848 100644 --- a/tests/ut/python/dataset/test_skip.py +++ b/tests/ut/python/dataset/test_skip.py @@ -13,9 +13,12 @@ # limitations under the License. # ============================================================================== import numpy as np +import pytest import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as vision +from mindspore import log as logger + DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" @@ -196,6 +199,29 @@ def test_skip_filter_2(): assert buf == [5, 6, 7, 8, 9, 10] +def test_skip_exception_1(): + data1 = ds.GeneratorDataset(generator_md, ["data"]) + + try: + data1 = data1.skip(count=-1) + num_iter = 0 + for _ in data1.create_dict_iterator(): + num_iter += 1 + + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Skip count must be positive integer or 0." in str(e) + + +def test_skip_exception_2(): + ds1 = ds.GeneratorDataset(generator_md, ["data"]) + + with pytest.raises(ValueError) as e: + ds1 = ds1.skip(-2) + assert "Input count is not within the required interval" in str(e.value) + + + if __name__ == "__main__": test_tf_skip() test_generator_skip() @@ -208,3 +234,5 @@ if __name__ == "__main__": test_skip_take_2() test_skip_filter_1() test_skip_filter_2() + test_skip_exception_1() + test_skip_exception_2()