forked from mindspore-Ecosystem/mindspore
C++ API: Provide validate param functions
This commit is contained in:
parent
7d70fb4dc4
commit
9c8af0d1cf
|
@ -105,7 +105,7 @@ Dataset::Dataset() {
|
|||
|
||||
// Function to create a CelebADataset.
|
||||
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type,
|
||||
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler, bool decode,
|
||||
const std::set<std::string> &extensions) {
|
||||
auto ds = std::make_shared<CelebADataset>(dataset_dir, dataset_type, sampler, decode, extensions);
|
||||
|
||||
|
@ -114,7 +114,7 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std:
|
|||
}
|
||||
|
||||
// Function to create a Cifar10Dataset.
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
|
@ -122,7 +122,7 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha
|
|||
}
|
||||
|
||||
// Function to create a Cifar100Dataset.
|
||||
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
|
||||
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
|
@ -131,8 +131,8 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::s
|
|||
|
||||
// Function to create a CLUEDataset.
|
||||
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &clue_files, const std::string &task,
|
||||
const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int num_shards,
|
||||
int shard_id) {
|
||||
const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id) {
|
||||
auto ds = std::make_shared<CLUEDataset>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id);
|
||||
|
||||
// Call derived class validation method.
|
||||
|
@ -150,9 +150,10 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
|
|||
}
|
||||
|
||||
// Function to create a ImageFolderDataset.
|
||||
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode,
|
||||
std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions,
|
||||
std::map<std::string, int32_t> class_indexing) {
|
||||
std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::set<std::string> &extensions,
|
||||
const std::map<std::string, int32_t> &class_indexing) {
|
||||
// This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
|
||||
bool recursive = false;
|
||||
|
||||
|
@ -164,7 +165,7 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
|
|||
}
|
||||
|
||||
// Function to create a MnistDataset.
|
||||
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler) {
|
||||
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<MnistDataset>(dataset_dir, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
|
@ -181,7 +182,7 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
|
|||
}
|
||||
|
||||
// Function to create a TextFileDataset.
|
||||
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples,
|
||||
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) {
|
||||
auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id);
|
||||
|
||||
|
@ -191,9 +192,9 @@ std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files
|
|||
|
||||
// Function to create a VOCDataset.
|
||||
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode,
|
||||
const std::map<std::string, int32_t> &class_index, bool decode,
|
||||
std::shared_ptr<SamplerObj> sampler) {
|
||||
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_index, decode, sampler);
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler) {
|
||||
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_indexing, decode, sampler);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
|
@ -402,16 +403,57 @@ Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, in
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to validate dataset params
|
||||
bool ValidateCommonDatasetParams(std::string dataset_dir) {
|
||||
// Helper function to validate dataset directory parameter
|
||||
bool ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
|
||||
if (dataset_dir.empty()) {
|
||||
MS_LOG(ERROR) << "No dataset path is specified";
|
||||
MS_LOG(ERROR) << dataset_name << ": dataset_dir is not specified.";
|
||||
return false;
|
||||
}
|
||||
|
||||
Path dir(dataset_dir);
|
||||
if (!dir.IsDirectory()) {
|
||||
MS_LOG(ERROR) << dataset_name << ": dataset_dir: [" << dataset_dir << "] is an invalid directory path.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (access(dataset_dir.c_str(), R_OK) == -1) {
|
||||
MS_LOG(ERROR) << "No access to specified dataset path: " << dataset_dir;
|
||||
MS_LOG(ERROR) << dataset_name << ": No access to specified dataset path: " << dataset_dir;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Helper function to validate dataset dataset files parameter
|
||||
bool ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
|
||||
if (dataset_files.empty()) {
|
||||
MS_LOG(ERROR) << dataset_name << ": dataset_files is not specified.";
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto f : dataset_files) {
|
||||
Path dataset_file(f);
|
||||
if (!dataset_file.Exists()) {
|
||||
MS_LOG(ERROR) << dataset_name << ": dataset file: [" << f << "] is invalid or does not exist.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Helper function to validate dataset num_shards and shard_id parameters
|
||||
bool ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
|
||||
if (num_shards <= 0) {
|
||||
MS_LOG(ERROR) << dataset_name << ": Invalid num_shards: " << num_shards;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (shard_id < 0 || shard_id >= num_shards) {
|
||||
MS_LOG(ERROR) << dataset_name << ": Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -431,9 +473,7 @@ CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &
|
|||
extensions_(extensions) {}
|
||||
|
||||
bool CelebADataset::ValidateParams() {
|
||||
Path dir(dataset_dir_);
|
||||
if (!dir.IsDirectory()) {
|
||||
MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
|
||||
if (!ValidateDatasetDirParam("CelebADataset", dataset_dir_)) {
|
||||
return false;
|
||||
}
|
||||
std::set<std::string> dataset_type_list = {"all", "train", "valid", "test"};
|
||||
|
@ -471,7 +511,7 @@ std::vector<std::shared_ptr<DatasetOp>> CelebADataset::Build() {
|
|||
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
||||
|
||||
bool Cifar10Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
||||
bool Cifar10Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_); }
|
||||
|
||||
// Function to build CifarOp for Cifar10
|
||||
std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
|
||||
|
@ -500,7 +540,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
|
|||
Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
||||
|
||||
bool Cifar100Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
||||
bool Cifar100Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_); }
|
||||
|
||||
// Function to build CifarOp for Cifar100
|
||||
std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
|
||||
|
@ -529,7 +569,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
|
|||
|
||||
// Constructor for CLUEDataset
|
||||
CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string task, std::string usage,
|
||||
int64_t num_samples, ShuffleMode shuffle, int num_shards, int shard_id)
|
||||
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id)
|
||||
: dataset_files_(clue_files),
|
||||
task_(task),
|
||||
usage_(usage),
|
||||
|
@ -539,19 +579,10 @@ CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string
|
|||
shard_id_(shard_id) {}
|
||||
|
||||
bool CLUEDataset::ValidateParams() {
|
||||
if (dataset_files_.empty()) {
|
||||
MS_LOG(ERROR) << "CLUEDataset: dataset_files is not specified.";
|
||||
if (!ValidateDatasetFilesParam("CLUEDataset", dataset_files_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto f : dataset_files_) {
|
||||
Path clue_file(f);
|
||||
if (!clue_file.Exists()) {
|
||||
MS_LOG(ERROR) << "dataset file: [" << f << "] is invalid or not exist";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"};
|
||||
std::vector<std::string> usage_list = {"train", "test", "eval"};
|
||||
|
||||
|
@ -570,13 +601,7 @@ bool CLUEDataset::ValidateParams() {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (num_shards_ <= 0) {
|
||||
MS_LOG(ERROR) << "CLUEDataset: Invalid num_shards: " << num_shards_;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
|
||||
MS_LOG(ERROR) << "CLUEDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
|
||||
if (!ValidateDatasetShardParams("CLUEDataset", num_shards_, shard_id_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -734,9 +759,7 @@ CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &anno
|
|||
: dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {}
|
||||
|
||||
bool CocoDataset::ValidateParams() {
|
||||
Path dir(dataset_dir_);
|
||||
if (!dir.IsDirectory()) {
|
||||
MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
|
||||
if (!ValidateDatasetDirParam("CocoDataset", dataset_dir_)) {
|
||||
return false;
|
||||
}
|
||||
Path annotation_file(annotation_file_);
|
||||
|
@ -829,7 +852,7 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std
|
|||
class_indexing_(class_indexing),
|
||||
exts_(extensions) {}
|
||||
|
||||
bool ImageFolderDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
||||
bool ImageFolderDataset::ValidateParams() { return ValidateDatasetDirParam("ImageFolderDataset", dataset_dir_); }
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
|
@ -857,7 +880,7 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
|
|||
MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir), sampler_(sampler) {}
|
||||
|
||||
bool MnistDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
|
||||
bool MnistDataset::ValidateParams() { return ValidateDatasetDirParam("MnistDataset", dataset_dir_); }
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
|
@ -890,31 +913,16 @@ TextFileDataset::TextFileDataset(std::vector<std::string> dataset_files, int32_t
|
|||
shard_id_(shard_id) {}
|
||||
|
||||
bool TextFileDataset::ValidateParams() {
|
||||
if (dataset_files_.empty()) {
|
||||
MS_LOG(ERROR) << "TextFileDataset: dataset_files is not specified.";
|
||||
if (!ValidateDatasetFilesParam("TextFileDataset", dataset_files_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto file : dataset_files_) {
|
||||
std::ifstream handle(file);
|
||||
if (!handle.is_open()) {
|
||||
MS_LOG(ERROR) << "TextFileDataset: Failed to open file: " << file;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "TextFileDataset: Invalid number of samples: " << num_samples_;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_shards_ <= 0) {
|
||||
MS_LOG(ERROR) << "TextFileDataset: Invalid num_shards: " << num_shards_;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
|
||||
MS_LOG(ERROR) << "TextFileDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
|
||||
if (!ValidateDatasetShardParams("TextfileDataset", num_shards_, shard_id_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -960,12 +968,12 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
|
|||
|
||||
// Constructor for VOCDataset
|
||||
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
|
||||
const std::map<std::string, int32_t> &class_index, bool decode,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode,
|
||||
std::shared_ptr<SamplerObj> sampler)
|
||||
: dataset_dir_(dataset_dir),
|
||||
task_(task),
|
||||
mode_(mode),
|
||||
class_index_(class_index),
|
||||
class_index_(class_indexing),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ class ZipDataset;
|
|||
/// will be used to randomly iterate the entire dataset
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all",
|
||||
const std::shared_ptr<SamplerObj> &sampler = nullptr, const bool &decode = false,
|
||||
const std::shared_ptr<SamplerObj> &sampler = nullptr, bool decode = false,
|
||||
const std::set<std::string> &extensions = {});
|
||||
|
||||
/// \brief Function to create a Cifar10 Dataset
|
||||
|
@ -84,7 +84,8 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std:
|
|||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
|
||||
/// will be used to randomly iterate the entire dataset
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
|
||||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir,
|
||||
const std::shared_ptr<SamplerObj> &sampler = nullptr);
|
||||
|
||||
/// \brief Function to create a Cifar100 Dataset
|
||||
/// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
|
||||
|
@ -93,7 +94,7 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha
|
|||
/// will be used to randomly iterate the entire dataset
|
||||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
|
||||
std::shared_ptr<SamplerObj> sampler = nullptr);
|
||||
const std::shared_ptr<SamplerObj> &sampler = nullptr);
|
||||
|
||||
/// \brief Function to create a CLUEDataset
|
||||
/// \notes The generated dataset has a variable number of columns depending on the task and usage
|
||||
|
@ -114,7 +115,8 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
|
|||
/// \return Shared pointer to the current CLUEDataset
|
||||
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
|
||||
const std::string &usage = "train", int64_t num_samples = 0,
|
||||
ShuffleMode shuffle = ShuffleMode::kGlobal, int num_shards = 1, int shard_id = 0);
|
||||
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
|
||||
int32_t shard_id = 0);
|
||||
|
||||
/// \brief Function to create a CocoDataset
|
||||
/// \notes The generated dataset has multi-columns :
|
||||
|
@ -147,10 +149,10 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
|
|||
/// \param[in] extensions File extensions to be read
|
||||
/// \param[in] class_indexing a class name to label map
|
||||
/// \return Shared pointer to the current ImageFolderDataset
|
||||
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode = false,
|
||||
std::shared_ptr<SamplerObj> sampler = nullptr,
|
||||
std::set<std::string> extensions = {},
|
||||
std::map<std::string, int32_t> class_indexing = {});
|
||||
std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode = false,
|
||||
const std::shared_ptr<SamplerObj> &sampler = nullptr,
|
||||
const std::set<std::string> &extensions = {},
|
||||
const std::map<std::string, int32_t> &class_indexing = {});
|
||||
|
||||
/// \brief Function to create a MnistDataset
|
||||
/// \notes The generated dataset has two columns ['image', 'label']
|
||||
|
@ -158,7 +160,8 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
|
|||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
|
||||
/// A `RandomSampler` will be used to randomly iterate the entire dataset
|
||||
/// \return Shared pointer to the current MnistDataset
|
||||
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
|
||||
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir,
|
||||
const std::shared_ptr<SamplerObj> &sampler = nullptr);
|
||||
|
||||
/// \brief Function to create a ConcatDataset
|
||||
/// \notes Reload "+" operator to concat two datasets
|
||||
|
@ -183,7 +186,7 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
|
|||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
/// specified only when num_shards is also specified. (Default = 0)
|
||||
/// \return Shared pointer to the current TextFileDataset
|
||||
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples = 0,
|
||||
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples = 0,
|
||||
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
|
||||
int32_t shard_id = 0);
|
||||
|
||||
|
@ -202,8 +205,8 @@ std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files
|
|||
/// \return Shared pointer to the current Dataset
|
||||
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
|
||||
const std::string &mode = "train",
|
||||
const std::map<std::string, int32_t> &class_index = {}, bool decode = false,
|
||||
std::shared_ptr<SamplerObj> sampler = nullptr);
|
||||
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
|
||||
const std::shared_ptr<SamplerObj> &sampler = nullptr);
|
||||
|
||||
/// \brief Function to create a ZipDataset
|
||||
/// \notes Applies zip to the dataset
|
||||
|
@ -417,7 +420,7 @@ class CLUEDataset : public Dataset {
|
|||
public:
|
||||
/// \brief Constructor
|
||||
CLUEDataset(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int num_shards, int shard_id);
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id);
|
||||
|
||||
/// \brief Destructor
|
||||
~CLUEDataset() = default;
|
||||
|
@ -440,8 +443,8 @@ class CLUEDataset : public Dataset {
|
|||
std::string usage_;
|
||||
int64_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int num_shards_;
|
||||
int shard_id_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
};
|
||||
|
||||
class CocoDataset : public Dataset {
|
||||
|
@ -549,7 +552,7 @@ class VOCDataset : public Dataset {
|
|||
public:
|
||||
/// \brief Constructor
|
||||
VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
|
||||
const std::map<std::string, int32_t> &class_index, bool decode, std::shared_ptr<SamplerObj> sampler);
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~VOCDataset() = default;
|
||||
|
|
|
@ -111,7 +111,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetFail3) {
|
|||
|
||||
// Attempt to create a TextFile Dataset
|
||||
// with non-existent dataset_files input
|
||||
std::shared_ptr<Dataset> ds = TextFile({"notexist.txt"}, 0, ShuffleMode::kFalse);
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1, "notexist.txt"}, 0, ShuffleMode::kFalse);
|
||||
|
||||
// Expect failure: specified dataset_files does not exist
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
|
|
Loading…
Reference in New Issue