forked from mindspore-Ecosystem/mindspore
C++ API Support for Skip Dataset Op and UTs
This commit is contained in:
parent
4bbbf2dc7a
commit
8921a6099c
|
@ -27,6 +27,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/map_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/skip_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/project_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/zip_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
|
@ -173,6 +174,20 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) {
|
|||
return ds;
|
||||
}
|
||||
|
||||
// Function to create a SkipDataset.
|
||||
std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
|
||||
auto ds = std::make_shared<SkipDataset>(count);
|
||||
|
||||
// Call derived class validation method.
|
||||
if (!ds->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ds->children.push_back(shared_from_this());
|
||||
|
||||
return ds;
|
||||
}
|
||||
|
||||
// Function to create a ProjectDataset.
|
||||
std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) {
|
||||
auto ds = std::make_shared<ProjectDataset>(columns);
|
||||
|
@ -400,6 +415,28 @@ bool ShuffleDataset::ValidateParams() {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Constructor for SkipDataset
|
||||
SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {}
|
||||
|
||||
// Function to build the SkipOp
|
||||
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> SkipDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
|
||||
return std::make_shared<std::vector<std::shared_ptr<DatasetOp>>>(node_ops);
|
||||
}
|
||||
|
||||
// Function to validate the parameters for SkipDataset
|
||||
bool SkipDataset::ValidateParams() {
|
||||
if (skip_count_ <= -1) {
|
||||
MS_LOG(ERROR) << "Skip: Invalid input, skip_count: " << skip_count_;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Constructor for Cifar10Dataset
|
||||
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}
|
||||
|
|
|
@ -46,6 +46,7 @@ class BatchDataset;
|
|||
class RepeatDataset;
|
||||
class MapDataset;
|
||||
class ShuffleDataset;
|
||||
class SkipDataset;
|
||||
class Cifar10Dataset;
|
||||
class ProjectDataset;
|
||||
class ZipDataset;
|
||||
|
@ -160,6 +161,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return Shared pointer to the current ShuffleDataset
|
||||
std::shared_ptr<ShuffleDataset> Shuffle(int32_t shuffle_size);
|
||||
|
||||
/// \brief Function to create a SkipDataset
|
||||
/// \notes Skips count elements in this dataset.
|
||||
/// \param[in] count Number of elements the dataset to be skipped.
|
||||
/// \return Shared pointer to the current SkipDataset
|
||||
std::shared_ptr<SkipDataset> Skip(int32_t count);
|
||||
|
||||
/// \brief Function to create a Project Dataset
|
||||
/// \notes Applies project to the dataset
|
||||
/// \param[in] columns The name of columns to project
|
||||
|
@ -293,6 +300,26 @@ class ShuffleDataset : public Dataset {
|
|||
bool reset_every_epoch_;
|
||||
};
|
||||
|
||||
class SkipDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit SkipDataset(int32_t count);
|
||||
|
||||
/// \brief Destructor
|
||||
~SkipDataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t skip_count_;
|
||||
};
|
||||
|
||||
class MapDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
|
|
@ -2094,8 +2094,8 @@ class SkipDataset(DatasetOp):
|
|||
The result of applying Skip operator to the input Dataset.
|
||||
|
||||
Args:
|
||||
input_dataset (tuple): A tuple of datasets to be skipped.
|
||||
count (int): Number of rows the dataset should be skipped.
|
||||
input_dataset (Dataset): Input dataset to have rows skipped.
|
||||
count (int): Number of rows in the dataset to be skipped.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, count):
|
||||
|
|
|
@ -573,6 +573,59 @@ TEST_F(MindDataTestPipeline, TestShuffleDataset) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSkipDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDataset.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_TRUE(ds != nullptr);
|
||||
|
||||
// Create a Skip operation on ds
|
||||
int32_t count = 3;
|
||||
ds = ds->Skip(count);
|
||||
EXPECT_TRUE(ds != nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_TRUE(iter != nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
MS_LOG(INFO) << "Number of rows: " << i;
|
||||
|
||||
// Expect 10-3=7 rows
|
||||
EXPECT_TRUE(i == 7);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSkipDatasetError1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipDatasetError1.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_TRUE(ds != nullptr);
|
||||
|
||||
// Create a Skip operation on ds with invalid count input
|
||||
int32_t count = -1;
|
||||
ds = ds->Skip(count);
|
||||
// Expect nullptr for invalid input skip_count
|
||||
EXPECT_TRUE(ds == nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
|
||||
|
||||
// Create a Cifar10 Dataset
|
||||
|
|
|
@ -13,9 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
@ -196,6 +199,29 @@ def test_skip_filter_2():
|
|||
assert buf == [5, 6, 7, 8, 9, 10]
|
||||
|
||||
|
||||
def test_skip_exception_1():
|
||||
data1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
try:
|
||||
data1 = data1.skip(count=-1)
|
||||
num_iter = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_iter += 1
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Skip count must be positive integer or 0." in str(e)
|
||||
|
||||
|
||||
def test_skip_exception_2():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
ds1 = ds1.skip(-2)
|
||||
assert "Input count is not within the required interval" in str(e.value)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tf_skip()
|
||||
test_generator_skip()
|
||||
|
@ -208,3 +234,5 @@ if __name__ == "__main__":
|
|||
test_skip_take_2()
|
||||
test_skip_filter_1()
|
||||
test_skip_filter_2()
|
||||
test_skip_exception_1()
|
||||
test_skip_exception_2()
|
||||
|
|
Loading…
Reference in New Issue