diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index f4ed295940f..2f899e7f537 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -28,6 +28,7 @@ #include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/shuffle_op.h" #include "minddata/dataset/engine/datasetops/project_op.h" +#include "minddata/dataset/engine/datasetops/zip_op.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" @@ -53,6 +54,7 @@ std::shared_ptr Dataset::CreateIterator() { iter = std::make_shared(); Status rc = iter->BuildAndLaunchTree(shared_from_this()); if (rc.IsError()) { + MS_LOG(ERROR) << rc; MS_LOG(ERROR) << "CreateIterator failed."; return nullptr; } @@ -184,6 +186,21 @@ std::shared_ptr Dataset::Project(const std::vector return ds; } +// Function to create a Zip dataset +std::shared_ptr Dataset::Zip(const std::vector> &datasets) { + // Default values + auto ds = std::make_shared(); + + if (!ds->ValidateParams()) { + return nullptr; + } + for (auto dataset : datasets) { + ds->children.push_back(dataset); + } + + return ds; +} + // Helper function to create default RandomSampler. std::shared_ptr CreateDefaultSampler() { const int32_t num_samples = 0; // 0 means to sample all ids. @@ -441,6 +458,19 @@ std::shared_ptr>> ProjectDataset::Build() return std::make_shared>>(node_ops); } +// Function to build ZipOp +ZipDataset::ZipDataset() {} + +bool ZipDataset::ValidateParams() { return true; } + +std::shared_ptr>> ZipDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(rows_per_buffer_, connector_que_size_)); + return std::make_shared>>(node_ops); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/iterator.cc b/mindspore/ccsrc/minddata/dataset/api/iterator.cc index 068bcfaa047..8136601ed84 100644 --- a/mindspore/ccsrc/minddata/dataset/api/iterator.cc +++ b/mindspore/ccsrc/minddata/dataset/api/iterator.cc @@ -52,7 +52,9 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { // Iterative BFS converting Dataset tree into runtime Execution tree. std::queue, std::shared_ptr>> q; - if (ds != nullptr) { + if (ds == nullptr) { + RETURN_STATUS_UNEXPECTED("Input is null pointer"); + } else { // Convert the current root node. auto root_op = ds->Build()->front(); RETURN_UNEXPECTED_IF_NULL(root_op); diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 22fc4110903..7588a25f06e 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -48,6 +48,7 @@ class MapDataset; class ShuffleDataset; class Cifar10Dataset; class ProjectDataset; +class ZipDataset; /// \brief Function to create an ImageFolderDataset /// \notes A source dataset that reads images from a tree of directories @@ -165,6 +166,12 @@ class Dataset : public std::enable_shared_from_this { /// \return Shared pointer to the current Dataset std::shared_ptr Project(const std::vector &columns); + /// \brief Function to create a Zip Dataset + /// \notes Applies zip to the dataset + /// \param[in] datasets A list of shared pointer to the datasets that we want to zip + /// \return Shared pointer to the current Dataset + std::shared_ptr Zip(const std::vector> &datasets); + protected: std::vector> children; std::shared_ptr parent; @@ -351,6 +358,24 @@ class ProjectDataset : public Dataset { private: std::vector columns_; }; + +class ZipDataset : public Dataset { + public: + /// \brief Constructor + ZipDataset(); + + /// \brief Destructor + ~ZipDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; +}; + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/tests/ut/cpp/dataset/c_api_test.cc b/tests/ut/cpp/dataset/c_api_test.cc index 902bc9a43b8..03c7c023a5e 100644 --- a/tests/ut/cpp/dataset/c_api_test.cc +++ b/tests/ut/cpp/dataset/c_api_test.cc @@ -764,8 +764,60 @@ TEST_F(MindDataTestPipeline, TestProjectMap) { iter->GetNextRow(&row); } - EXPECT_TRUE(i == 20); + EXPECT_EQ(i, 20); // Manually terminate the pipeline iter->Stop(); -} \ No newline at end of file +} + +TEST_F(MindDataTestPipeline, TestZip) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Project operation on ds + std::vector column_project = {"image"}; + ds = ds->Project(column_project); + EXPECT_TRUE(ds != nullptr); + + folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds1 = Cifar10(folder_path, 0, RandomSampler(false, 10)); + EXPECT_TRUE(ds1 != nullptr); + + // Create a Project operation on ds + column_project = {"label"}; + ds1 = ds1->Project(column_project); + EXPECT_TRUE(ds1 != nullptr); + + // Create a Zip operation on the datasets + ds = ds->Zip({ds, ds1}); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); +}