!1808 consistent design for num_samples

Merge pull request !1808 from Jamie/numsamples
This commit is contained in:
mindspore-ci-bot 2020-06-04 22:12:55 +08:00 committed by Gitee
commit 769ae609b4
55 changed files with 618 additions and 1155 deletions

View File

@ -856,9 +856,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -893,9 +891,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -930,9 +926,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -966,9 +960,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -1001,9 +993,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -1039,10 +1029,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
(void)builder.SetNumWorkers(ToInt(value));
} else if (key == "schema_file_path" || key == "schema_json_string") {
schema_exists = true;
} else if (key == "num_samples") {
(void)builder.SetTotalRows(ToInt(value));
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
} else if (key == "num_samples") {
// This is not sampling here. The random data op needs to know how much data to
// generate. It does not currently support sampling.
(void)builder.SetTotalRows(ToInt(value));
}
}
if (schema_exists) {
@ -1077,9 +1069,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -1121,8 +1111,6 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
(void)builder->SetDecode(ToBool(value));
} else if (key == "extensions") {
(void)builder->SetExtensions(ToStringSet(value));
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "dataset_type") {
(void)builder->SetDatasetType(ToString(value));
}
@ -1153,7 +1141,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
(void)builder->SetTotalRows(ToInt(value));
} else if (key == "num_shards") {
(void)builder->SetNumDevices(ToInt(value));
} else if (key == "shard_id") {

View File

@ -49,7 +49,6 @@
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
@ -143,17 +142,16 @@ void bindDatasetOps(py::module *m) {
});
(void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) {
.def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
int64_t count = 0;
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count));
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
return count;
});
(void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
.def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) {
.def_static("get_num_rows_and_classes", [](const std::string &path) {
int64_t count = 0, num_classes = 0;
THROW_IF_ERROR(
ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set<std::string>{}, &count, &num_classes));
THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes));
return py::make_tuple(count, num_classes);
});
@ -172,22 +170,21 @@ void bindDatasetOps(py::module *m) {
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
.def_static("get_num_rows_and_classes",
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
[](const std::string &file, const py::dict &dict, const std::string &usage) {
int64_t count = 0, num_classes = 0;
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes));
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
return py::make_tuple(count, num_classes);
})
.def_static("get_class_indexing",
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing));
return output_class_indexing;
});
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
return output_class_indexing;
});
(void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) {
.def_static("get_num_rows", [](const std::string &dir) {
int64_t count = 0;
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
return count;
});
@ -206,13 +203,13 @@ void bindDatasetOps(py::module *m) {
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples) {
int64_t count = 0;
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count));
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
return count;
})
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
const std::string &task_mode, const py::dict &dict) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, numSamples, &output_class_indexing));
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
return output_class_indexing;
});
}
@ -452,25 +449,19 @@ void bindSamplerOps(py::module *m) {
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
py::arg("seed"));
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
.def(py::init<int64_t, int64_t, bool>());
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
.def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
py::arg("num_samples"))
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
.def(py::init<int64_t, bool, bool>());
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
.def(py::init<>());
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
.def(py::init<int64_t, int64_t>());
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
.def(py::init<int64_t, std::vector<int64_t>>());
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
*m, "MindrecordSubsetRandomSampler")
@ -487,11 +478,10 @@ void bindSamplerOps(py::module *m) {
}));
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
py::arg("replacement"));
.def(py::init<int64_t, std::vector<double>, bool>());
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
.def(py::init<py::object>(), py::arg("pySampler"));
.def(py::init<int64_t, py::object>());
}
void bindInfoObjects(py::module *m) {

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace dataset {
CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr), builder_num_samples_(0) {
CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
@ -38,7 +38,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << ".";
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0;
int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
@ -47,10 +49,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
// label is like this:0 1 0 0 1......
RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
*op =
std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, builder_op_connector_size_,
builder_decode_, builder_dataset_type_, builder_extensions_, std::move(builder_schema_),
std::move(builder_sampler_), builder_num_samples_);
*op = std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
builder_op_connector_size_, builder_decode_, builder_dataset_type_,
builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_));
if (*op == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null");
}
@ -68,7 +69,7 @@ Status CelebAOp::Builder::SanityCheck() {
CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size,
bool decode, const std::string &dataset_type, const std::set<std::string> &exts,
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler, int64_t num_samples)
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, queue_size),
rows_per_buffer_(rows_per_buffer),
folder_path_(dir),
@ -77,8 +78,6 @@ CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::stri
data_schema_(std::move(schema)),
sampler_(std::move(sampler)),
num_rows_in_attr_file_(0),
num_rows_exact_(0),
num_samples_(num_samples),
dataset_type_(dataset_type) {
// Set the column name map (base class field)
for (int32_t index = 0; index < data_schema_->NumColumns(); index++) {
@ -202,13 +201,6 @@ Status CelebAOp::ParseImageAttrInfo() {
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
while (!image_infos.empty() && needMoreData) {
for (uint32_t index = 0; index < image_infos.size(); index++) {
if (num_samples_ != 0 && image_labels_vec_.size() >= num_samples_) {
MS_LOG(WARNING) << "Image number(" << image_labels_vec_.size() << " is more than"
<< " rows num eval attr file(" << num_rows_in_attr_file_ << ") or num samples(" << num_samples_
<< ").";
needMoreData = false;
break;
}
std::string image_info = image_infos[index];
std::vector<std::string> split = Split(image_info);
std::pair<std::string, std::vector<int32_t>> image_labels;
@ -239,14 +231,13 @@ Status CelebAOp::ParseImageAttrInfo() {
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
}
num_rows_exact_ = image_labels_vec_.size();
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_exact_) ? num_rows_exact_ : num_samples_;
if (num_rows_exact_ == 0) {
num_rows_ = image_labels_vec_.size();
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
"validation first.");
}
MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_exact_ << ".";
MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_ << ".";
return Status::OK();
}
@ -268,28 +259,6 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) {
return split;
}
// Derived from RandomAccessOp
Status CelebAOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_samples_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_samples_;
return Status::OK();
}
Status CelebAOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_exact_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
"validation first.");
}
*num = num_rows_exact_;
return Status::OK();
}
// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work
Status CelebAOp::operator()() {
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
@ -310,9 +279,8 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
RETURN_IF_NOT_OK((*data_buffer)->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
if ((*itr) >= num_rows_exact_) {
MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_exact_
<< ".";
if ((*itr) >= num_rows_) {
MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_ << ".";
continue;
}
keys.push_back(*itr);
@ -446,7 +414,7 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nNumber of rows:" << num_rows_exact_ << "\nceleba dir: " << folder_path_ << "\n\n";
out << "\nNumber of rows:" << num_rows_ << "\nceleba dir: " << folder_path_ << "\n\n";
}
}

View File

@ -108,14 +108,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
return *this;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method
// @param const std::string dataset_type: type to be read
// @return Builder setter method returns reference to the builder.
@ -141,7 +133,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
std::set<std::string> builder_extensions_;
std::shared_ptr<Sampler> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
int64_t builder_num_samples_;
std::string builder_dataset_type_;
};
@ -153,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read
CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode,
const std::string &dataset_type, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema,
std::shared_ptr<Sampler> sampler, int64_t num_samples);
std::shared_ptr<Sampler> sampler);
~CelebAOp() override = default;
@ -163,16 +154,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumRowsInDataset(int64_t *num) const override;
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t worker_id - id of each worker
// @return Status - The error code return
@ -233,11 +214,9 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
std::shared_ptr<Sampler> sampler_;
std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_;
int64_t num_rows_in_attr_file_; // rows number specified in attr file
int64_t num_rows_exact_; // exact rows number,maybe is less than rows_num_in_attr_file_
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
WaitPost wp_;
std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_;
int64_t num_samples_;
std::string dataset_type_;
std::ifstream partition_file_;
};

View File

@ -35,7 +35,7 @@ constexpr uint32_t kCifarImageChannel = 3;
constexpr uint32_t kCifarBlockImageNum = 5;
constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel;
CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) {
CifarOp::Builder::Builder() : sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
@ -46,7 +46,9 @@ CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) {
Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
if (sampler_ == nullptr) {
sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0;
int64_t start_index = 0;
sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
schema_ = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar();
@ -62,7 +64,7 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar)));
}
*ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, num_samples_,
*ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_,
std::move(schema_), std::move(sampler_));
return Status::OK();
}
@ -76,16 +78,13 @@ Status CifarOp::Builder::SanityCheck() {
}
CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir,
int32_t queue_size, int64_t num_samples, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_works, queue_size),
cifar_type_(type),
rows_per_buffer_(rows_per_buf),
folder_path_(file_dir),
num_samples_(num_samples),
data_schema_(std::move(data_schema)),
sampler_(std::move(sampler)),
num_rows_(0),
row_cnt_(0),
buf_cnt_(0) {
// set the column name map (base class field)
@ -112,8 +111,7 @@ Status CifarOp::operator()() {
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
row_cnt_++;
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
if (row_cnt_ % rows_per_buffer_ == 0) {
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add(
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
@ -255,30 +253,6 @@ Status CifarOp::InitSampler() {
return Status::OK();
}
// Derived from RandomAccessOp
Status CifarOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
std::string err_msg = "There is no valid data matching the dataset API " + api +
".Please check file path or dataset API validation first.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
(*num) = num_samples_;
return Status::OK();
}
// Derived from RandomAccessOp
Status CifarOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
std::string err_msg = "There is no valid data matching the dataset API " + api +
".Please check file path or dataset API validation first.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
(*num) = num_rows_;
return Status::OK();
}
Status CifarOp::ReadCifarBlockDataAsync() {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(GetCifarFiles());
@ -404,7 +378,6 @@ Status CifarOp::ParseCifarData() {
}
cifar_image_label_pairs_.shrink_to_fit();
num_rows_ = cifar_image_label_pairs_.size();
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
if (num_rows_ == 0) {
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
std::string err_msg = "There is no valid data matching the dataset API " + api +
@ -432,11 +405,11 @@ Status CifarOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
return Status::OK();
}
Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool isCIFAR10, int64_t *count) {
Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) {
// the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block()
std::shared_ptr<CifarOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetNumSamples(numSamples).SetCifarType(isCIFAR10).Build(&op));
RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op));
RETURN_IF_NOT_OK(op->GetCifarFiles());
if (op->cifar_type_ == kCifar10) {
constexpr int64_t num_cifar10_records = 10000;
@ -448,7 +421,6 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool
}
*count = *count + num_cifar10_records;
}
*count = *count < numSamples || numSamples == 0 ? *count : numSamples;
return Status::OK();
} else {
int64_t num_cifar100_records = 0;
@ -470,7 +442,7 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
*count = num_cifar100_records < numSamples || numSamples == 0 ? num_cifar100_records : numSamples;
*count = num_cifar100_records;
return Status::OK();
}
}

View File

@ -73,14 +73,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Setter method
// @param uint64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(uint64_t num_samples) {
num_samples_ = num_samples;
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
@ -121,7 +113,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
private:
std::string dir_;
int32_t num_workers_;
uint64_t num_samples_;
int32_t rows_per_buffer_;
int32_t op_connect_size_;
std::shared_ptr<Sampler> sampler_;
@ -137,7 +128,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @param uint32_t - queueSize - connector queue size
// @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size,
int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
// Destructor.
~CifarOp() = default;
@ -152,16 +143,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param uint64_t num - to return numRows
// @return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
// @param uint64_t num - to return numRows
// @return Status - The error code return
Status GetNumRowsInDataset(int64_t *num) const override;
// A print method typically used for debugging
// @param out
// @param show_all
@ -169,11 +150,10 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Function to count the number of samples in the CIFAR dataset
// @param dir path to the CIFAR directory
// @param numSamples maximum number of samples requested
// @param isCIFAR10 true if CIFAR10 and false if CIFAR100
// @param count output arg that will hold the minimum of the actual dataset size and numSamples
// @param count output arg that will hold the actual dataset size
// @return
static Status CountTotalRows(const std::string &dir, int64_t numSamples, bool isCIFAR10, int64_t *count);
static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count);
private:
// Initialize Sampler, calls sampler->Init() within
@ -227,10 +207,8 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
CifarType cifar_type_;
int32_t rows_per_buffer_;
std::string folder_path_;
int64_t num_samples_;
std::unique_ptr<DataSchema> data_schema_;
std::shared_ptr<Sampler> sampler_;
int64_t num_rows_;
int64_t row_cnt_;
int64_t buf_cnt_;

View File

@ -26,8 +26,7 @@
namespace mindspore {
namespace dataset {
ImageFolderOp::Builder::Builder()
: builder_decode_(false), builder_recursive_(false), builder_num_samples_(0), builder_sampler_(nullptr) {
ImageFolderOp::Builder::Builder() : builder_decode_(false), builder_recursive_(false), builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
@ -37,7 +36,9 @@ ImageFolderOp::Builder::Builder()
Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data
int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar();
@ -46,9 +47,9 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) {
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
*ptr = std::make_shared<ImageFolderOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
builder_op_connector_size_, builder_num_samples_, builder_recursive_,
builder_decode_, builder_extensions_, builder_labels_to_read_,
std::move(builder_schema_), std::move(builder_sampler_));
builder_op_connector_size_, builder_recursive_, builder_decode_,
builder_extensions_, builder_labels_to_read_, std::move(builder_schema_),
std::move(builder_sampler_));
return Status::OK();
}
@ -61,20 +62,18 @@ Status ImageFolderOp::Builder::SanityCheck() {
}
ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size,
int64_t num_samples, bool recursive, bool do_decode, const std::set<std::string> &exts,
bool recursive, bool do_decode, const std::set<std::string> &exts,
const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
: ParallelOp(num_wkrs, queue_size),
rows_per_buffer_(rows_per_buffer),
folder_path_(file_dir),
num_samples_(num_samples),
recursive_(recursive),
decode_(do_decode),
extensions_(exts),
class_index_(map),
data_schema_(std::move(data_schema)),
sampler_(std::move(sampler)),
num_rows_(0),
row_cnt_(0),
buf_cnt_(0),
sampler_ind_(0),
@ -117,7 +116,6 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
}
image_label_pairs_.shrink_to_fit();
num_rows_ = image_label_pairs_.size();
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
// free memory of two queues used for pre-scan
folder_name_queue_->Reset();
image_name_queue_->Reset();
@ -138,8 +136,7 @@ Status ImageFolderOp::operator()() {
std::shared_ptr<Tensor> sample_ids = sample_row[0];
if (sample_ids->type() != DataType(DataType::DE_INT64)) RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64");
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
keys.push_back(*itr);
row_cnt_++;
if (row_cnt_ % rows_per_buffer_ == 0) {
@ -272,28 +269,6 @@ Status ImageFolderOp::InitSampler() {
return Status::OK();
}
// Derived from RandomAccessOp
Status ImageFolderOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_samples_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API "
"validation first.");
}
(*num) = num_samples_;
return Status::OK();
}
// Derived from RandomAccessOp
Status ImageFolderOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API "
"validation first.");
}
(*num) = num_rows_;
return Status::OK();
}
// Derived from RandomAccessOp
Status ImageFolderOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
@ -413,16 +388,14 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() {
return Status::OK();
}
Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t &num_samples,
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
int64_t dev_id, int64_t num_dev) {
Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows,
int64_t *num_classes, int64_t dev_id, int64_t num_dev) {
Path dir(path);
std::string err_msg = "";
int64_t row_cnt = 0;
err_msg += (dir.Exists() == false || dir.IsDirectory() == false) ? "unable to open dir " + path : "";
err_msg += (num_classes == nullptr || num_rows == nullptr) ? "num_class/num_rows is null\n" : "";
err_msg += (dev_id >= num_dev || num_dev <= 0) ? "invalid sharding config\n" : "";
err_msg += num_samples < 0 ? "num_samples can't be negative! set it to 0 to use all samples\n" : "";
if (err_msg.empty() == false) {
RETURN_STATUS_UNEXPECTED(err_msg);
}
@ -441,10 +414,6 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t
while (dir_itr->hasNext()) {
if (exts.empty() || exts.find(subdir.Extension()) != exts.end()) {
++row_cnt;
if (row_cnt == num_samples * num_dev) {
(*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1);
return Status::OK();
}
}
}
foldernames.pop();

View File

@ -107,14 +107,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
@ -153,7 +145,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
bool builder_recursive_;
std::string builder_dir_;
int32_t builder_num_workers_;
int64_t builder_num_samples_;
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::set<std::string> builder_extensions_;
@ -169,10 +160,9 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @param int32_t queue_size - connector queue size
// @param std::set<std::string> exts - set of file extensions to read, if empty, read everything under the dir
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size,
int64_t num_samples, bool recursive, bool do_decode, const std::set<std::string> &exts,
const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema>,
std::shared_ptr<Sampler> sampler);
ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive,
bool do_decode, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
std::unique_ptr<DataSchema>, std::shared_ptr<Sampler> sampler);
// Destructor.
~ImageFolderOp() = default;
@ -198,16 +188,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumRowsInDataset(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
@ -221,9 +201,8 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// This function is a hack! It is to return the num_class and num_rows the old storageOp does. The result
// returned by this function may not be consistent with what image_folder_op is going to return
// user this at your own risk!
static Status CountRowsAndClasses(const std::string &path, const int64_t &num_samples,
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
int64_t dev_id = 0, int64_t num_dev = 1);
static Status CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows,
int64_t *num_classes, int64_t dev_id = 0, int64_t num_dev = 1);
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
@ -266,14 +245,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
int32_t rows_per_buffer_;
std::string folder_path_; // directory of image folder
int64_t num_samples_;
bool recursive_;
bool decode_;
std::set<std::string> extensions_; // extensions allowed
std::map<std::string, int32_t> class_index_;
std::unique_ptr<DataSchema> data_schema_;
std::shared_ptr<Sampler> sampler_;
int64_t num_rows_; // total number of images in ImageFolder
int64_t row_cnt_;
int64_t buf_cnt_;
int64_t sampler_ind_;

View File

@ -29,7 +29,7 @@
namespace mindspore {
namespace dataset {
ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_(0), builder_decode_(false) {
ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_decode_(false) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
@ -39,16 +39,18 @@ ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_
Status ManifestOp::Builder::Build(std::shared_ptr<ManifestOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0;
int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
*ptr = std::make_shared<ManifestOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_file_, builder_op_connector_size_, builder_num_samples_,
builder_decode_, builder_labels_to_read_, std::move(builder_schema_), std::move(builder_sampler_), builder_usage_);
*ptr = std::make_shared<ManifestOp>(builder_num_workers_, builder_rows_per_buffer_, builder_file_,
builder_op_connector_size_, builder_decode_, builder_labels_to_read_,
std::move(builder_schema_), std::move(builder_sampler_), builder_usage_);
return Status::OK();
}
@ -59,9 +61,9 @@ Status ManifestOp::Builder::SanityCheck() {
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size,
int64_t num_samples, bool decode, const std::map<std::string, int32_t> &class_index,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler, std::string usage)
ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler, std::string usage)
: ParallelOp(num_works, queue_size),
rows_per_buffer_(rows_per_buffer),
io_block_pushed_(0),
@ -71,8 +73,6 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f
file_(file),
class_index_(class_index),
sampler_(std::move(sampler)),
num_samples_(num_samples),
num_rows_(0),
decode_(decode),
usage_(usage),
buf_cnt_(0) {
@ -101,8 +101,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
RETURN_IF_NOT_OK((*sampler_buffer)->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
keys.push_back(*itr);
row_cnt_++;
if (row_cnt_ % rows_per_buffer_ == 0) {
@ -269,28 +268,6 @@ Status ManifestOp::InitSampler() {
return Status::OK();
}
// Derived from RandomAccessOp
Status ManifestOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_samples_;
return Status::OK();
}
// Derived from RandomAccessOp
Status ManifestOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_rows_;
return Status::OK();
}
// Derived from RandomAccessOp
Status ManifestOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
if (cls_ids == nullptr || !cls_ids->empty() || image_labelname_.empty()) {
@ -408,7 +385,6 @@ Status ManifestOp::CountDatasetInfo() {
}
num_rows_ = static_cast<int64_t>(image_labelname_.size());
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
@ -417,8 +393,8 @@ Status ManifestOp::CountDatasetInfo() {
return Status::OK();
}
Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, const py::dict &dict,
const std::string &usage, int64_t *count, int64_t *numClasses) {
Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage,
int64_t *count, int64_t *numClasses) {
// the logic of counting the number of samples is copied from ParseManifestFile()
std::map<std::string, int32_t> map;
for (auto p : dict) {
@ -428,17 +404,15 @@ Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, c
std::shared_ptr<ManifestOp> op;
*count = 0;
RETURN_IF_NOT_OK(
Builder().SetManifestFile(file).SetNumSamples(numSamples).SetClassIndex(map).SetUsage(usage).Build(&op));
RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op));
RETURN_IF_NOT_OK(op->ParseManifestFile());
*numClasses = static_cast<int64_t>(op->label_index_.size());
*count = static_cast<int64_t>(op->image_labelname_.size());
*count = (*count < numSamples || numSamples == 0) ? *count : numSamples;
return Status::OK();
}
Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples, const py::dict &dict,
const std::string &usage, std::map<std::string, int32_t> *output_class_indexing) {
Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
std::map<std::string, int32_t> *output_class_indexing) {
std::map<std::string, int32_t> input_class_indexing;
for (auto p : dict) {
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
@ -449,12 +423,7 @@ Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples,
*output_class_indexing = input_class_indexing;
} else {
std::shared_ptr<ManifestOp> op;
RETURN_IF_NOT_OK(Builder()
.SetManifestFile(file)
.SetNumSamples(numSamples)
.SetClassIndex(input_class_indexing)
.SetUsage(usage)
.Build(&op));
RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(input_class_indexing).SetUsage(usage).Build(&op));
RETURN_IF_NOT_OK(op->ParseManifestFile());
RETURN_IF_NOT_OK(op->CountDatasetInfo());
uint32_t count = 0;

View File

@ -86,14 +86,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
@ -129,7 +121,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
private:
std::shared_ptr<Sampler> builder_sampler_;
int64_t builder_num_samples_;
bool builder_decode_;
std::string builder_file_;
@ -147,8 +138,8 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @param std::string - file list of Manifest
// @param int32_t queue_size - connector queue size
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, int64_t num_samples,
bool decode, const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler, std::string usage);
// Destructor.
~ManifestOp() = default;
@ -164,16 +155,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get total number of Rows in dataset
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumRowsInDataset(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
@ -184,12 +165,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void Print(std::ostream &out, bool show_all) const override;
static Status CountTotalRows(const std::string &file, int64_t numSamples, const py::dict &dict,
const std::string &usage, int64_t *count, int64_t *numClasses);
static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count,
int64_t *numClasses);
// Get str-to-int mapping from label name to index
static Status GetClassIndexing(const std::string &file, int64_t numSamples, const py::dict &dict,
const std::string &usage, std::map<std::string, int32_t> *output_class_indexing);
static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
std::map<std::string, int32_t> *output_class_indexing);
private:
// Initialize Sampler, calls sampler->Init() within
@ -240,8 +221,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
std::string file_; // file that store the information of images
std::map<std::string, int32_t> class_index_;
std::shared_ptr<Sampler> sampler_;
int64_t num_samples_;
int64_t num_rows_;
bool decode_;
std::string usage_;
int64_t buf_cnt_;

View File

@ -91,7 +91,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
block_reader_(block_reader),
buffers_needed_(0),
buf_cnt_(0),
num_rows_(0),
ended_worker_(0),
buffer_water_mark_(0) {
io_blk_queues_.Init(num_workers_, op_connector_queue_size);

View File

@ -31,7 +31,7 @@ const int32_t kMnistLabelFileMagicNumber = 2049;
const int32_t kMnistImageRows = 28;
const int32_t kMnistImageCols = 28;
MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr) {
MnistOp::Builder::Builder() : builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
@ -41,7 +41,9 @@ MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr)
Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0;
int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
@ -49,9 +51,8 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
*ptr =
std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, builder_op_connector_size_,
builder_num_samples_, std::move(builder_schema_), std::move(builder_sampler_));
*ptr = std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_));
return Status::OK();
}
@ -60,17 +61,14 @@ Status MnistOp::Builder::SanityCheck() {
std::string err_msg;
err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : "";
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : "";
err_msg += builder_num_samples_ < 0 ? "Number of samples is set to negative\n" : "";
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, queue_size),
buf_cnt_(0),
row_cnt_(0),
num_rows_(0),
num_samples_(num_samples),
folder_path_(folder_path),
rows_per_buffer_(rows_per_buffer),
sampler_(std::move(sampler)),
@ -84,8 +82,7 @@ MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folde
Status MnistOp::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys) {
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
keys->push_back(*itr);
row_cnt_++;
if (row_cnt_ % rows_per_buffer_ == 0) {
@ -219,17 +216,6 @@ Status MnistOp::InitSampler() {
return Status::OK();
}
// Derived from RandomAccessOp
Status MnistOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_samples_;
return Status::OK();
}
// Derived from RandomAccessOp
Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
@ -364,7 +350,6 @@ Status MnistOp::ParseMnistData() {
}
image_label_pairs_.shrink_to_fit();
num_rows_ = image_label_pairs_.size();
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
return Status::OK();
}
@ -414,11 +399,11 @@ Status MnistOp::LaunchThreadsAndInitOp() {
return Status::OK();
}
Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64_t *count) {
Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) {
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
std::shared_ptr<MnistOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetNumSamples(numSamples).Build(&op));
RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op));
RETURN_IF_NOT_OK(op->WalkAllFiles());
@ -440,19 +425,6 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64
label_reader.close();
}
*count = (numSamples == 0 || *count < numSamples) ? *count : numSamples;
return Status::OK();
}
// Derived from RandomAccessOp
Status MnistOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_rows_;
return Status::OK();
}
} // namespace dataset

View File

@ -78,14 +78,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Setter method
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
@ -114,7 +106,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
private:
std::string builder_dir_;
int32_t builder_num_workers_;
int64_t builder_num_samples_;
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::shared_ptr<Sampler> builder_sampler_;
@ -126,11 +117,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
// @param std::string folder_path - dir directory of mnist
// @param int32_t queue_size - connector queue size
// @param int64_t num_samples - number of samples to read
// @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
// @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
// Destructor.
~MnistOp() = default;
@ -146,16 +136,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get numRows
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
// @param int64_t num - to return numRows
// @return Status - The error code return
Status GetNumRowsInDataset(int64_t *num) const override;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
@ -167,11 +147,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
void Print(std::ostream &out, bool show_all) const override;
// Function to count the number of samples in the MNIST dataset
// @param dir path to the MNSIT directory
// @param numSamples maximum number of samples requested
// @param dir path to the MNIST directory
// @param count output arg that will hold the minimum of the actual dataset size and numSamples
// @return
static Status CountTotalRows(const std::string &dir, int64_t numSamples, int64_t *count);
static Status CountTotalRows(const std::string &dir, int64_t *count);
private:
// Initialize Sampler, calls sampler->Init() within
@ -244,9 +223,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
int64_t buf_cnt_;
int64_t row_cnt_;
int64_t num_rows_; // total number of images in Mnist
WaitPost wp_;
int64_t num_samples_;
std::string folder_path_; // directory of image folder
int32_t rows_per_buffer_;
std::shared_ptr<Sampler> sampler_;

View File

@ -8,6 +8,5 @@ add_library(engine-datasetops-source-sampler OBJECT
sampler.cc
sequential_sampler.cc
subset_random_sampler.cc
subset_sampler.cc
weighted_random_sampler.cc
)

View File

@ -23,8 +23,9 @@
namespace mindspore {
namespace dataset {
DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shuffle, uint32_t seed)
: Sampler(),
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed)
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
cnt_(0),
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
device_id_(dev_id),
@ -32,6 +33,11 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu
shuffle_(shuffle) {}
Status DistributedSampler::InitSampler() {
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
num_samples_ = num_rows_;
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,

View File

@ -27,10 +27,11 @@ namespace mindspore {
namespace dataset {
class DistributedSampler : public Sampler {
public:
// @param int64_t numDev
// @param int64_t devId
// @param num_samples
// @param int64_t num_dev
// @param int64_t dev_id
// @param bool shuffle
DistributedSampler(int64_t num_dev, int64_t dev_id, bool shuffle = true,
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
uint32_t seed = std::numeric_limits<uint32_t>::max());
// default destructor

View File

@ -20,12 +20,11 @@
namespace mindspore {
namespace dataset {
PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer)
: Sampler(samples_per_buffer),
PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer),
shuffle_(shuffle),
seed_(GetSeed()),
next_id_(0),
num_pk_samples_(0),
samples_per_class_(val) {}
Status PKSampler::InitSampler() {
@ -36,22 +35,34 @@ Status PKSampler::InitSampler() {
}
}
rnd_.seed(seed_++);
num_pk_samples_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_;
num_samples_ = num_pk_samples_;
// The special handshake gives the list of classes and id's, but it did not set the num_rows_ to
// capture the total number of possible sample ids.
// Compute that here for this case to find the total number of samples that are available to return.
// (in this case, samples per class * total classes).
num_rows_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
// The user may have chosen to sample less than the total amount.
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
num_samples_ = num_rows_;
}
samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_;
if (shuffle_ == true) {
std::shuffle(labels_.begin(), labels_.end(), rnd_);
} else {
std::sort(labels_.begin(), labels_.end());
}
CHECK_FAIL_RETURN_UNEXPECTED(num_pk_samples_ > 0, "num_class or K (num samples per class) is not positive");
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_class or K (num samples per class) is not positive");
return Status::OK();
}
Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (next_id_ > num_pk_samples_ || num_pk_samples_ == 0) {
if (next_id_ > num_samples_ || num_samples_ == 0) {
RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler");
} else if (next_id_ == num_pk_samples_) {
} else if (next_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
@ -60,8 +71,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sample_ids;
int64_t last_id =
(samples_per_buffer_ + next_id_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_ + next_id_;
int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_));
int64_t *id_ptr = reinterpret_cast<int64_t *>(sample_ids->GetMutableBuffer());
while (next_id_ < last_id) {
@ -85,7 +95,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status PKSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late");
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
rnd_.seed(seed_++);

View File

@ -28,10 +28,11 @@ namespace mindspore {
namespace dataset {
class PKSampler : public Sampler { // NOT YET FINISHED
public:
// @param int64_t kVal
// @param num_samples - the number of samples to draw. value of 0 means to take the full amount
// @param int64_t val
// @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PKSampler(int64_t val, bool shuffle = false,
explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// default destructor
@ -42,8 +43,9 @@ class PKSampler : public Sampler { // NOT YET FINISHED
// @return - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
// first handshake between StorageOp and Sampler
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
// @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
// @return
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
@ -58,7 +60,6 @@ class PKSampler : public Sampler { // NOT YET FINISHED
bool shuffle_;
uint32_t seed_;
int64_t next_id_;
int64_t num_pk_samples_;
int64_t samples_per_class_;
std::mt19937 rnd_;
std::vector<int64_t> labels_;

View File

@ -20,8 +20,8 @@
namespace mindspore {
namespace dataset {
PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer)
: Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (need_to_reset_) {
@ -65,6 +65,11 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
Status PythonSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0");
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
num_samples_ = num_rows_;
}
{
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {

View File

@ -26,8 +26,11 @@ namespace dataset {
class PythonSampler : public Sampler {
public:
// Constructor
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PythonSampler(py::object py_sampler_instance,
// @param num_samples - the number of samples to draw. Value of 0 means to sample all of the
// data from the dataset.
// @param py_sampler_instance - the python instance of the sampler
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor.

View File

@ -22,12 +22,11 @@
namespace mindspore {
namespace dataset {
RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples,
RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer)
: Sampler(samples_per_buffer),
: Sampler(num_samples, samples_per_buffer),
seed_(GetSeed()),
replacement_(replacement),
user_num_samples_(num_samples),
next_id_(0),
reshuffle_each_epoch_(reshuffle_each_epoch),
dist(nullptr) {}
@ -70,27 +69,25 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status RandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive.");
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
num_samples_ = num_rows_;
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
rnd_.seed(seed_);
if (replacement_ == false) {
num_samples_ = std::min(num_samples_, num_rows_);
num_samples_ = std::min(num_samples_, user_num_samples_);
shuffled_ids_.reserve(num_rows_);
for (int64_t i = 0; i < num_rows_; i++) {
shuffled_ids_.push_back(i);
}
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
} else {
num_samples_ = std::min(num_samples_, user_num_samples_);
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive.");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();
}
@ -119,7 +116,6 @@ void RandomSampler::Print(std::ostream &out, bool show_all) const {
out << "(sampler): RandomSampler\n";
if (show_all) {
out << "user_num_samples_: " << user_num_samples_ << '\n';
out << "num_samples_: " << num_samples_ << '\n';
out << "next_id_: " << next_id_ << '\n';
}

View File

@ -27,11 +27,11 @@ namespace dataset {
class RandomSampler : public Sampler {
public:
// Constructor
// @param int64_t num_samples - number samples to draw
// @param bool replacement - put he id back / or not after a sample
// @param int64_t numSamples - number samples to draw
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true,
int64_t num_samples = std::numeric_limits<int64_t>::max(),
// @param reshuffle_each_epoch - T/F to reshuffle after epoch
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor.
@ -55,7 +55,6 @@ class RandomSampler : public Sampler {
private:
uint32_t seed_;
bool replacement_;
int64_t user_num_samples_;
std::vector<int64_t> shuffled_ids_; // only used for NO REPLACEMENT
int64_t next_id_;
std::mt19937 rnd_;

View File

@ -19,8 +19,25 @@
namespace mindspore {
namespace dataset {
Sampler::Sampler(int64_t samples_per_buffer)
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
// The sampler base class itself does not compute it's own num_rows_ value.
// Instead, this value is computed by the derived leaf op during it's own initialization
// after it has interacted with it's storage layers.
// Here, it is just a getter method to return the value. However, it is invalid if there is
// not a value set for this count, so generate a failure if that is the case.
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet.");
}
(*num) = num_rows_;
return Status::OK();
}
Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer)
: DatasetOp(0),
num_rows_(0),
num_samples_(num_samples),
samples_per_buffer_(samples_per_buffer),
col_desc_(nullptr) {}
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
std::shared_ptr<Sampler> child_sampler;
@ -36,10 +53,10 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
}
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
// If there's a child sampler, set the row count to be it's sample count
if (HasChildSampler()) {
int64_t child_num_samples = child_sampler->num_samples();
num_rows_ = child_num_samples;
num_rows_ = child_sampler->num_samples_;
} else {
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
}
@ -105,7 +122,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
}
Status Sampler::SetNumSamples(int64_t num_samples) {
CHECK_FAIL_RETURN_UNEXPECTED(num_samples > 0, "num_samples is negative or 0");
CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "num_samples is negative");
num_samples_ = num_samples;
return Status::OK();
}
@ -116,6 +133,16 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
return Status::OK();
}
// inline op doesn't have it's own consumer, it's assigned from parent
int32_t Sampler::num_consumers() const {
if (parent_.empty() || parent_[0] == nullptr) {
MS_LOG(WARNING) << "Sampler with no parent. num_consumers is 0.";
return 0;
} else {
return parent_[0]->num_consumers();
}
}
Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) {
if (child == nullptr) {
return Status::OK();
@ -155,5 +182,14 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
return Status::OK();
}
// inline op doesn't have it's own producers, it's assigned from child
int32_t Sampler::num_producers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(WARNING) << "Sampler with no child, num_producers is 0.";
return 0;
} else {
return child_[0]->num_producers();
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -33,23 +33,10 @@ namespace dataset {
// must inherit from if those leaf operator wish to support sampling.
class RandomAccessOp {
public:
// Sampler get numRows from StorageOp
// @param int64_t num - return number of rows, normally num of samples
// @return - The error code return
virtual Status GetNumSamples(int64_t *num_samples) const {
// CI complains num_samples not used if the following line is not added
CHECK_FAIL_RETURN_UNEXPECTED(num_samples != nullptr, "num_samples == nullptr");
RETURN_STATUS_UNEXPECTED("function GetNumSamples needs to overridden to support this sampler");
}
// Sampler get number of rows in the dataset!
// Sampler get number of rows in the dataset
// @param int64_t num - return number of rows for this dataset
// @return - The error code return
virtual Status GetNumRowsInDataset(int64_t *num_rows) const {
// CI complains num_rows not used if the following line is not added
CHECK_FAIL_RETURN_UNEXPECTED(num_rows != nullptr, "num_rows == nullptr");
RETURN_STATUS_UNEXPECTED("function GetNumRowsInDataset needs to overridden to support this sampler");
}
Status GetNumRowsInDataset(int64_t *num_rows) const;
// sampler gets label , imageIds from storageOp, this function is unique to PK
// @param std::map<int64_t, std::vector<int64_t>> * map
@ -60,12 +47,20 @@ class RandomAccessOp {
// default destructor
virtual ~RandomAccessOp() = default;
protected:
// The amount of rows in the dataset itself. This is the before-sampling value, the
// total count of rows. A sampler may choose to sample less than this amount.
int64_t num_rows_;
};
class Sampler : public DatasetOp {
public:
// Constructor
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
// indicates that the sampler should produce the complete set of ids.
// @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit Sampler(int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit Sampler(int64_t num_samples, int64_t samples_per_buffer);
// default destructor
~Sampler() = default;
@ -84,33 +79,36 @@ class Sampler : public DatasetOp {
// @return - The error code return
Status Reset() override = 0;
// setter function for num_rows_
Status SetNumRowsInDataset(int64_t num_rows);
// setter function for num_samples_
Status SetNumSamples(int64_t num_samples);
int64_t num_samples() { return num_samples_; }
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
// first handshake between leaf source op and Sampler. This func will determine the amount of data
// in the dataset that we can sample from.
// @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
// @return
virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op);
// initialize sampler and perform checks on certain vars
virtual Status InitSampler() { return Status::OK(); }
// Not meant to be called
// setter for num samples
// @param num_samples - the number of samples to assign.
// @return status error code
Status SetNumSamples(int64_t num_samples);
// setter for num or records in the dataset
// @param num_rows - the number of records
// @return status error code
Status SetNumRowsInDataset(int64_t num_rows);
// Sampler is an inlined op and has no workers. Producers and consumers are computed.
// @return
int32_t num_workers() const final { return 0; }
// Not meant to be called
// Identify num consumers (inlined op)
// @return
int32_t num_consumers() const final { return 0; }
int32_t num_consumers() const final;
// Not meant to be called
// Identify num producers (inlined op)
// @return
int32_t num_producers() const final { return 0; }
int32_t num_producers() const final;
// Not meant to be called!
// @return - The error code return
@ -151,10 +149,11 @@ class Sampler : public DatasetOp {
// output. Otherwise, num_rows_ is the number of rows in the dataset.
int64_t num_rows_;
// Number of ids this sampler will return.
// The user may want to sample less than the full amount of data. num_samples_ reduces the number
// of id's returned as request by the user. Derived classes will choose how to sample the smaller
// amount.
int64_t num_samples_;
// The max number of ids a DataBuffer returned by this sampler will contain.
int64_t samples_per_buffer_;
std::unique_ptr<ColDescriptor> col_desc_;
std::unique_ptr<DataBuffer> child_ids_;

View File

@ -20,34 +20,42 @@
namespace mindspore {
namespace dataset {
SequentialSampler::SequentialSampler(int64_t samples_per_buffer) : Sampler(samples_per_buffer), next_id_(0) {}
SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {}
Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (next_id_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("Sequential Sampler Internal Error");
} else if (next_id_ == num_samples_) {
if (id_count_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error");
} else if (id_count_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
(*out_buffer) = std::make_unique<DataBuffer>(current_id_, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampleIds;
int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_));
// Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for
// samples per buffer though.
int64_t remaining_ids = num_samples_ - id_count_;
int64_t num_elements = std::min(remaining_ids, samples_per_buffer_);
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements));
int64_t *idPtr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
while (next_id_ < lastId) {
int64_t sampled_id = next_id_;
for (int64_t i = 0; i < num_elements; i++) {
int64_t sampled_id = current_id_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*idPtr = sampled_id;
next_id_++;
current_id_++; // Move the current id to the next one in the sequence
idPtr++;
}
id_count_ += num_elements; // Count the packed ids towards our overall sample count
TensorRow row(1, sampleIds);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
}
@ -55,19 +63,24 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
}
Status SequentialSampler::InitSampler() {
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
if (HasChildSampler()) {
num_samples_ = std::min(num_samples_, num_rows_);
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n");
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0, "num_samples < 0\n");
// Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample
// the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data.
int64_t available_row_count = num_rows_ - start_index_;
if (num_samples_ == 0 || num_samples_ > available_row_count) {
num_samples_ = available_row_count;
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();
}
Status SequentialSampler::Reset() {
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
next_id_ = 0;
CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late");
current_id_ = start_index_;
id_count_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());

View File

@ -26,8 +26,12 @@ namespace dataset {
class SequentialSampler : public Sampler {
public:
// Constructor
// @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the
// full amount of ids from the dataset
// @param start_index - The starting index value
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit SequentialSampler(int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
explicit SequentialSampler(int64_t num_samples, int64_t start_index,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor.
~SequentialSampler() = default;
@ -48,7 +52,9 @@ class SequentialSampler : public Sampler {
void Print(std::ostream &out, bool show_all) const override;
private:
int64_t next_id_;
int64_t current_id_; // The id sequencer. Each new id increments from this
int64_t start_index_; // The starting id. current_id_ begins from here.
int64_t id_count_; // An internal counter that tracks how many ids have been produced
};
} // namespace dataset
} // namespace mindspore

View File

@ -27,22 +27,28 @@
namespace mindspore {
namespace dataset {
// Constructor.
SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t samples_per_buffer)
: Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}
SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices,
int64_t samples_per_buffer)
: Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}
// Initialized this Sampler.
Status SubsetRandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
num_samples_ = indices_.size();
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// In this case, the id's are provided by the user. Cap the num_samples on the number of id's given.
if (num_samples_ == 0 || num_samples_ > static_cast<int64_t>(indices_.size())) {
num_samples_ = static_cast<int64_t>(indices_.size());
}
// Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed());
if (static_cast<size_t>(samples_per_buffer_) > indices_.size()) {
samples_per_buffer_ = static_cast<int64_t>(indices_.size());
if (samples_per_buffer_ > num_samples_) {
samples_per_buffer_ = num_samples_;
}
// num_samples_ could be smaller than the total number of input id's.
// We will shuffle the full set of id's, but only select the first num_samples_ of them later.
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
return Status::OK();
@ -68,7 +74,7 @@ Status SubsetRandomSampler::Reset() {
// Get the sample ids.
Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
// All samples have been drawn
if (sample_id_ == indices_.size()) {
if (sample_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
@ -80,8 +86,8 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
int64_t last_id = sample_id_ + samples_per_buffer_;
// Handling the return all samples at once, and when last draw is not a full batch.
if (static_cast<size_t>(last_id) > indices_.size()) {
last_id = indices_.size();
if (last_id > num_samples_) {
last_id = num_samples_;
}
// Allocate tensor

View File

@ -28,10 +28,11 @@ namespace dataset {
class SubsetRandomSampler : public Sampler {
public:
// Constructor.
// @param num_samples The number of samples to draw. 0 for the full amount.
// @param indices List of indices from where we will randomly draw samples.
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
explicit SubsetRandomSampler(const std::vector<int64_t> &indices,
explicit SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices,
std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor.

View File

@ -1,85 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include <memory>
#include <string>
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
namespace mindspore {
namespace dataset {
// Constructor.
SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
Status SubsetSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
num_samples_ = subset_size_;
return Status::OK();
}
Status SubsetSampler::Reset() {
current_id_ = 0;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->Reset());
}
return Status::OK();
}
Status SubsetSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
if (current_id_ > subset_size_) {
RETURN_STATUS_UNEXPECTED("SubsetSampler Internal Error");
} else if (current_id_ == subset_size_) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
}
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
std::shared_ptr<Tensor> sampled_ids;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampled_ids, subset_size_));
int64_t *sampled_ids_start_addr = reinterpret_cast<int64_t *>(sampled_ids->GetMutableBuffer());
while (current_id_ < subset_size_) {
int64_t sampled_id = start_index_ + current_id_;
if (HasChildSampler()) {
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
}
*(sampled_ids_start_addr + current_id_) = sampled_id;
current_id_++;
}
TensorRow sampled_ids_row(1, sampled_ids);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, sampled_ids_row));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -1,58 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
#include <memory>
#include <vector>
#include "dataset/engine/datasetops/source/sampler/sampler.h"
namespace mindspore {
namespace dataset {
class SubsetSampler : public Sampler {
public:
// Constructor.
// @param start_index The index we start sampling from.
explicit SubsetSampler(int64_t start_index, int64_t subset_size);
// Destructor.
~SubsetSampler() = default;
// Initialize the sampler.
// @return Status
Status InitSampler() override;
// Reset the internal variable to the initial state and reshuffle the indices.
// @return Status
Status Reset() override;
// Get the sample ids.
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
// @note the sample ids (int64_t) will be placed in one Tensor.
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
private:
int64_t start_index_;
int64_t subset_size_;
int64_t current_id_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_

View File

@ -27,25 +27,28 @@
namespace mindspore {
namespace dataset {
// Constructor.
WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement,
WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement,
int64_t samples_per_buffer)
: Sampler(samples_per_buffer),
: Sampler(num_samples, samples_per_buffer),
weights_(weights),
replacement_(replacement),
sample_id_(0),
buffer_id_(0),
user_num_samples_(num_samples) {}
buffer_id_(0) {}
// Initialized this Sampler.
Status WeightedRandomSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive");
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
if (num_samples_ == 0 || num_samples_ > num_rows_) {
num_samples_ = num_rows_;
}
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, "num_samples & num_rows need to be positive");
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n");
num_samples_ = user_num_samples_;
// Initialize random generator with seed from config manager
rand_gen_.seed(GetSeed());
samples_per_buffer_ = (samples_per_buffer_ > user_num_samples_) ? user_num_samples_ : samples_per_buffer_;
samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_;
if (!replacement_) {
exp_dist_ = std::make_unique<std::exponential_distribution<>>(1);
@ -67,8 +70,8 @@ void WeightedRandomSampler::InitOnePassSampling() {
}
// Partial sort the first `numSamples` elements.
std::partial_sort(val_idx.begin(), val_idx.begin() + user_num_samples_, val_idx.end());
for (int64_t i = 0; i < user_num_samples_; i++) {
std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end());
for (int64_t i = 0; i < num_samples_; i++) {
onepass_ids_.push_back(val_idx[i].second);
}
}
@ -98,11 +101,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors");
}
if (!replacement_ && (weights_.size() < static_cast<size_t>(user_num_samples_))) {
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples");
}
if (sample_id_ == user_num_samples_) {
if (sample_id_ == num_samples_) {
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
} else {
if (HasChildSampler()) {
@ -114,8 +117,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
int64_t last_id = sample_id_ + samples_per_buffer_;
// Handling the return all samples at once, and when last draw is not a full batch.
if (last_id > user_num_samples_) {
last_id = user_num_samples_;
if (last_id > num_samples_) {
last_id = num_samples_;
}
// Allocate tensor.

View File

@ -29,12 +29,12 @@ namespace dataset {
class WeightedRandomSampler : public Sampler {
public:
// Constructor.
// @param weights A lift of sample weights.
// @param num_samples Number of samples to be drawn.
// @param weights A lift of sample weights.
// @param replacement Determine if samples are drawn with/without replacement.
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement = true,
WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement,
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
// Destructor.
@ -69,9 +69,6 @@ class WeightedRandomSampler : public Sampler {
// Random engine and device
std::mt19937 rand_gen_;
// num_samples from user
int64_t user_num_samples_;
// Discrete distribution for generating weighted random numbers with replacement.
std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_;

View File

@ -33,7 +33,7 @@
namespace mindspore {
namespace dataset {
TextFileOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -62,7 +62,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_,
std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_,
builder_num_devices_, builder_device_id_);
RETURN_IF_NOT_OK(text_file_op->Init());
@ -71,14 +71,14 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
return Status::OK();
}
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
num_samples_(num_samples),
total_rows_(total_rows),
text_files_list_(std::move(text_files_list)),
shuffle_files_(shuffle_files),
data_schema_(std::move(schema)),
@ -104,9 +104,9 @@ void TextFileOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nText files list:\n";
out << "\nRows per buffer: " << rows_per_buffer_ << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_
<< "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no")
<< "\nText files list:\n";
for (int i = 0; i < text_files_list_.size(); ++i) {
out << " " << text_files_list_[i];
}
@ -404,9 +404,9 @@ Status TextFileOp::operator()() {
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (num_samples_ == 0 || rows_read < num_samples_) {
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
} else if (total_rows_ == 0 || rows_read < total_rows_) {
if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) {
int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();

View File

@ -107,8 +107,8 @@ class TextFileOp : public ParallelOp {
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
Builder &SetTotalRows(int64_t total_rows) {
builder_total_rows_ = total_rows;
return *this;
}
@ -118,7 +118,7 @@ class TextFileOp : public ParallelOp {
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
int64_t builder_num_samples_;
int64_t builder_total_rows_;
int32_t builder_worker_connector_size_;
std::vector<std::string> builder_text_files_list_;
bool builder_shuffle_files_;
@ -136,7 +136,7 @@ class TextFileOp : public ParallelOp {
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id);
@ -246,7 +246,7 @@ class TextFileOp : public ParallelOp {
int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;
int64_t num_samples_;
int64_t total_rows_;
std::vector<std::string> text_files_list_;
bool shuffle_files_;
std::unique_ptr<DataSchema> data_schema_;

View File

@ -44,7 +44,7 @@ const char kSegmentationExtension[] = ".png";
const char kAnnotationExtension[] = ".xml";
const char kImageSetsExtension[] = ".txt";
VOCOp::Builder::Builder() : builder_decode_(false), builder_num_samples_(0), builder_sampler_(nullptr) {
VOCOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
@ -55,7 +55,9 @@ VOCOp::Builder::Builder() : builder_decode_(false), builder_num_samples_(0), bui
Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<SequentialSampler>();
int64_t num_samples = 0;
int64_t start_index = 0;
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
}
builder_schema_ = std::make_unique<DataSchema>();
if (builder_task_type_ == TaskType::Segmentation) {
@ -71,8 +73,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
}
*ptr = std::make_shared<VOCOp>(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_,
builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_,
builder_num_samples_, builder_decode_, std::move(builder_schema_),
std::move(builder_sampler_));
builder_decode_, std::move(builder_schema_), std::move(builder_sampler_));
return Status::OK();
}
@ -81,20 +82,16 @@ Status VOCOp::Builder::SanityCheck() {
std::string err_msg;
err_msg += dir.IsDirectory() == false ? "VOC path is invalid or not set\n" : "";
err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : "";
err_msg += builder_num_samples_ < 0 ? "num_samples is negative\n" : "";
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}
VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path,
const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer,
int32_t queue_size, int64_t num_samples, bool decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, queue_size),
decode_(decode),
row_cnt_(0),
buf_cnt_(0),
num_rows_(0),
num_samples_(num_samples),
task_type_(task_type),
task_mode_(task_mode),
folder_path_(folder_path),
@ -112,7 +109,6 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std:
Status VOCOp::TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys) {
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
if ((*itr) > num_rows_) continue;
if (row_cnt_ == num_samples_) break;
keys->push_back(*itr);
row_cnt_++;
if (row_cnt_ % rows_per_buffer_ == 0) {
@ -187,16 +183,6 @@ Status VOCOp::Reset() {
return Status::OK();
}
Status VOCOp::GetNumSamples(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API VOCDataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_samples_;
return Status::OK();
}
Status VOCOp::LoadTensorRow(const std::string &image_id, TensorRow *trow) {
if (task_type_ == TaskType::Segmentation) {
std::shared_ptr<Tensor> image, target;
@ -280,7 +266,6 @@ Status VOCOp::ParseImageIds() {
in_file.close();
image_ids_.shrink_to_fit();
num_rows_ = image_ids_.size();
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
return Status::OK();
}
@ -305,7 +290,6 @@ Status VOCOp::ParseAnnotationIds() {
}
num_rows_ = image_ids_.size();
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
return Status::OK();
}
@ -432,19 +416,8 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescripto
return Status::OK();
}
// Derived from RandomAccessOp
Status VOCOp::GetNumRowsInDataset(int64_t *num) const {
if (num == nullptr || num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API VOCDataset.Please check file path or dataset API "
"validation first.");
}
(*num) = num_rows_;
return Status::OK();
}
Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples, int64_t *count) {
const py::dict &dict, int64_t *count) {
if (task_type == "Detection") {
std::map<std::string, int32_t> input_class_indexing;
for (auto p : dict) {
@ -464,14 +437,12 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ
RETURN_IF_NOT_OK(op->ParseImageIds());
*count = static_cast<int64_t>(op->image_ids_.size());
}
*count = (numSamples == 0 || *count < numSamples) ? *count : numSamples;
return Status::OK();
}
Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples,
std::map<std::string, int32_t> *output_class_indexing) {
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing) {
std::map<std::string, int32_t> input_class_indexing;
for (auto p : dict) {
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),

View File

@ -116,14 +116,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Setter method.
// @param int64_t num_samples
// @return Builder setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method.
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
@ -157,7 +149,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int32_t builder_rows_per_buffer_;
int64_t builder_num_samples_;
std::shared_ptr<Sampler> builder_sampler_;
std::unique_ptr<DataSchema> builder_schema_;
std::map<std::string, int32_t> builder_labels_to_read_;
@ -171,14 +162,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param int32_t num_workers - number of workers reading images in parallel
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
// @param int32_t queue_size - connector queue size
// @param int64_t num_samples - number of samples to read
// @param bool decode - whether to decode images
// @param std::unique_ptr<DataSchema> data_schema - the schema of the VOC dataset
// @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read
VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path,
const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer,
int32_t queue_size, int64_t num_samples, bool decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler);
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
// Destructor
~VOCOp() = default;
@ -194,15 +183,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @return Status - The error code return
Status operator()() override;
// Method derived from RandomAccessOp, enable Sampler to get numRows
// @param uint64_t num - to return numRows
// return Status - The error code return
Status GetNumSamples(int64_t *num) const override;
// Method derived from RandomAccessOp, enable Sampler to get total number of rows in dataset
// @param uint64_t num - to return numRows
Status GetNumRowsInDataset(int64_t *num) const override;
// A print method typically used for debugging
// @param out
// @param show_all
@ -212,10 +192,9 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
// @param const py::dict &dict - input dict of class index
// @param int64_t numSamples - samples number of VOCDataset
// @param int64_t *count - output rows number of VOCDataset
static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples, int64_t *count);
const py::dict &dict, int64_t *count);
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
@ -224,8 +203,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param int64_t numSamples - samples number of VOCDataset
// @param std::map<std::string, int32_t> *output_class_indexing - output class index of VOCDataset
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples,
std::map<std::string, int32_t> *output_class_indexing);
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing);
private:
// Initialize Sampler, calls sampler->Init() within
@ -283,8 +261,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
bool decode_;
int64_t row_cnt_;
int64_t buf_cnt_;
int64_t num_rows_;
int64_t num_samples_;
std::string folder_path_;
TaskType task_type_;
std::string task_mode_;

View File

@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, SubsetSampler, Sampler
WeightedRandomSampler, Sampler
from .engine.serializer_deserializer import serialize, deserialize, show
from .engine.graphdata import GraphData

View File

@ -1261,8 +1261,8 @@ class MappableDataset(SourceDataset):
def _get_sampler_dataset_size(self):
if self.sampler is not None:
if hasattr(self.sampler, 'get_dataset_size'):
return self.sampler.get_dataset_size()
if hasattr(self.sampler, 'get_num_samples'):
return self.sampler.get_num_samples()
if hasattr(self.sampler, '__len__'):
return len(self.sampler)
@ -1355,7 +1355,7 @@ class MappableDataset(SourceDataset):
random_sampler.reshuffle_each_epoch = False
ds.add_sampler(random_sampler)
subset_sampler = samplers.SubsetSampler(current_split_start_index, size)
subset_sampler = samplers.SequentialSampler(current_split_start_index, size)
ds.add_sampler(subset_sampler)
# add sequential sampler, so that if user calls use_sampler, we will
@ -2226,31 +2226,45 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID.
"""
if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler.
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
# be None. Consider this example:
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
# In this case, the user has given different sample-related arguments that contradict each other.
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler, samplers.Sampler)) and
(num_shards is not None or shard_id is not None or shuffle is not None or num_samples is not None)):
raise ValueError(
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle))
return input_sampler
if shuffle is None:
if input_sampler is not None:
# If shuffle is not specified, user provided sampler, use user's sampler
return input_sampler
if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler
shuffle = True
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle is not specified, sharding disabled, use random sampler
if num_samples is not None:
return samplers.RandomSampler(replacement=True, num_samples=num_samples)
return samplers.RandomSampler()
return samplers.RandomSampler(num_samples=num_samples)
if shuffle is True:
if num_shards is not None:
# If shuffle enabled, sharding enabled, use distributed random sampler
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle enabled, sharding disabled, use random sampler
if num_samples is not None:
return samplers.RandomSampler(replacement=True, num_samples=num_samples)
return samplers.RandomSampler()
return samplers.RandomSampler(num_samples=num_samples)
if num_shards is not None:
# If shuffle disabled, sharding enabled, use distributed sequential sampler
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle disabled, sharding disabled, use sequential sampler
return samplers.SequentialSampler()
return samplers.SequentialSampler(num_samples=num_samples)
class ImageFolderDatasetV2(MappableDataset):
@ -2370,11 +2384,7 @@ class ImageFolderDatasetV2(MappableDataset):
Return:
Number, number of batches.
"""
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0]
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0]
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
@ -2390,11 +2400,7 @@ class ImageFolderDatasetV2(MappableDataset):
Return:
Number, number of classes.
"""
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[1]
return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[1]
def is_shuffled(self):
if self.shuffle_level is None:
@ -2503,12 +2509,7 @@ class MnistDataset(MappableDataset):
Return:
Number, number of batches.
"""
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples)
num_rows = MnistOp.get_num_rows(self.dataset_dir)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
@ -2956,11 +2957,8 @@ class GeneratorDataset(MappableDataset):
if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler, samplers.Sampler)):
if num_samples is None:
num_samples = len(source)
sampler_instance = self.sampler.create()
sampler_instance.set_num_rows(len(source))
sampler_instance.set_num_samples(num_samples)
sampler_instance.initialize()
if num_parallel_workers > 1:
self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers))
@ -3304,17 +3302,12 @@ class ManifestDataset(MappableDataset):
Return:
Number, number of batches.
"""
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0]
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0]
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
@ -3330,17 +3323,12 @@ class ManifestDataset(MappableDataset):
Return:
Number, number of classes.
"""
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing
return ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[1]
return ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[1]
def get_class_indexing(self):
"""
@ -3349,17 +3337,12 @@ class ManifestDataset(MappableDataset):
Return:
Dict, A str-to-int mapping from label name to index.
"""
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing
return ManifestOp.get_class_indexing(self.dataset_file, num_samples, class_indexing, self.usage)
return ManifestOp.get_class_indexing(self.dataset_file, class_indexing, self.usage)
def is_shuffled(self):
if self.shuffle_level is None:
@ -3473,12 +3456,8 @@ class Cifar10Dataset(MappableDataset):
Return:
Number, number of batches.
"""
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True)
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
@ -3597,12 +3576,8 @@ class Cifar100Dataset(MappableDataset):
Return:
Number, number of batches.
"""
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False)
num_rows = CifarOp.get_num_rows(self.dataset_dir, False)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
@ -3631,7 +3606,7 @@ class RandomDataset(SourceDataset):
Args:
num_samples (int): number of samples to generate.
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
If the schema is not provided, the meta data from the TFRecord file is considered the schema.
If the schema is not provided, the random dataset generates a random schema.
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
@ -3644,9 +3619,12 @@ class RandomDataset(SourceDataset):
schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
self.schema = schema
self.columns_list = columns_list
self.num_samples = num_samples
if schema_obj is not None and num_samples is None:
self.num_samples = schema_obj.num_rows
elif num_samples is None:
self.num_samples = 0
else:
self.num_samples = num_samples
def get_args(self):
args = super().get_args()
@ -4015,17 +3993,12 @@ class VOCDataset(MappableDataset):
if self.task != "Detection":
raise NotImplementedError()
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing)
def is_shuffled(self):
if self.shuffle_level is None:
@ -4205,9 +4178,11 @@ class TextFileDataset(SourceDataset):
if self._dataset_size is None:
num_rows = TextFileOp.get_num_rows(self.dataset_files)
num_rows = get_num_rows(num_rows, self.num_shards)
if self.num_samples is None:
return num_rows
return min(self.num_samples, num_rows)
# If the user gave a num samples in the dataset, then the sampler will limit the rows returned
# to that amount. Account for that here in the row count
if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples:
num_rows = self.num_samples
return num_rows
return self._dataset_size
def is_shuffled(self):

View File

@ -22,7 +22,6 @@ User can also define custom sampler by extending from Sampler class.
import numpy as np
import mindspore._c_dataengine as cde
class Sampler:
"""
Base class for user defined sampler.
@ -44,10 +43,10 @@ class Sampler:
>>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler())
"""
def __init__(self):
def __init__(self, num_samples=None):
self.dataset_size = 0
self.num_samples = 0
self.child_sampler = None
self.num_samples = num_samples
def __iter__(self):
"""
@ -84,7 +83,8 @@ class Sampler:
# Instance fetcher
# Do not override this method!
def create(self):
c_sampler = cde.PythonSampler(self)
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.PythonSampler(num_samples, self)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -114,7 +114,7 @@ class Sampler:
return self.child_sampler.is_sharded()
def get_dataset_size(self):
def get_num_samples(self):
return self._get_indices().size
@ -124,8 +124,9 @@ class BuiltinSampler:
User should not extend this class.
"""
def __init__(self):
def __init__(self, num_samples=None):
self.child_sampler = None
self.num_samples = num_samples
def create(self):
pass
@ -149,11 +150,37 @@ class BuiltinSampler:
def is_sharded(self):
raise NotImplementedError("Sampler must implement is_sharded.")
def get_dataset_size(self):
if self.child_sampler is not None:
return self.child_sampler.get_dataset_size()
def get_num_samples(self):
"""
All samplers can contain a numeric num_samples value (or it could be set to None).
Child sampler can exist or be None.
if child sampler exists, then the child sampler count can be a numeric value or None.
Given these conditions, we need to output what the sampler count is for this sampler.
The following table shows the possible results from calling this function.
return None
child sampler num_samples child_samples result
------------- ----------- ------------- --------
T x y min(x, y)
T x None x
T None y y
T None None None
None x n/a x
None None n/a None
Returns:
int, The number of samples, or None
"""
if self.child_sampler is not None:
child_samples = self.child_sampler.get_num_samples()
if self.num_samples is not None:
if child_samples is not None:
return min(self.num_samples, child_samples)
return self.num_samples
return child_samples
return self.num_samples
class DistributedSampler(BuiltinSampler):
@ -164,6 +191,7 @@ class DistributedSampler(BuiltinSampler):
num_shards (int): Number of shards to divide the dataset into.
shard_id (int): Shard ID of the current shard within num_shards.
shuffle (bool, optional): If true, the indices are shuffled (default=True).
num_samples (int, optional): The number of samples to draw (default=None, all elements).
Examples:
>>> import mindspore.dataset as ds
@ -180,7 +208,7 @@ class DistributedSampler(BuiltinSampler):
ValueError: If shuffle is not a boolean value.
"""
def __init__(self, num_shards, shard_id, shuffle=True):
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None):
if num_shards <= 0:
raise ValueError("num_shards should be a positive integer value, but got num_shards={}".format(num_shards))
@ -194,12 +222,13 @@ class DistributedSampler(BuiltinSampler):
self.shard_id = shard_id
self.shuffle = shuffle
self.seed = 0
super().__init__()
super().__init__(num_samples)
def create(self):
num_samples = self.num_samples if self.num_samples is not None else 0
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
self.seed += 1
c_sampler = cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id, self.shuffle, self.seed)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -226,6 +255,7 @@ class PKSampler(BuiltinSampler):
num_class (int, optional): Number of classes to sample (default=None, all classes).
shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset.
num_samples (int, optional): The number of samples to draw (default=None, all elements).
Examples:
>>> import mindspore.dataset as ds
@ -242,7 +272,7 @@ class PKSampler(BuiltinSampler):
ValueError: If shuffle is not boolean.
"""
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'):
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
if num_val <= 0:
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
@ -255,10 +285,11 @@ class PKSampler(BuiltinSampler):
self.num_val = num_val
self.shuffle = shuffle
self.class_column = class_column # work for minddataset
super().__init__()
super().__init__(num_samples)
def create(self):
c_sampler = cde.PKSampler(self.num_val, self.shuffle)
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -309,23 +340,18 @@ class RandomSampler(BuiltinSampler):
raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement))
if num_samples is not None:
if num_samples <= 0:
if num_samples < 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
self.deterministic = False
self.replacement = replacement
self.num_samples = num_samples
self.reshuffle_each_epoch = True
super().__init__()
super().__init__(num_samples)
def create(self):
c_sampler = None
if self.num_samples is None:
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch)
else:
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch, self.num_samples)
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -339,14 +365,15 @@ class RandomSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.num_samples
class SequentialSampler(BuiltinSampler):
"""
Samples the dataset elements sequentially, same as not having a sampler.
Args:
start_index (int, optional): Index to start sampling at. (dafault=None starts at first id)
num_samples (int, optional): Number of elements to sample (default=None, all elements).
Examples:
>>> import mindspore.dataset as ds
>>>
@ -357,66 +384,14 @@ class SequentialSampler(BuiltinSampler):
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
"""
def create(self):
c_sampler = cde.SequentialSampler()
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
def is_shuffled(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_shuffled()
def is_sharded(self):
if self.child_sampler is None:
return False
return self.child_sampler.is_sharded()
class SubsetSampler(BuiltinSampler):
"""
Samples a subset of elements consecutively from a given index.
Args:
start_index (int): Index to start sampling at.
subset_size (int): How many samples to include in this subset.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> dataset_dir = "path/to/imagefolder_directory"
>>>
>>> # creates a SubsetSampler, will sample the next 5 images from the 100th image.
>>> sampler = ds.SubsetSampler(100, 5)
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
Raises:
ValueError: If start_index is not a positive int.
ValueError: If subset_size is not a positive int.
"""
def __init__(self, start_index, subset_size):
if not isinstance(start_index, int):
raise ValueError("start_index should be an int.")
if start_index < 0:
raise ValueError("start_index should not be negative.")
if not isinstance(subset_size, int):
raise ValueError("start_index should be an int")
if subset_size < 0:
raise ValueError("subset_size should not be negative.")
def __init__(self, start_index=None, num_samples=None):
self.start_index = start_index
self.subset_size = subset_size
super().__init__()
super().__init__(num_samples)
def create(self):
c_sampler = cde.SubsetSampler(self.start_index, self.subset_size)
start_index = self.start_index if self.start_index is not None else 0
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SequentialSampler(num_samples, start_index)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -433,9 +408,6 @@ class SubsetSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.subset_size
class SubsetRandomSampler(BuiltinSampler):
"""
@ -443,6 +415,7 @@ class SubsetRandomSampler(BuiltinSampler):
Args:
indices (list[int]): A sequence of indices.
num_samples (int, optional): Number of elements to sample (default=None, all elements).
Examples:
>>> import mindspore.dataset as ds
@ -456,15 +429,16 @@ class SubsetRandomSampler(BuiltinSampler):
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
"""
def __init__(self, indices):
def __init__(self, indices, num_samples=None):
if not isinstance(indices, list):
indices = [indices]
self.indices = indices
super().__init__()
super().__init__(num_samples)
def create(self):
c_sampler = cde.SubsetRandomSampler(self.indices)
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.SubsetRandomSampler(num_samples, self.indices)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -481,9 +455,9 @@ class SubsetRandomSampler(BuiltinSampler):
def _create_for_minddataset(self):
return cde.MindrecordSubsetRandomSampler(self.indices)
def get_dataset_size(self):
return len(self.indices)
def get_num_samples(self):
num_samples = super().get_num_samples()
return min(len(self.indices), num_samples)
class WeightedRandomSampler(BuiltinSampler):
@ -492,7 +466,7 @@ class WeightedRandomSampler(BuiltinSampler):
Args:
weights (list[float]): A sequence of weights, not necessarily summing up to 1.
num_samples (int): Number of elements to sample.
num_samples (int): Number of elements to sample (default=None, all elements).
replacement (bool, optional): If True, put the sample ID back for the next draw (default=True).
Examples:
@ -511,24 +485,25 @@ class WeightedRandomSampler(BuiltinSampler):
ValueError: If replacement is not boolean.
"""
def __init__(self, weights, num_samples, replacement=True):
def __init__(self, weights, num_samples=None, replacement=True):
if not isinstance(weights, list):
weights = [weights]
if num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
if num_samples is not None:
if num_samples < 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement))
self.weights = weights
self.num_samples = num_samples
self.replacement = replacement
super().__init__()
super().__init__(num_samples)
def create(self):
c_sampler = cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement)
c_child_sampler = self.create_child()
c_sampler.add_child(c_child_sampler)
return c_sampler
@ -541,6 +516,3 @@ class WeightedRandomSampler(BuiltinSampler):
return False
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.num_samples

View File

@ -161,6 +161,20 @@ def traverse(node):
else:
node_repr[k] = v
# If a sampler exists in this node, then the following 4 arguments must be set to None:
# num_samples, shard_id, num_shards, shuffle
# These arguments get moved into the sampler itself, so they are no longer needed to
# be set at the dataset level.
if 'sampler' in node_args.keys():
if 'num_samples' in node_repr.keys():
node_repr['num_samples'] = None
if 'shuffle' in node_repr.keys():
node_repr['shuffle'] = None
if 'num_shards' in node_repr.keys():
node_repr['num_shards'] = None
if 'shard_id' in node_repr.keys():
node_repr['shard_id'] = None
# Leaf node doesn't have input attribute.
if not node.input:
return node_repr

View File

@ -283,8 +283,8 @@ def check_num_parallel_workers(value):
def check_num_samples(value):
check_type(value, 'num_samples', int)
if value <= 0:
raise ValueError("num_samples must be greater than 0!")
if value < 0:
raise ValueError("num_samples cannot be less than 0!")
def check_dataset_dir(dataset_dir):

View File

@ -39,14 +39,13 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
std::shared_ptr<CelebAOp> Celeba(int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size,
const std::string &dir, int64_t num_samples = 0,
std::unique_ptr<Sampler> sampler = nullptr, bool decode = false,
const std::string &dataset_type="all") {
const std::string &dir, std::shared_ptr<Sampler> sampler = nullptr,
bool decode = false, const std::string &dataset_type="all") {
std::shared_ptr<CelebAOp> so;
CelebAOp::Builder builder;
Status rc = builder.SetNumWorkers(num_workers).SetCelebADir(dir).SetRowsPerBuffer(rows_per_buffer)
.SetOpConnectorSize(queue_size).SetSampler(std::move(sampler)).SetDecode(decode)
.SetNumSamples(num_samples).SetDatasetType(dataset_type).Build(&so);
.SetDatasetType(dataset_type).Build(&so);
return so;
}
@ -116,11 +115,12 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
TEST_F(MindDataTestCelebaDataset, TestSubsetRandomSamplerCeleba) {
std::vector<int64_t> indices({1});
std::unique_ptr<Sampler> sampler = std::make_unique<SubsetRandomSampler>(indices);
int64_t num_samples = 0;
std::shared_ptr<Sampler> sampler = std::make_shared<SubsetRandomSampler>(num_samples, indices);
uint32_t expect_labels[1][40] = {{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}};
std::string dir = datasets_root_path_ + "/testCelebAData/";
uint32_t count = 0;
auto tree = Build({Celeba(16, 2, 32, dir, 0, std::move(sampler))});
auto tree = Build({Celeba(16, 2, 32, dir, std::move(sampler))});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
@ -143,25 +143,3 @@ TEST_F(MindDataTestCelebaDataset, TestSubsetRandomSamplerCeleba) {
EXPECT_TRUE(count == 1);
}
}
TEST_F(MindDataTestCelebaDataset, TestCelebaNumSamples) {
std::string dir = datasets_root_path_ + "/testCelebAData/";
uint32_t count = 0;
auto tree = Build({Celeba(16, 2, 32, dir, 1)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
EXPECT_TRUE(false);
} else {
DatasetIterator di(tree);
TensorMap tersor_map;
di.GetNextAsMap(&tersor_map);
EXPECT_TRUE(rc.IsOk());
while (tersor_map.size() != 0) {
count++;
di.GetNextAsMap(&tersor_map);
}
EXPECT_TRUE(count == 1);
}
}

View File

@ -45,13 +45,12 @@ std::shared_ptr<RepeatOp> Repeat(int repeatCnt);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
std::shared_ptr<CifarOp> Cifarop(uint64_t num_works, uint64_t rows, uint64_t conns, std::string path,
std::unique_ptr<Sampler> sampler = nullptr,
uint64_t num_samples = 0, bool cifar10 = true) {
std::shared_ptr<Sampler> sampler = nullptr, bool cifar10 = true) {
std::shared_ptr<CifarOp> so;
CifarOp::Builder builder;
Status rc = builder.SetNumWorkers(num_works).SetCifarDir(path).SetRowsPerBuffer(rows)
.SetOpConnectorSize(conns).SetSampler(std::move(sampler)).SetCifarType(cifar10)
.SetNumSamples(num_samples).Build(&so);
.Build(&so);
return so;
}
@ -66,7 +65,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
//appear in this dataset
//Example: python tests/dataset/data/prep_data.py
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100)});
auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
@ -79,7 +78,8 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
EXPECT_TRUE(rc.IsOk());
uint64_t i = 0;
uint32_t label = 0;
while (tensor_map.size() != 0) {
// Note: only iterating first 100 rows then break out.
while (tensor_map.size() != 0 && i < 100) {
tensor_map["label"]->GetItemAt<uint32_t>(&label, {});
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n";
i++;
@ -92,9 +92,9 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) {
uint32_t original_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(0);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
std::shared_ptr<Sampler> sampler = std::make_unique<RandomSampler>(12, true, true);
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)});
auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler))});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
@ -118,34 +118,9 @@ TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) {
GlobalContext::config_manager()->set_seed(original_seed);
}
TEST_F(MindDataTestCifarOp, TestCifar10NumSample) {
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
MS_LOG(ERROR) << "Return code error detected during tree launch: " << common::SafeCStr(rc.ToString()) << ".";
EXPECT_TRUE(false);
} else {
DatasetIterator di(tree);
TensorMap tensor_map;
di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
uint64_t i = 0;
uint32_t label = 0;
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<uint32_t>(&label, {});
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n";
i++;
di.GetNextAsMap(&tensor_map);
}
EXPECT_TRUE(i == 100);
}
}
TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar100) {
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100, false)});
auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, false)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
@ -159,7 +134,8 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar100) {
uint64_t i = 0;
uint32_t coarse = 0;
uint32_t fine = 0;
while (tensor_map.size() != 0) {
// only iterate to 100 then break out of loop
while (tensor_map.size() != 0 && i < 100) {
tensor_map["coarse_label"]->GetItemAt<uint32_t>(&coarse, {});
tensor_map["fine_label"]->GetItemAt<uint32_t>(&fine, {});
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << " coarse:"

View File

@ -50,9 +50,8 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr,
std::map<std::string, int32_t> map = {}, int64_t num_samples = 0,
bool decode = false) {
bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr,
std::map<std::string, int32_t> map = {}, bool decode = false) {
std::shared_ptr<ImageFolderOp> so;
ImageFolderOp::Builder builder;
Status rc = builder.SetNumWorkers(num_works)
@ -63,7 +62,6 @@ std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int6
.SetSampler(std::move(sampler))
.SetClassIndex(map)
.SetDecode(decode)
.SetNumSamples(num_samples)
.Build(&so);
return so;
}
@ -138,7 +136,8 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) {
TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) {
int32_t original_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(0);
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
int64_t num_samples = 12;
std::shared_ptr<Sampler> sampler = std::make_unique<RandomSampler>(num_samples, true, true);
int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label
std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});
@ -200,7 +199,8 @@ TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeatBatch)
TEST_F(MindDataTestImageFolderSampler, TestSubsetRandomSamplerImageFolder) {
// id range 0 - 10 is label 0, and id range 11 - 21 is label 1
std::vector<int64_t> indices({0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11});
std::unique_ptr<Sampler> sampler = std::make_unique<SubsetRandomSampler>(indices);
int64_t num_samples = 0;
std::shared_ptr<Sampler> sampler = std::make_shared<SubsetRandomSampler>(num_samples, indices);
std::string folder_path = datasets_root_path_ + "/testPK/data";
// Expect 6 samples for label 0 and 1
int res[2] = {6, 6};
@ -237,8 +237,8 @@ TEST_F(MindDataTestImageFolderSampler, TestWeightedRandomSamplerImageFolder) {
std::vector<double> weights(total_samples, std::rand() % 100);
// create sampler with replacement = replacement
std::unique_ptr<Sampler> sampler =
std::make_unique<WeightedRandomSampler>(weights, num_samples, true, samples_per_buffer);
std::shared_ptr<Sampler> sampler =
std::make_shared<WeightedRandomSampler>(num_samples, weights, true, samples_per_buffer);
std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});
@ -295,7 +295,8 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderClassIndex) {
}
TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) {
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(11, 10, false);
int64_t num_samples = 0;
std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 11, 10, false);
std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)), Repeat(4)});
tree->Prepare();
@ -322,7 +323,8 @@ TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) {
}
TEST_F(MindDataTestImageFolderSampler, TestPKSamplerImageFolder) {
std::unique_ptr<Sampler> sampler = std::make_unique<PKSampler>(3, false, 4);
int64_t num_samples = 0;
std::shared_ptr<Sampler> sampler = std::make_shared<PKSampler>(num_samples, 3, false, 4);
int32_t res[] = {0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3}; // ground truth label
std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});
@ -349,39 +351,16 @@ TEST_F(MindDataTestImageFolderSampler, TestPKSamplerImageFolder) {
}
}
TEST_F(MindDataTestImageFolderSampler, TestImageFolderNumSamples) {
std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, nullptr, {}, 11), Repeat(2)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
MS_LOG(ERROR) << "Return code error detected during tree launch: " << common::SafeCStr(rc.ToString()) << ".";
EXPECT_TRUE(false);
} else {
DatasetIterator di(tree);
TensorMap tensor_map;
di.GetNextAsMap(&tensor_map);
EXPECT_TRUE(rc.IsOk());
uint64_t i = 0;
int32_t label = 0;
while (tensor_map.size() != 0) {
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
EXPECT_TRUE(0 == label);
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n";
i++;
di.GetNextAsMap(&tensor_map);
}
EXPECT_TRUE(i == 22);
}
}
TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) {
std::string folder_path = datasets_root_path_ + "/testPK/data";
std::map<std::string, int32_t> map;
map["class3"] = 333;
map["class1"] = 111;
map["wrong folder name"] = 1234; // this is skipped
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, nullptr, map, 20, true)});
int64_t num_samples = 20;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(seq_sampler), map, true)});
int64_t res[2] = {111, 333};
tree->Prepare();
Status rc = tree->Launch();
@ -408,33 +387,12 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) {
}
}
TEST_F(MindDataTestImageFolderSampler, TestImageFolderDatasetSize) {
std::string folder_path = datasets_root_path_ + "/testPK/data";
int64_t num_rows = 0;
int64_t num_classes = 0;
ImageFolderOp::CountRowsAndClasses(folder_path, 15, {}, &num_rows, &num_classes);
EXPECT_TRUE(num_rows == 15 && num_classes == 4);
ImageFolderOp::CountRowsAndClasses(folder_path, 44, {}, &num_rows, &num_classes);
EXPECT_TRUE(num_rows == 44 && num_classes == 4);
ImageFolderOp::CountRowsAndClasses(folder_path, 0, {}, &num_rows, &num_classes);
EXPECT_TRUE(num_rows == 44 && num_classes == 4);
ImageFolderOp::CountRowsAndClasses(folder_path, 55, {}, &num_rows, &num_classes);
EXPECT_TRUE(num_rows == 44 && num_classes == 4);
ImageFolderOp::CountRowsAndClasses(folder_path, 44, {}, &num_rows, &num_classes, 2, 3);
EXPECT_TRUE(num_rows == 15 && num_classes == 4);
ImageFolderOp::CountRowsAndClasses(folder_path, 33, {}, &num_rows, &num_classes, 0, 3);
EXPECT_TRUE(num_rows == 15 && num_classes == 4);
ImageFolderOp::CountRowsAndClasses(folder_path, 13, {}, &num_rows, &num_classes, 0, 11);
EXPECT_TRUE(num_rows == 4 && num_classes == 4);
ImageFolderOp::CountRowsAndClasses(folder_path, 3, {}, &num_rows, &num_classes, 0, 11);
EXPECT_TRUE(num_rows == 3 && num_classes == 4);
}
TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) {
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(4, 0, false);
int64_t num_samples = 5;
std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 4, 0, false);
std::string folder_path = datasets_root_path_ + "/testPK/data";
// numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler), {}, 5)});
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler), {})});
tree->Prepare();
Status rc = tree->Launch();
int32_t labels[5] = {0, 0, 0, 1, 1};
@ -460,10 +418,11 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) {
}
TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding2) {
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(4, 3, false);
int64_t num_samples = 12;
std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 4, 3, false);
std::string folder_path = datasets_root_path_ + "/testPK/data";
// numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode
auto tree = Build({ImageFolder(16, 16, 32, folder_path, false, std::move(sampler), {}, 12)});
auto tree = Build({ImageFolder(16, 16, 32, folder_path, false, std::move(sampler), {})});
tree->Prepare();
Status rc = tree->Launch();
uint32_t labels[11] = {0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3};

View File

@ -23,6 +23,7 @@
#include "dataset/core/client.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/util/de_error.h"
#include "dataset/util/status.h"
@ -42,14 +43,13 @@ std::shared_ptr<RepeatOp> Repeat(int repeatCnt);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
std::shared_ptr<ManifestOp> Manifest(int32_t num_works, int32_t rows, int32_t conns, const std::string &file,
std::string usage = "train", std::unique_ptr<Sampler> sampler = nullptr,
std::map<std::string, int32_t> map = {}, uint64_t num_samples = 0,
bool decode = false) {
std::string usage = "train", std::shared_ptr<Sampler> sampler = nullptr,
std::map<std::string, int32_t> map = {}, bool decode = false) {
std::shared_ptr<ManifestOp> so;
ManifestOp::Builder builder;
Status rc = builder.SetNumWorkers(num_works).SetManifestFile(file).SetRowsPerBuffer(
rows).SetOpConnectorSize(conns).SetSampler(std::move(sampler)).SetClassIndex(map).SetDecode(decode)
.SetNumSamples(num_samples).SetUsage(usage).Build(&so);
.SetUsage(usage).Build(&so);
return so;
}
@ -86,7 +86,8 @@ TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) {
TEST_F(MindDataTestManifest, TestSubsetRandomSamplerManifest) {
std::vector<int64_t> indices({1});
std::unique_ptr<Sampler> sampler = std::make_unique<SubsetRandomSampler>(indices);
int64_t num_samples = 0;
std::shared_ptr<Sampler> sampler = std::make_shared<SubsetRandomSampler>(num_samples, indices);
std::string file = datasets_root_path_ + "/testManifestData/cpp.json";
// Expect 6 samples for label 0 and 1
auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(sampler))});
@ -145,7 +146,10 @@ TEST_F(MindDataTestManifest, MindDataTestManifestClassIndex) {
TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) {
std::string file = datasets_root_path_ + "/testManifestData/cpp.json";
auto tree = Build({Manifest(16, 2, 32, file, "train", nullptr, {}, 1), Repeat(4)});
int64_t num_samples = 1;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}), Repeat(4)});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
@ -171,7 +175,10 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) {
TEST_F(MindDataTestManifest, MindDataTestManifestEval) {
std::string file = datasets_root_path_ + "/testManifestData/cpp.json";
auto tree = Build({Manifest(16, 2, 32, file, "eval", nullptr, {}, 1)});
int64_t num_samples = 1;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
auto tree = Build({Manifest(16, 2, 32, file, "eval", std::move(seq_sampler), {})});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {

View File

@ -120,9 +120,8 @@ class MindDataTestMapOp : public UT::DatasetOpTesting {
};
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr,
std::map<std::string, int32_t> map = {}, int64_t num_samples = 0,
bool decode = false);
bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr,
std::map<std::string, int32_t> map = {}, bool decode = false);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);

View File

@ -53,13 +53,11 @@ Status Create1DTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements,
DataType::Type data_type = DataType::DE_UINT32);
std::shared_ptr<MnistOp> CreateMnist(int64_t num_wrks, int64_t rows, int64_t conns, std::string path,
bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr,
int64_t num_samples = 0) {
bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr) {
std::shared_ptr<MnistOp> so;
MnistOp::Builder builder;
Status rc = builder.SetNumWorkers(num_wrks).SetDir(path).SetRowsPerBuffer(rows)
.SetOpConnectorSize(conns).SetSampler(std::move(sampler))
.SetNumSamples(num_samples).Build(&so);
.SetOpConnectorSize(conns).SetSampler(std::move(sampler)).Build(&so);
return so;
}
@ -74,7 +72,10 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) {
// appear in this dataset
// Example: python tests/dataset/data/prep_data.py
std::string folder_path = datasets_root_path_ + "/testMnistData/";
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, nullptr, 10), Repeat(2)});
int64_t num_samples = 10;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2)});
tree->Prepare();
uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
Status rc = tree->Launch();
@ -101,7 +102,10 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) {
TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) {
std::string folder_path = datasets_root_path_ + "/testMnistData/";
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, nullptr, 10), Repeat(2), Batch(5)});
int64_t num_samples = 10;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2), Batch(5)});
tree->Prepare();
uint32_t res[4][5] = { {0, 0, 0, 0, 0 },
{0, 0, 0, 0, 0 },

View File

@ -43,20 +43,11 @@ class MindDataTestStandAloneSampler : public UT::DatasetOpTesting {
protected:
class MockStorageOp : public RandomAccessOp {
public:
MockStorageOp(int64_t val) : m_val_(val) {}
Status GetNumSamples(int64_t *ptr) const override {
(*ptr) = m_val_;
return Status::OK();
MockStorageOp(int64_t val){
// row count is in base class as protected member
// GetNumRowsInDataset does not need an override, the default from base class is fine.
num_rows_ = val;
}
Status GetNumRowsInDataset(int64_t *ptr) const override {
(*ptr) = m_val_;
return Status::OK();
}
private:
int64_t m_val_;
};
};
@ -73,8 +64,9 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
MockStorageOp mock(20);
std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> tensor;
int64_t num_samples = 0;
for (int i = 0; i < 6; i++) {
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(3, i % 3, (i < 3 ? false : true));
std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 3, i % 3, (i < 3 ? false : true));
sampler->HandshakeRandomAccessOp(&mock);
sampler->GetNextBuffer(&db);
db->GetTensor(&tensor, 0, 0);
@ -92,7 +84,9 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
std::shared_ptr<Tensor> label1, label2;
CreateINT64Tensor(&label1, 3, reinterpret_cast<unsigned char *>(res));
CreateINT64Tensor(&label2, 2, reinterpret_cast<unsigned char *>(res + 3));
std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(3);
int64_t num_samples = 0;
int64_t start_index = 0;
std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(num_samples, start_index, 3);
std::unique_ptr<DataBuffer> db;
std::shared_ptr<Tensor> tensor;
sampler->HandshakeRandomAccessOp(&mock);

View File

@ -31,26 +31,17 @@ class MindDataTestSubsetRandomSampler : public UT::Common {
public:
class DummyRandomAccessOp : public RandomAccessOp {
public:
DummyRandomAccessOp(int64_t num_rows) : num_rows_(num_rows) {};
Status GetNumSamples(int64_t *num) const {
*num = num_rows_;
return Status::OK();
}
Status GetNumRowsInDataset(int64_t *num) const {
*num = num_rows_;
return Status::OK();
}
private:
int64_t num_rows_;
DummyRandomAccessOp(int64_t num_rows) {
num_rows_ = num_rows; // base class
};
};
};
TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
std::vector<int64_t> in({0, 1, 2, 3, 4});
std::unordered_set<int64_t> in_set(in.begin(), in.end());
SubsetRandomSampler sampler(in);
int64_t num_samples = 0;
SubsetRandomSampler sampler(num_samples, in);
DummyRandomAccessOp dummyRandomAccessOp(5);
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
@ -77,8 +68,9 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
int64_t total_samples = 100000 - 5;
int64_t samples_per_buffer = 10;
int64_t num_samples = 0;
std::vector<int64_t> input(total_samples, 1);
SubsetRandomSampler sampler(input, samples_per_buffer);
SubsetRandomSampler sampler(num_samples, input, samples_per_buffer);
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
@ -109,7 +101,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
std::vector<int64_t> in({0, 1, 2, 3, 4});
std::unordered_set<int64_t> in_set(in.begin(), in.end());
SubsetRandomSampler sampler(in);
int64_t num_samples = 0;
SubsetRandomSampler sampler(num_samples, in);
DummyRandomAccessOp dummyRandomAccessOp(5);
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);

View File

@ -35,19 +35,11 @@ class MindDataTestWeightedRandomSampler : public UT::Common {
public:
class DummyRandomAccessOp : public RandomAccessOp {
public:
DummyRandomAccessOp(uint64_t num_rows) : num_rows_(num_rows) {};
Status GetNumSamples(int64_t *num) const {
*num = num_rows_;
return Status::OK();
DummyRandomAccessOp(uint64_t num_rows) {
// row count is in base class as protected member
// GetNumRowsInDataset does not need an override, the default from base class is fine.
num_rows_ = num_rows;
}
Status GetNumRowsInDataset(int64_t *num) const {
*num = num_rows_;
return Status::OK();
}
private:
uint64_t num_rows_;
};
};
@ -59,7 +51,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
std::vector<uint64_t> freq(total_samples, 0);
// create sampler with replacement = true
WeightedRandomSampler m_sampler(weights, num_samples, true);
WeightedRandomSampler m_sampler(num_samples, weights, true);
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
@ -89,7 +81,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
std::vector<uint64_t> freq(total_samples, 0);
// create sampler with replacement = replacement
WeightedRandomSampler m_sampler(weights, num_samples, false);
WeightedRandomSampler m_sampler(num_samples, weights, false);
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
@ -125,7 +117,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
std::vector<double> weights(total_samples, std::rand() % 100);
// create sampler with replacement = replacement
WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer);
WeightedRandomSampler m_sampler(num_samples, weights, true, samples_per_buffer);
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
@ -161,7 +153,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
std::vector<uint64_t> freq(total_samples, 0);
// create sampler with replacement = replacement
WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer);
WeightedRandomSampler m_sampler(num_samples, weights, false, samples_per_buffer);
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
@ -202,7 +194,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
std::vector<uint64_t> freq(total_samples, 0);
// create sampler with replacement = true
WeightedRandomSampler m_sampler(weights, num_samples, true);
WeightedRandomSampler m_sampler(num_samples, weights, true);
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
@ -247,7 +239,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
std::vector<uint64_t> freq(total_samples, 0);
// create sampler with replacement = true
WeightedRandomSampler m_sampler(weights, num_samples, false);
WeightedRandomSampler m_sampler(num_samples, weights, false);
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);

View File

@ -58,7 +58,7 @@ def test_imagefolder_numsamples():
assert num_iter == 10
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
num_iter = 0
for item in data1.create_dict_iterator():
@ -67,7 +67,7 @@ def test_imagefolder_numsamples():
assert num_iter == 3
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
num_iter = 0
for item in data1.create_dict_iterator():

View File

@ -162,8 +162,8 @@ def test_voc_shardings(print_res=False):
voc_dir = "../data/dataset/testVOC2012"
def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_samples=num_samples)
sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler)
data1 = data1.repeat(repeat_cnt)
res = []
for item in data1.create_dict_iterator(): # each data is a dictionary

View File

@ -35,18 +35,13 @@ def test_exception_01():
def test_exception_02():
"""
Test multiple exceptions with invalid input
Test exceptions with invalid input, and test valid input
"""
logger.info("test_exception_02")
num_samples = 0
with pytest.raises(ValueError) as info:
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
assert "num_samples must be greater than 0" in str(info.value)
num_samples = -1
with pytest.raises(ValueError) as info:
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
assert "num_samples must be greater than 0" in str(info.value)
assert "num_samples cannot be less than 0" in str(info.value)
num_samples = 1
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)

View File

@ -544,7 +544,7 @@ def test_distributed_sampler():
def test_num_samples():
source = [(np.array([x]),) for x in range(64)]
num_samples = 32
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_samples=num_samples)
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples))
ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples=num_samples)
ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples)
@ -660,4 +660,6 @@ if __name__ == "__main__":
test_sequential_sampler()
test_distributed_sampler()
test_random_sampler()
test_num_samples()
test_num_samples_underflow()
test_schema()

View File

@ -28,8 +28,8 @@ def test_sequential_sampler(print_res=False):
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_samples, num_repeats=None):
sampler = ds.SequentialSampler()
data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler)
sampler = ds.SequentialSampler(num_samples=num_samples)
data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
if num_repeats is not None:
data1 = data1.repeat(num_repeats)
res = []
@ -43,6 +43,7 @@ def test_sequential_sampler(print_res=False):
assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2]
assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2
assert test_config(num_samples=0, num_repeats=2) == [0, 1, 2, 3, 4] * 2
assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2
@ -119,8 +120,8 @@ def test_python_sampler():
return iter([i for i in range(self.dataset_size)])
class Sp2(ds.Sampler):
def __init__(self):
super(Sp2, self).__init__()
def __init__(self, num_samples=None):
super(Sp2, self).__init__(num_samples)
# at this stage, self.dataset_size and self.num_samples are not yet known
self.cnt = 0
@ -130,8 +131,8 @@ def test_python_sampler():
def reset(self):
self.cnt = (self.cnt + 1) % self.dataset_size
def test_config(num_samples, num_repeats, sampler):
data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler)
def test_config(num_repeats, sampler):
data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
if num_repeats is not None:
data1 = data1.repeat(num_repeats)
res = []
@ -154,8 +155,8 @@ def test_python_sampler():
assert data[0] == (np.array(i),)
i = i - 1
assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
test_generator()
sp1 = Sp1().create()
@ -169,9 +170,8 @@ def test_subset_sampler():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_samples, start_index, subset_size):
_ = num_samples
sampler = ds.SubsetSampler(start_index, subset_size)
def test_config(start_index, num_samples):
sampler = ds.SequentialSampler(start_index, num_samples)
d = ds.ManifestDataset(manifest_file, sampler=sampler)
res = []
@ -180,19 +180,15 @@ def test_subset_sampler():
return res
with pytest.raises(RuntimeError) as info:
test_config(5, 0, 0)
assert "subset_size <= 0" in str(info.value)
assert test_config(5, 0, 1) == [0]
assert test_config(5, 0, 2) == [0, 1]
assert test_config(5, 0, 3) == [0, 1, 2]
assert test_config(5, 0, 4) == [0, 1, 2, 3]
assert test_config(5, 0, 5) == [0, 1, 2, 3, 4]
assert test_config(5, 1, 1) == [1]
assert test_config(5, 2, 3) == [2, 3, 4]
assert test_config(5, 3, 2) == [3, 4]
assert test_config(5, 4, 1) == [4]
assert test_config(0, 1) == [0]
assert test_config(0, 2) == [0, 1]
assert test_config(0, 3) == [0, 1, 2]
assert test_config(0, 4) == [0, 1, 2, 3]
assert test_config(0, 5) == [0, 1, 2, 3, 4]
assert test_config(1, 1) == [1]
assert test_config(2, 3) == [2, 3, 4]
assert test_config(3, 2) == [3, 4]
assert test_config(4, 1) == [4]
def test_sampler_chain():
@ -200,11 +196,11 @@ def test_sampler_chain():
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_shards, shard_id):
sampler = ds.DistributedSampler(num_shards, shard_id, False)
sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=False, num_samples=5)
child_sampler = ds.SequentialSampler()
sampler.add_child(child_sampler)
data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler)
data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
res = []
for item in data1.create_dict_iterator():
@ -234,6 +230,11 @@ def test_add_sampler_invalid_input():
data1.use_sampler("sampler")
assert "not an instance of a sampler" in str(info.value)
sampler = ds.SequentialSampler()
with pytest.raises(ValueError) as info:
data2 = ds.ManifestDataset(manifest_file, sampler=sampler, num_samples=20)
assert "Conflicting arguments during sampler assignments" in str(info.value)
if __name__ == '__main__':
test_sequential_sampler(True)