forked from OSSInnovation/mindspore
C++ API: Reorder code contents alphabetically
This commit is contained in:
parent
e07f74367d
commit
81005a3095
|
@ -17,12 +17,14 @@
|
|||
#include <fstream>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/include/samplers.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/engine/dataset_iterator.h"
|
||||
// Source dataset headers (in alphabetical order)
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
// Dataset operator headers (in alphabetical order)
|
||||
#include "minddata/dataset/engine/datasetops/batch_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||
|
@ -31,6 +33,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/project_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/zip_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/rename_op.h"
|
||||
// Sampler headers (in alphabetical order)
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
|
||||
|
@ -79,6 +82,18 @@ Dataset::Dataset() {
|
|||
connector_que_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
|
||||
// (In alphabetical order)
|
||||
|
||||
// Function to create a Cifar10Dataset.
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
|
||||
std::shared_ptr<SamplerObj> sampler) {
|
||||
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, num_samples, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
}
|
||||
|
||||
// Function to create a ImageFolderDataset.
|
||||
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode,
|
||||
std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions,
|
||||
|
@ -101,14 +116,8 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
|
|||
return ds->ValidateParams() ? ds : nullptr;
|
||||
}
|
||||
|
||||
// Function to create a Cifar10Dataset.
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
|
||||
std::shared_ptr<SamplerObj> sampler) {
|
||||
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, num_samples, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
}
|
||||
// FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
|
||||
// (In alphabetical order)
|
||||
|
||||
// Function to create a Batch dataset
|
||||
std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
|
||||
|
@ -127,14 +136,12 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
|
|||
return ds;
|
||||
}
|
||||
|
||||
// Function to create Repeat dataset.
|
||||
std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
|
||||
// Workaround for repeat == 1, do not inject repeat.
|
||||
if (count == 1) {
|
||||
return shared_from_this();
|
||||
}
|
||||
|
||||
auto ds = std::make_shared<RepeatDataset>(count);
|
||||
// Function to create a Map dataset.
|
||||
std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns,
|
||||
std::vector<std::string> output_columns,
|
||||
const std::vector<std::string> &project_columns) {
|
||||
auto ds = std::make_shared<MapDataset>(operations, input_columns, output_columns, project_columns);
|
||||
|
||||
if (!ds->ValidateParams()) {
|
||||
return nullptr;
|
||||
|
@ -145,12 +152,41 @@ std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
|
|||
return ds;
|
||||
}
|
||||
|
||||
// Function to create a Map dataset.
|
||||
std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns,
|
||||
std::vector<std::string> output_columns,
|
||||
const std::vector<std::string> &project_columns) {
|
||||
auto ds = std::make_shared<MapDataset>(operations, input_columns, output_columns, project_columns);
|
||||
// Function to create a ProjectDataset.
|
||||
std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) {
|
||||
auto ds = std::make_shared<ProjectDataset>(columns);
|
||||
// Call derived class validation method.
|
||||
if (!ds->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ds->children.push_back(shared_from_this());
|
||||
|
||||
return ds;
|
||||
}
|
||||
|
||||
// Function to create a RenameDataset.
|
||||
std::shared_ptr<RenameDataset> Dataset::Rename(const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns) {
|
||||
auto ds = std::make_shared<RenameDataset>(input_columns, output_columns);
|
||||
// Call derived class validation method.
|
||||
if (!ds->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ds->children.push_back(shared_from_this());
|
||||
|
||||
return ds;
|
||||
}
|
||||
|
||||
// Function to create Repeat dataset.
|
||||
std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
|
||||
// Workaround for repeat == 1, do not inject repeat.
|
||||
if (count == 1) {
|
||||
return shared_from_this();
|
||||
}
|
||||
|
||||
auto ds = std::make_shared<RepeatDataset>(count);
|
||||
|
||||
if (!ds->ValidateParams()) {
|
||||
return nullptr;
|
||||
|
@ -189,33 +225,6 @@ std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
|
|||
return ds;
|
||||
}
|
||||
|
||||
// Function to create a ProjectDataset.
|
||||
std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) {
|
||||
auto ds = std::make_shared<ProjectDataset>(columns);
|
||||
// Call derived class validation method.
|
||||
if (!ds->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ds->children.push_back(shared_from_this());
|
||||
|
||||
return ds;
|
||||
}
|
||||
|
||||
// Function to create a RenameDataset.
|
||||
std::shared_ptr<RenameDataset> Dataset::Rename(const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns) {
|
||||
auto ds = std::make_shared<RenameDataset>(input_columns, output_columns);
|
||||
// Call derived class validation method.
|
||||
if (!ds->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ds->children.push_back(shared_from_this());
|
||||
|
||||
return ds;
|
||||
}
|
||||
|
||||
// Function to create a Zip dataset
|
||||
std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
||||
// Default values
|
||||
|
@ -231,6 +240,9 @@ std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Datas
|
|||
return ds;
|
||||
}
|
||||
|
||||
// OTHER FUNCTIONS
|
||||
// (In alphabetical order)
|
||||
|
||||
// Helper function to create default RandomSampler.
|
||||
std::shared_ptr<SamplerObj> CreateDefaultSampler() {
|
||||
const int32_t num_samples = 0; // 0 means to sample all ids.
|
||||
|
@ -240,6 +252,48 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() {
|
|||
|
||||
/* ####################################### Derived Dataset classes ################################# */
|
||||
|
||||
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
|
||||
// (In alphabetical order)
|
||||
|
||||
// Constructor for Cifar10Dataset
|
||||
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}
|
||||
|
||||
bool Cifar10Dataset::ValidateParams() {
|
||||
if (dataset_dir_.empty()) {
|
||||
MS_LOG(ERROR) << "No dataset path is specified.";
|
||||
return false;
|
||||
}
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "Number of samples cannot be negative";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to build CifarOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
|
||||
if (sampler_ == nullptr) {
|
||||
sampler_ = CreateDefaultSampler();
|
||||
}
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
|
||||
bool recursive, std::set<std::string> extensions,
|
||||
std::map<std::string, int32_t> class_indexing)
|
||||
|
@ -315,6 +369,9 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// DERIVED DATASET CLASSES LEAF-NODE DATASETS
|
||||
// (In alphabetical order)
|
||||
|
||||
BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map,
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
|
||||
: batch_size_(batch_size),
|
||||
|
@ -347,24 +404,6 @@ bool BatchDataset::ValidateParams() {
|
|||
return true;
|
||||
}
|
||||
|
||||
RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
bool RepeatDataset::ValidateParams() {
|
||||
if (repeat_count_ <= 0) {
|
||||
MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
MapDataset::MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns,
|
||||
std::vector<std::string> output_columns, const std::vector<std::string> &project_columns)
|
||||
: operations_(operations),
|
||||
|
@ -409,6 +448,69 @@ bool MapDataset::ValidateParams() {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Function to build ProjectOp
|
||||
ProjectDataset::ProjectDataset(const std::vector<std::string> &columns) : columns_(columns) {}
|
||||
|
||||
bool ProjectDataset::ValidateParams() {
|
||||
if (columns_.empty()) {
|
||||
MS_LOG(ERROR) << "No columns are specified.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ProjectDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<ProjectOp>(columns_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to build RenameOp
|
||||
RenameDataset::RenameDataset(const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns)
|
||||
: input_columns_(input_columns), output_columns_(output_columns) {}
|
||||
|
||||
bool RenameDataset::ValidateParams() {
|
||||
if (input_columns_.empty() || output_columns_.empty()) {
|
||||
MS_LOG(ERROR) << "input and output columns must be specified";
|
||||
return false;
|
||||
}
|
||||
if (input_columns_.size() != output_columns_.size()) {
|
||||
MS_LOG(ERROR) << "input and output columns must be the same size";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
bool RepeatDataset::ValidateParams() {
|
||||
if (repeat_count_ <= 0) {
|
||||
MS_LOG(ERROR) << "Repeat: Repeat count cannot be negative";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Constructor for ShuffleDataset
|
||||
ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch)
|
||||
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {}
|
||||
|
@ -455,64 +557,6 @@ bool SkipDataset::ValidateParams() {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Constructor for Cifar10Dataset
|
||||
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {}
|
||||
|
||||
bool Cifar10Dataset::ValidateParams() {
|
||||
if (dataset_dir_.empty()) {
|
||||
MS_LOG(ERROR) << "No dataset path is specified.";
|
||||
return false;
|
||||
}
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "Number of samples cannot be negative";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to build CifarOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// If user does not specify Sampler, create a default sampler based on the shuffle variable.
|
||||
if (sampler_ == nullptr) {
|
||||
sampler_ = CreateDefaultSampler();
|
||||
}
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to build ProjectOp
|
||||
ProjectDataset::ProjectDataset(const std::vector<std::string> &columns) : columns_(columns) {}
|
||||
|
||||
bool ProjectDataset::ValidateParams() {
|
||||
if (columns_.empty()) {
|
||||
MS_LOG(ERROR) << "No columns are specified.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ProjectDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<ProjectOp>(columns_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to build ZipOp
|
||||
ZipDataset::ZipDataset() {}
|
||||
|
||||
|
@ -526,31 +570,6 @@ std::vector<std::shared_ptr<DatasetOp>> ZipDataset::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to build RenameOp
|
||||
RenameDataset::RenameDataset(const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns)
|
||||
: input_columns_(input_columns), output_columns_(output_columns) {}
|
||||
|
||||
bool RenameDataset::ValidateParams() {
|
||||
if (input_columns_.empty() || output_columns_.empty()) {
|
||||
MS_LOG(ERROR) << "input and output columns must be specified";
|
||||
return false;
|
||||
}
|
||||
if (input_columns_.size() != output_columns_.size()) {
|
||||
MS_LOG(ERROR) << "input and output columns must be the same size";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,18 +16,19 @@
|
|||
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/kernels/image/normalize_op.h"
|
||||
#include "minddata/dataset/kernels/image/decode_op.h"
|
||||
#include "minddata/dataset/kernels/image/resize_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
||||
|
||||
#include "minddata/dataset/kernels/image/center_crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/uniform_aug_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
||||
#include "minddata/dataset/kernels/image/cut_out_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
|
||||
#include "minddata/dataset/kernels/image/decode_op.h"
|
||||
#include "minddata/dataset/kernels/image/normalize_op.h"
|
||||
#include "minddata/dataset/kernels/image/pad_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
|
||||
#include "minddata/dataset/kernels/image/resize_op.h"
|
||||
#include "minddata/dataset/kernels/image/uniform_aug_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -38,9 +39,19 @@ TensorOperation::TensorOperation() {}
|
|||
// Transform operations for computer vision.
|
||||
namespace vision {
|
||||
|
||||
// Function to create NormalizeOperation.
|
||||
std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std) {
|
||||
auto op = std::make_shared<NormalizeOperation>(mean, std);
|
||||
// Function to create CenterCropOperation.
|
||||
std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size) {
|
||||
auto op = std::make_shared<CenterCropOperation>(size);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create CutOutOp.
|
||||
std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches) {
|
||||
auto op = std::make_shared<CutOutOperation>(length, num_patches);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
|
@ -58,73 +69,9 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb) {
|
|||
return op;
|
||||
}
|
||||
|
||||
// Function to create ResizeOperation.
|
||||
std::shared_ptr<ResizeOperation> Resize(std::vector<int32_t> size, InterpolationMode interpolation) {
|
||||
auto op = std::make_shared<ResizeOperation>(size, interpolation);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomCropOperation.
|
||||
std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding,
|
||||
bool pad_if_needed, std::vector<uint8_t> fill_value) {
|
||||
auto op = std::make_shared<RandomCropOperation>(size, padding, pad_if_needed, fill_value);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create CenterCropOperation.
|
||||
std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size) {
|
||||
auto op = std::make_shared<CenterCropOperation>(size);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create UniformAugOperation.
|
||||
std::shared_ptr<UniformAugOperation> UniformAugment(std::vector<std::shared_ptr<TensorOperation>> transforms,
|
||||
int32_t num_ops) {
|
||||
auto op = std::make_shared<UniformAugOperation>(transforms, num_ops);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomHorizontalFlipOperation.
|
||||
std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob) {
|
||||
auto op = std::make_shared<RandomHorizontalFlipOperation>(prob);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomVerticalFlipOperation.
|
||||
std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) {
|
||||
auto op = std::make_shared<RandomVerticalFlipOperation>(prob);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomRotationOperation.
|
||||
std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degrees, InterpolationMode resample,
|
||||
bool expand, std::vector<float> center,
|
||||
std::vector<uint8_t> fill_value) {
|
||||
auto op = std::make_shared<RandomRotationOperation>(degrees, resample, expand, center, fill_value);
|
||||
// Function to create NormalizeOperation.
|
||||
std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std) {
|
||||
auto op = std::make_shared<NormalizeOperation>(mean, std);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
|
@ -143,16 +90,6 @@ std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint
|
|||
return op;
|
||||
}
|
||||
|
||||
// Function to create CutOutOp.
|
||||
std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches) {
|
||||
auto op = std::make_shared<CutOutOperation>(length, num_patches);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomColorAdjustOperation.
|
||||
std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness,
|
||||
std::vector<float> contrast,
|
||||
|
@ -165,106 +102,72 @@ std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float>
|
|||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomCropOperation.
|
||||
std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding,
|
||||
bool pad_if_needed, std::vector<uint8_t> fill_value) {
|
||||
auto op = std::make_shared<RandomCropOperation>(size, padding, pad_if_needed, fill_value);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomHorizontalFlipOperation.
|
||||
std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob) {
|
||||
auto op = std::make_shared<RandomHorizontalFlipOperation>(prob);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomRotationOperation.
|
||||
std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degrees, InterpolationMode resample,
|
||||
bool expand, std::vector<float> center,
|
||||
std::vector<uint8_t> fill_value) {
|
||||
auto op = std::make_shared<RandomRotationOperation>(degrees, resample, expand, center, fill_value);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomVerticalFlipOperation.
|
||||
std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) {
|
||||
auto op = std::make_shared<RandomVerticalFlipOperation>(prob);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create ResizeOperation.
|
||||
std::shared_ptr<ResizeOperation> Resize(std::vector<int32_t> size, InterpolationMode interpolation) {
|
||||
auto op = std::make_shared<ResizeOperation>(size, interpolation);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create UniformAugOperation.
|
||||
std::shared_ptr<UniformAugOperation> UniformAugment(std::vector<std::shared_ptr<TensorOperation>> transforms,
|
||||
int32_t num_ops) {
|
||||
auto op = std::make_shared<UniformAugOperation>(transforms, num_ops);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
/* ####################################### Derived TensorOperation classes ################################# */
|
||||
|
||||
// NormalizeOperation
|
||||
NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {}
|
||||
|
||||
bool NormalizeOperation::ValidateParams() {
|
||||
if (mean_.size() != 3) {
|
||||
MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (std_.size() != 3) {
|
||||
MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> NormalizeOperation::Build() {
|
||||
return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
|
||||
}
|
||||
|
||||
// DecodeOperation
|
||||
DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {}
|
||||
|
||||
bool DecodeOperation::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); }
|
||||
|
||||
// ResizeOperation
|
||||
ResizeOperation::ResizeOperation(std::vector<int32_t> size, InterpolationMode interpolation)
|
||||
: size_(size), interpolation_(interpolation) {}
|
||||
|
||||
bool ResizeOperation::ValidateParams() {
|
||||
if (size_.empty() || size_.size() > 2) {
|
||||
MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> ResizeOperation::Build() {
|
||||
int32_t height = size_[0];
|
||||
int32_t width = 0;
|
||||
|
||||
// User specified the width value.
|
||||
if (size_.size() == 2) {
|
||||
width = size_[1];
|
||||
}
|
||||
|
||||
return std::make_shared<ResizeOp>(height, width, interpolation_);
|
||||
}
|
||||
|
||||
// RandomCropOperation
|
||||
RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
|
||||
std::vector<uint8_t> fill_value)
|
||||
: size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value) {}
|
||||
|
||||
bool RandomCropOperation::ValidateParams() {
|
||||
if (size_.empty() || size_.size() > 2) {
|
||||
MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (padding_.empty() || padding_.size() != 4) {
|
||||
MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (fill_value_.empty() || fill_value_.size() != 3) {
|
||||
MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomCropOperation::Build() {
|
||||
int32_t crop_height = size_[0];
|
||||
int32_t crop_width = 0;
|
||||
|
||||
int32_t pad_top = padding_[0];
|
||||
int32_t pad_bottom = padding_[1];
|
||||
int32_t pad_left = padding_[2];
|
||||
int32_t pad_right = padding_[3];
|
||||
|
||||
uint8_t fill_r = fill_value_[0];
|
||||
uint8_t fill_g = fill_value_[1];
|
||||
uint8_t fill_b = fill_value_[2];
|
||||
|
||||
// User has specified the crop_width value.
|
||||
if (size_.size() == 2) {
|
||||
crop_width = size_[1];
|
||||
}
|
||||
|
||||
auto tensor_op = std::make_shared<RandomCropOp>(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right,
|
||||
BorderType::kConstant, pad_if_needed_, fill_r, fill_g, fill_b);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// CenterCropOperation
|
||||
CenterCropOperation::CenterCropOperation(std::vector<int32_t> size) : size_(size) {}
|
||||
|
||||
|
@ -289,73 +192,54 @@ std::shared_ptr<TensorOp> CenterCropOperation::Build() {
|
|||
return tensor_op;
|
||||
}
|
||||
|
||||
// UniformAugOperation
|
||||
UniformAugOperation::UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> transforms, int32_t num_ops)
|
||||
: transforms_(transforms), num_ops_(num_ops) {}
|
||||
// CutOutOperation
|
||||
CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {}
|
||||
|
||||
bool UniformAugOperation::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<TensorOp> UniformAugOperation::Build() {
|
||||
std::vector<std::shared_ptr<TensorOp>> tensor_ops;
|
||||
(void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
|
||||
[](std::shared_ptr<TensorOperation> op) -> std::shared_ptr<TensorOp> { return op->Build(); });
|
||||
std::shared_ptr<UniformAugOp> tensor_op = std::make_shared<UniformAugOp>(tensor_ops, num_ops_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomHorizontalFlipOperation
|
||||
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {}
|
||||
|
||||
bool RandomHorizontalFlipOperation::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
|
||||
std::shared_ptr<RandomHorizontalFlipOp> tensor_op = std::make_shared<RandomHorizontalFlipOp>(probability_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomVerticalFlipOperation
|
||||
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
|
||||
|
||||
bool RandomVerticalFlipOperation::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {
|
||||
std::shared_ptr<RandomVerticalFlipOp> tensor_op = std::make_shared<RandomVerticalFlipOp>(probability_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// Function to create RandomRotationOperation.
|
||||
RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
|
||||
bool expand, std::vector<float> center,
|
||||
std::vector<uint8_t> fill_value)
|
||||
: degrees_(degrees),
|
||||
interpolation_mode_(interpolation_mode),
|
||||
expand_(expand),
|
||||
center_(center),
|
||||
fill_value_(fill_value) {}
|
||||
|
||||
bool RandomRotationOperation::ValidateParams() {
|
||||
if (degrees_.empty() || degrees_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()";
|
||||
bool CutOutOperation::ValidateParams() {
|
||||
if (length_ < 0) {
|
||||
MS_LOG(ERROR) << "CutOut: length cannot be negative";
|
||||
return false;
|
||||
}
|
||||
if (center_.empty() || center_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()";
|
||||
return false;
|
||||
}
|
||||
if (fill_value_.empty() || fill_value_.size() != 3) {
|
||||
MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()";
|
||||
if (num_patches_ < 0) {
|
||||
MS_LOG(ERROR) << "CutOut: number of patches cannot be negative";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
|
||||
std::shared_ptr<RandomRotationOp> tensor_op =
|
||||
std::make_shared<RandomRotationOp>(degrees_[0], degrees_[1], center_[0], center_[1], interpolation_mode_, expand_,
|
||||
fill_value_[0], fill_value_[1], fill_value_[2]);
|
||||
std::shared_ptr<TensorOp> CutOutOperation::Build() {
|
||||
std::shared_ptr<CutOutOp> tensor_op = std::make_shared<CutOutOp>(length_, length_, num_patches_, false, 0, 0, 0);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// DecodeOperation
|
||||
DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {}
|
||||
|
||||
bool DecodeOperation::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); }
|
||||
|
||||
// NormalizeOperation
|
||||
NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {}
|
||||
|
||||
bool NormalizeOperation::ValidateParams() {
|
||||
if (mean_.size() != 3) {
|
||||
MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (std_.size() != 3) {
|
||||
MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> NormalizeOperation::Build() {
|
||||
return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
|
||||
}
|
||||
|
||||
// PadOperation
|
||||
PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
|
||||
: padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {}
|
||||
|
@ -411,26 +295,6 @@ std::shared_ptr<TensorOp> PadOperation::Build() {
|
|||
return tensor_op;
|
||||
}
|
||||
|
||||
// CutOutOperation
|
||||
CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {}
|
||||
|
||||
bool CutOutOperation::ValidateParams() {
|
||||
if (length_ < 0) {
|
||||
MS_LOG(ERROR) << "CutOut: length cannot be negative";
|
||||
return false;
|
||||
}
|
||||
if (num_patches_ < 0) {
|
||||
MS_LOG(ERROR) << "CutOut: number of patches cannot be negative";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> CutOutOperation::Build() {
|
||||
std::shared_ptr<CutOutOp> tensor_op = std::make_shared<CutOutOp>(length_, length_, num_patches_, false, 0, 0, 0);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomColorAdjustOperation.
|
||||
RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
|
||||
std::vector<float> saturation, std::vector<float> hue)
|
||||
|
@ -485,6 +349,143 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
|
|||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomCropOperation
|
||||
RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
|
||||
std::vector<uint8_t> fill_value)
|
||||
: size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value) {}
|
||||
|
||||
bool RandomCropOperation::ValidateParams() {
|
||||
if (size_.empty() || size_.size() > 2) {
|
||||
MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (padding_.empty() || padding_.size() != 4) {
|
||||
MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (fill_value_.empty() || fill_value_.size() != 3) {
|
||||
MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomCropOperation::Build() {
|
||||
int32_t crop_height = size_[0];
|
||||
int32_t crop_width = 0;
|
||||
|
||||
int32_t pad_top = padding_[0];
|
||||
int32_t pad_bottom = padding_[1];
|
||||
int32_t pad_left = padding_[2];
|
||||
int32_t pad_right = padding_[3];
|
||||
|
||||
uint8_t fill_r = fill_value_[0];
|
||||
uint8_t fill_g = fill_value_[1];
|
||||
uint8_t fill_b = fill_value_[2];
|
||||
|
||||
// User has specified the crop_width value.
|
||||
if (size_.size() == 2) {
|
||||
crop_width = size_[1];
|
||||
}
|
||||
|
||||
auto tensor_op = std::make_shared<RandomCropOp>(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right,
|
||||
BorderType::kConstant, pad_if_needed_, fill_r, fill_g, fill_b);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomHorizontalFlipOperation
|
||||
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {}
|
||||
|
||||
bool RandomHorizontalFlipOperation::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
|
||||
std::shared_ptr<RandomHorizontalFlipOp> tensor_op = std::make_shared<RandomHorizontalFlipOp>(probability_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// Function to create RandomRotationOperation.
|
||||
RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
|
||||
bool expand, std::vector<float> center,
|
||||
std::vector<uint8_t> fill_value)
|
||||
: degrees_(degrees),
|
||||
interpolation_mode_(interpolation_mode),
|
||||
expand_(expand),
|
||||
center_(center),
|
||||
fill_value_(fill_value) {}
|
||||
|
||||
bool RandomRotationOperation::ValidateParams() {
|
||||
if (degrees_.empty() || degrees_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()";
|
||||
return false;
|
||||
}
|
||||
if (center_.empty() || center_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()";
|
||||
return false;
|
||||
}
|
||||
if (fill_value_.empty() || fill_value_.size() != 3) {
|
||||
MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
|
||||
std::shared_ptr<RandomRotationOp> tensor_op =
|
||||
std::make_shared<RandomRotationOp>(degrees_[0], degrees_[1], center_[0], center_[1], interpolation_mode_, expand_,
|
||||
fill_value_[0], fill_value_[1], fill_value_[2]);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomVerticalFlipOperation
|
||||
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
|
||||
|
||||
bool RandomVerticalFlipOperation::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {
|
||||
std::shared_ptr<RandomVerticalFlipOp> tensor_op = std::make_shared<RandomVerticalFlipOp>(probability_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// ResizeOperation
|
||||
ResizeOperation::ResizeOperation(std::vector<int32_t> size, InterpolationMode interpolation)
|
||||
: size_(size), interpolation_(interpolation) {}
|
||||
|
||||
bool ResizeOperation::ValidateParams() {
|
||||
if (size_.empty() || size_.size() > 2) {
|
||||
MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> ResizeOperation::Build() {
|
||||
int32_t height = size_[0];
|
||||
int32_t width = 0;
|
||||
|
||||
// User specified the width value.
|
||||
if (size_.size() == 2) {
|
||||
width = size_[1];
|
||||
}
|
||||
|
||||
return std::make_shared<ResizeOp>(height, width, interpolation_);
|
||||
}
|
||||
|
||||
// UniformAugOperation
|
||||
UniformAugOperation::UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> transforms, int32_t num_ops)
|
||||
: transforms_(transforms), num_ops_(num_ops) {}
|
||||
|
||||
bool UniformAugOperation::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<TensorOp> UniformAugOperation::Build() {
|
||||
std::vector<std::shared_ptr<TensorOp>> tensor_ops;
|
||||
(void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
|
||||
[](std::shared_ptr<TensorOperation> op) -> std::shared_ptr<TensorOp> { return op->Build(); });
|
||||
std::shared_ptr<UniformAugOp> tensor_op = std::make_shared<UniformAugOp>(tensor_ops, num_ops_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
|
|
|
@ -40,17 +40,29 @@ namespace api {
|
|||
|
||||
class TensorOperation;
|
||||
class SamplerObj;
|
||||
// Datasets classes (in alphabetical order)
|
||||
class Cifar10Dataset;
|
||||
class ImageFolderDataset;
|
||||
class MnistDataset;
|
||||
// Dataset Op classes (in alphabetical order)
|
||||
class BatchDataset;
|
||||
class RepeatDataset;
|
||||
class MapDataset;
|
||||
class ProjectDataset;
|
||||
class RenameDataset;
|
||||
class RepeatDataset;
|
||||
class ShuffleDataset;
|
||||
class SkipDataset;
|
||||
class Cifar10Dataset;
|
||||
class ProjectDataset;
|
||||
class ZipDataset;
|
||||
class RenameDataset;
|
||||
|
||||
/// \brief Function to create a Cifar10 Dataset
|
||||
/// \notes The generated dataset has two columns ['image', 'label']
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset
|
||||
/// \param[in] num_samples The number of images to be included in the dataset
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
|
||||
/// will be used to randomly iterate the entire dataset
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
|
||||
std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \brief Function to create an ImageFolderDataset
|
||||
/// \notes A source dataset that reads images from a tree of directories
|
||||
|
@ -76,16 +88,6 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
|
|||
/// \return Shared pointer to the current MnistDataset
|
||||
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
|
||||
|
||||
/// \brief Function to create a Cifar10 Dataset
|
||||
/// \notes The generated dataset has two columns ['image', 'label']
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset
|
||||
/// \param[in] num_samples The number of images to be included in the dataset
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
|
||||
/// will be used to randomly iterate the entire dataset
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, int32_t num_samples,
|
||||
std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \class Dataset datasets.h
|
||||
/// \brief A base class to represent a dataset in the data pipeline.
|
||||
class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||
|
@ -128,14 +130,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return Shared pointer to the current BatchDataset
|
||||
std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
|
||||
|
||||
/// \brief Function to create a RepeatDataset
|
||||
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
|
||||
/// \param[in] count Number of times the dataset should be repeated
|
||||
/// \return Shared pointer to the current Dataset
|
||||
/// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
|
||||
/// due to a limitation in the current implementation
|
||||
std::shared_ptr<Dataset> Repeat(int32_t count = -1);
|
||||
|
||||
/// \brief Function to create a MapDataset
|
||||
/// \notes Applies each operation in operations to this dataset
|
||||
/// \param[in] operations Vector of operations to be applied on the dataset. Operations are
|
||||
|
@ -156,6 +150,28 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
std::vector<std::string> output_columns = {},
|
||||
const std::vector<std::string> &project_columns = {});
|
||||
|
||||
/// \brief Function to create a Project Dataset
|
||||
/// \notes Applies project to the dataset
|
||||
/// \param[in] columns The name of columns to project
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns);
|
||||
|
||||
/// \brief Function to create a Rename Dataset
|
||||
/// \notes Renames the columns in the input dataset
|
||||
/// \param[in] input_columns List of the input columns to rename
|
||||
/// \param[in] output_columns List of the output columns
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<RenameDataset> Rename(const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns);
|
||||
|
||||
/// \brief Function to create a RepeatDataset
|
||||
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
|
||||
/// \param[in] count Number of times the dataset should be repeated
|
||||
/// \return Shared pointer to the current Dataset
|
||||
/// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset`
|
||||
/// due to a limitation in the current implementation
|
||||
std::shared_ptr<Dataset> Repeat(int32_t count = -1);
|
||||
|
||||
/// \brief Function to create a Shuffle Dataset
|
||||
/// \notes Randomly shuffles the rows of this dataset
|
||||
/// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
|
||||
|
@ -168,26 +184,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return Shared pointer to the current SkipDataset
|
||||
std::shared_ptr<SkipDataset> Skip(int32_t count);
|
||||
|
||||
/// \brief Function to create a Project Dataset
|
||||
/// \notes Applies project to the dataset
|
||||
/// \param[in] columns The name of columns to project
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<ProjectDataset> Project(const std::vector<std::string> &columns);
|
||||
|
||||
/// \brief Function to create a Zip Dataset
|
||||
/// \notes Applies zip to the dataset
|
||||
/// \param[in] datasets A list of shared pointer to the datasets that we want to zip
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets);
|
||||
|
||||
/// \brief Function to create a Rename Dataset
|
||||
/// \notes Renames the columns in the input dataset
|
||||
/// \param[in] input_columns List of the input columns to rename
|
||||
/// \param[in] output_columns List of the output columns
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<RenameDataset> Rename(const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns);
|
||||
|
||||
protected:
|
||||
std::vector<std::shared_ptr<Dataset>> children;
|
||||
std::shared_ptr<Dataset> parent;
|
||||
|
@ -199,6 +201,28 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
|
||||
/* ####################################### Derived Dataset classes ################################# */
|
||||
|
||||
class Cifar10Dataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar10Dataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
int32_t num_samples_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
/// \class ImageFolderDataset
|
||||
/// \brief A Dataset derived class to represent ImageFolder dataset
|
||||
class ImageFolderDataset : public Dataset {
|
||||
|
@ -273,6 +297,71 @@ class BatchDataset : public Dataset {
|
|||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_;
|
||||
};
|
||||
|
||||
class MapDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns = {},
|
||||
std::vector<std::string> output_columns = {}, const std::vector<std::string> &columns = {});
|
||||
|
||||
/// \brief Destructor
|
||||
~MapDataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
std::vector<std::string> project_columns_;
|
||||
};
|
||||
|
||||
class ProjectDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ProjectDataset(const std::vector<std::string> &columns);
|
||||
|
||||
/// \brief Destructor
|
||||
~ProjectDataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> columns_;
|
||||
};
|
||||
|
||||
class RenameDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit RenameDataset(const std::vector<std::string> &input_columns, const std::vector<std::string> &output_columns);
|
||||
|
||||
/// \brief Destructor
|
||||
~RenameDataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
};
|
||||
|
||||
class RepeatDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -329,72 +418,6 @@ class SkipDataset : public Dataset {
|
|||
int32_t skip_count_;
|
||||
};
|
||||
|
||||
class MapDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns = {},
|
||||
std::vector<std::string> output_columns = {}, const std::vector<std::string> &columns = {});
|
||||
|
||||
/// \brief Destructor
|
||||
~MapDataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
std::vector<std::string> project_columns_;
|
||||
};
|
||||
|
||||
class Cifar10Dataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar10Dataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
int32_t num_samples_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
class ProjectDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ProjectDataset(const std::vector<std::string> &columns);
|
||||
|
||||
/// \brief Destructor
|
||||
~ProjectDataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> columns_;
|
||||
};
|
||||
|
||||
class ZipDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -412,27 +435,6 @@ class ZipDataset : public Dataset {
|
|||
bool ValidateParams() override;
|
||||
};
|
||||
|
||||
class RenameDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit RenameDataset(const std::vector<std::string> &input_columns, const std::vector<std::string> &output_columns);
|
||||
|
||||
/// \brief Destructor
|
||||
~RenameDataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue