forked from OSSInnovation/mindspore
!3449 Cifar100 for C API
Merge pull request !3449 from MahdiRahmaniHanzaki/cifar100-c-api
This commit is contained in:
commit
ee67f70b73
|
@ -86,9 +86,16 @@ Dataset::Dataset() {
|
|||
// (In alphabetical order)
|
||||
|
||||
// Function to create a Cifar10Dataset.
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
|
||||
std::shared_ptr<SamplerObj> sampler) {
|
||||
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, num_samples, sampler);
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
|
||||
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
}
|
||||
|
||||
// Function to create a Cifar100Dataset.
|
||||
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
|
||||
auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
|
@ -250,28 +257,27 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() {
|
|||
return std::make_shared<RandomSamplerObj>(replacement, num_samples);
|
||||
}
|
||||
|
||||
// Helper function to validate dataset params
|
||||
bool ValidateCommonDatasetParams(std::string dataset_dir) {
|
||||
if (dataset_dir.empty()) {
|
||||
MS_LOG(ERROR) << "No dataset path is specified";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/* ####################################### Derived Dataset classes ################################# */
|
||||
|
||||
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
|
||||
// (In alphabetical order)
|
||||
|
||||
// Constructor for Cifar10Dataset
|
||||
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}
|
||||
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
||||
|
||||
bool Cifar10Dataset::ValidateParams() {
|
||||
if (dataset_dir_.empty()) {
|
||||
MS_LOG(ERROR) << "No dataset path is specified.";
|
||||
return false;
|
||||
}
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "Number of samples cannot be negative";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Cifar10Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
||||
|
||||
// Function to build CifarOp
|
||||
// Function to build CifarOp for Cifar10
|
||||
std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
@ -294,6 +300,37 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Constructor for Cifar100Dataset
|
||||
Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
||||
|
||||
bool Cifar100Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
||||
|
||||
// Function to build CifarOp for Cifar100
|
||||
std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
|
||||
if (sampler_ == nullptr) {
|
||||
sampler_ = CreateDefaultSampler();
|
||||
}
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
|
||||
bool recursive, std::set<std::string> extensions,
|
||||
std::map<std::string, int32_t> class_indexing)
|
||||
|
@ -304,14 +341,7 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std
|
|||
class_indexing_(class_indexing),
|
||||
exts_(extensions) {}
|
||||
|
||||
bool ImageFolderDataset::ValidateParams() {
|
||||
if (dataset_dir_.empty()) {
|
||||
MS_LOG(ERROR) << "No dataset path is specified.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
bool ImageFolderDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
|
@ -339,14 +369,7 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
|
|||
MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
||||
|
||||
bool MnistDataset::ValidateParams() {
|
||||
if (dataset_dir_.empty()) {
|
||||
MS_LOG(ERROR) << "No dataset path is specified.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
bool MnistDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
|
|
|
@ -42,6 +42,7 @@ class TensorOperation;
|
|||
class SamplerObj;
|
||||
// Datasets classes (in alphabetical order)
|
||||
class Cifar10Dataset;
|
||||
class Cifar100Dataset;
|
||||
class ImageFolderDataset;
|
||||
class MnistDataset;
|
||||
// Dataset Op classes (in alphabetical order)
|
||||
|
@ -57,12 +58,19 @@ class ZipDataset;
|
|||
/// \brief Function to create a Cifar10 Dataset
|
||||
/// \notes The generated dataset has two columns ['image', 'label']
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset
|
||||
/// \param[in] num_samples The number of images to be included in the dataset
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
|
||||
/// will be used to randomly iterate the entire dataset
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
|
||||
std::shared_ptr<SamplerObj> sampler);
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
|
||||
|
||||
/// \brief Function to create a Cifar100 Dataset
|
||||
/// \notes The generated dataset has two columns ['image', 'coarse_label', 'fine_label']
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
|
||||
/// will be used to randomly iterate the entire dataset
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
|
||||
std::shared_ptr<SamplerObj> sampler = nullptr);
|
||||
|
||||
/// \brief Function to create an ImageFolderDataset
|
||||
/// \notes A source dataset that reads images from a tree of directories
|
||||
|
@ -204,7 +212,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
class Cifar10Dataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler);
|
||||
Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar10Dataset() = default;
|
||||
|
@ -219,7 +227,27 @@ class Cifar10Dataset : public Dataset {
|
|||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
int32_t num_samples_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
class Cifar100Dataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar100Dataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
|
|
|
@ -84,6 +84,12 @@ TEST_F(MindDataTestPipeline, TestBatchAndRepeat) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestMnistFail1) {
|
||||
// Create a Mnist Dataset
|
||||
std::shared_ptr<Dataset> ds = Mnist("", RandomSampler(false, 10));
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) {
|
||||
// Create a Mnist Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
|
@ -274,6 +280,12 @@ TEST_F(MindDataTestPipeline, TestImageFolderBatchAndRepeat) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestImageFolderFail1) {
|
||||
// Create an ImageFolder Dataset
|
||||
std::shared_ptr<Dataset> ds = ImageFolder("", true, nullptr);
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) {
|
||||
std::shared_ptr<SamplerObj> sampl = DistributedSampler(2, 1);
|
||||
EXPECT_NE(sampl, nullptr);
|
||||
|
@ -630,17 +642,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
|
|||
|
||||
// Create a Cifar10 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, 0, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds = ds->Repeat(repeat_num);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds
|
||||
int32_t batch_size = 2;
|
||||
ds = ds->Batch(batch_size);
|
||||
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
|
@ -652,6 +654,9 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
|
|||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
|
@ -666,6 +671,54 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCifar10DatasetFail1) {
|
||||
|
||||
// Create a Cifar10 Dataset
|
||||
std::shared_ptr<Dataset> ds = Cifar10("", RandomSampler(false, 10));
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
|
||||
|
||||
// Create a Cifar100 Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
|
||||
std::shared_ptr<Dataset> ds = Cifar100(folder_path, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("coarse_label"), row.end());
|
||||
EXPECT_NE(row.find("fine_label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCifar100DatasetFail1) {
|
||||
|
||||
// Create a Cifar100 Dataset
|
||||
std::shared_ptr<Dataset> ds = Cifar100("", RandomSampler(false, 10));
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomColorAdjust) {
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
|
@ -843,7 +896,7 @@ TEST_F(MindDataTestPipeline, TestZipSuccess) {
|
|||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, 0, RandomSampler(false, 10));
|
||||
std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Project operation on ds
|
||||
|
|
Loading…
Reference in New Issue