forked from mindspore-Ecosystem/mindspore
Added ZipOp
This commit is contained in:
parent
eeba046115
commit
91b4d90716
|
@ -28,6 +28,7 @@
|
||||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||||
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
||||||
#include "minddata/dataset/engine/datasetops/project_op.h"
|
#include "minddata/dataset/engine/datasetops/project_op.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/zip_op.h"
|
||||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||||
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
|
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||||
|
|
||||||
|
@ -53,6 +54,7 @@ std::shared_ptr<Iterator> Dataset::CreateIterator() {
|
||||||
iter = std::make_shared<Iterator>();
|
iter = std::make_shared<Iterator>();
|
||||||
Status rc = iter->BuildAndLaunchTree(shared_from_this());
|
Status rc = iter->BuildAndLaunchTree(shared_from_this());
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
|
MS_LOG(ERROR) << rc;
|
||||||
MS_LOG(ERROR) << "CreateIterator failed.";
|
MS_LOG(ERROR) << "CreateIterator failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -184,6 +186,21 @@ std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string>
|
||||||
return ds;
|
return ds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to create a Zip dataset
|
||||||
|
std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
||||||
|
// Default values
|
||||||
|
auto ds = std::make_shared<ZipDataset>();
|
||||||
|
|
||||||
|
if (!ds->ValidateParams()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
for (auto dataset : datasets) {
|
||||||
|
ds->children.push_back(dataset);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ds;
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to create default RandomSampler.
|
// Helper function to create default RandomSampler.
|
||||||
std::shared_ptr<SamplerObj> CreateDefaultSampler() {
|
std::shared_ptr<SamplerObj> CreateDefaultSampler() {
|
||||||
int32_t num_samples = 0; // 0 means to sample all ids.
|
int32_t num_samples = 0; // 0 means to sample all ids.
|
||||||
|
@ -441,6 +458,19 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ProjectDataset::Build()
|
||||||
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
|
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to build ZipOp
|
||||||
|
ZipDataset::ZipDataset() {}
|
||||||
|
|
||||||
|
bool ZipDataset::ValidateParams() { return true; }
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ZipDataset::Build() {
|
||||||
|
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||||
|
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||||
|
|
||||||
|
node_ops.push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_));
|
||||||
|
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace api
|
} // namespace api
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -52,7 +52,9 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
|
||||||
// Iterative BFS converting Dataset tree into runtime Execution tree.
|
// Iterative BFS converting Dataset tree into runtime Execution tree.
|
||||||
std::queue<std::pair<std::shared_ptr<Dataset>, std::shared_ptr<DatasetOp>>> q;
|
std::queue<std::pair<std::shared_ptr<Dataset>, std::shared_ptr<DatasetOp>>> q;
|
||||||
|
|
||||||
if (ds != nullptr) {
|
if (ds == nullptr) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Input is null pointer");
|
||||||
|
} else {
|
||||||
// Convert the current root node.
|
// Convert the current root node.
|
||||||
auto root_op = ds->Build()->front();
|
auto root_op = ds->Build()->front();
|
||||||
RETURN_UNEXPECTED_IF_NULL(root_op);
|
RETURN_UNEXPECTED_IF_NULL(root_op);
|
||||||
|
|
|
@ -48,6 +48,7 @@ class MapDataset;
|
||||||
class ShuffleDataset;
|
class ShuffleDataset;
|
||||||
class Cifar10Dataset;
|
class Cifar10Dataset;
|
||||||
class ProjectDataset;
|
class ProjectDataset;
|
||||||
|
class ZipDataset;
|
||||||
|
|
||||||
/// \brief Function to create an ImageFolderDataset
|
/// \brief Function to create an ImageFolderDataset
|
||||||
/// \notes A source dataset that reads images from a tree of directories
|
/// \notes A source dataset that reads images from a tree of directories
|
||||||
|
@ -165,6 +166,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// \return Shared pointer to the current Dataset
|
/// \return Shared pointer to the current Dataset
|
||||||
std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns);
|
std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns);
|
||||||
|
|
||||||
|
/// \brief Function to create a Zip Dataset
|
||||||
|
/// \notes Applies zip to the dataset
|
||||||
|
/// \param[in] datasets A list of shared pointer to the datasets that we want to zip
|
||||||
|
/// \return Shared pointer to the current Dataset
|
||||||
|
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<std::shared_ptr<Dataset>> children;
|
std::vector<std::shared_ptr<Dataset>> children;
|
||||||
std::shared_ptr<Dataset> parent;
|
std::shared_ptr<Dataset> parent;
|
||||||
|
@ -351,6 +358,24 @@ class ProjectDataset : public Dataset {
|
||||||
private:
|
private:
|
||||||
std::vector<std::string> columns_;
|
std::vector<std::string> columns_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ZipDataset : public Dataset {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor
|
||||||
|
ZipDataset();
|
||||||
|
|
||||||
|
/// \brief Destructor
|
||||||
|
~ZipDataset() = default;
|
||||||
|
|
||||||
|
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||||
|
/// \return shared pointer to the list of newly created DatasetOps
|
||||||
|
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
|
||||||
|
|
||||||
|
/// \brief Parameters validation
|
||||||
|
/// \return bool true if all the params are valid
|
||||||
|
bool ValidateParams() override;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace api
|
} // namespace api
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -764,8 +764,60 @@ TEST_F(MindDataTestPipeline, TestProjectMap) {
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_TRUE(i == 20);
|
EXPECT_EQ(i, 20);
|
||||||
|
|
||||||
// Manually terminate the pipeline
|
// Manually terminate the pipeline
|
||||||
iter->Stop();
|
iter->Stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestZip) {
|
||||||
|
// Create an ImageFolder Dataset
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||||
|
EXPECT_TRUE(ds != nullptr);
|
||||||
|
|
||||||
|
// Create a Project operation on ds
|
||||||
|
std::vector<std::string> column_project = {"image"};
|
||||||
|
ds = ds->Project(column_project);
|
||||||
|
EXPECT_TRUE(ds != nullptr);
|
||||||
|
|
||||||
|
folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||||
|
std::shared_ptr<Dataset> ds1 = Cifar10(folder_path, 0, RandomSampler(false, 10));
|
||||||
|
EXPECT_TRUE(ds1 != nullptr);
|
||||||
|
|
||||||
|
// Create a Project operation on ds
|
||||||
|
column_project = {"label"};
|
||||||
|
ds1 = ds1->Project(column_project);
|
||||||
|
EXPECT_TRUE(ds1 != nullptr);
|
||||||
|
|
||||||
|
// Create a Zip operation on the datasets
|
||||||
|
ds = ds->Zip({ds, ds1});
|
||||||
|
EXPECT_TRUE(ds != nullptr);
|
||||||
|
|
||||||
|
// Create a Batch operation on ds
|
||||||
|
int32_t batch_size = 1;
|
||||||
|
ds = ds->Batch(batch_size);
|
||||||
|
EXPECT_TRUE(ds != nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_TRUE(iter != nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
i++;
|
||||||
|
auto image = row["image"];
|
||||||
|
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||||
|
iter->GetNextRow(&row);
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 10);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue