consistent design for num_samples
updates more update more work more fixin post rebase updates clang formatting code review recovery ci fixes updates update update update
This commit is contained in:
parent
bc7a3a1bef
commit
51bc0c0460
|
@ -856,9 +856,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -893,9 +891,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -930,9 +926,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -966,9 +960,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -1001,9 +993,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -1039,10 +1029,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
|
|||
(void)builder.SetNumWorkers(ToInt(value));
|
||||
} else if (key == "schema_file_path" || key == "schema_json_string") {
|
||||
schema_exists = true;
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder.SetTotalRows(ToInt(value));
|
||||
} else if (key == "columns_list") {
|
||||
columns_to_load = ToStringVector(value);
|
||||
} else if (key == "num_samples") {
|
||||
// This is not sampling here. The random data op needs to know how much data to
|
||||
// generate. It does not currently support sampling.
|
||||
(void)builder.SetTotalRows(ToInt(value));
|
||||
}
|
||||
}
|
||||
if (schema_exists) {
|
||||
|
@ -1077,9 +1069,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -1121,8 +1111,6 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
|
|||
(void)builder->SetDecode(ToBool(value));
|
||||
} else if (key == "extensions") {
|
||||
(void)builder->SetExtensions(ToStringSet(value));
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "dataset_type") {
|
||||
(void)builder->SetDatasetType(ToString(value));
|
||||
}
|
||||
|
@ -1153,7 +1141,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
(void)builder->SetTotalRows(ToInt(value));
|
||||
} else if (key == "num_shards") {
|
||||
(void)builder->SetNumDevices(ToInt(value));
|
||||
} else if (key == "shard_id") {
|
||||
|
|
|
@ -49,7 +49,6 @@
|
|||
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
|
@ -143,17 +142,16 @@ void bindDatasetOps(py::module *m) {
|
|||
});
|
||||
|
||||
(void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
|
||||
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) {
|
||||
.def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count));
|
||||
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
|
||||
return count;
|
||||
});
|
||||
|
||||
(void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
|
||||
.def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) {
|
||||
.def_static("get_num_rows_and_classes", [](const std::string &path) {
|
||||
int64_t count = 0, num_classes = 0;
|
||||
THROW_IF_ERROR(
|
||||
ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set<std::string>{}, &count, &num_classes));
|
||||
THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes));
|
||||
return py::make_tuple(count, num_classes);
|
||||
});
|
||||
|
||||
|
@ -172,22 +170,21 @@ void bindDatasetOps(py::module *m) {
|
|||
|
||||
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
|
||||
.def_static("get_num_rows_and_classes",
|
||||
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
|
||||
[](const std::string &file, const py::dict &dict, const std::string &usage) {
|
||||
int64_t count = 0, num_classes = 0;
|
||||
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes));
|
||||
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
|
||||
return py::make_tuple(count, num_classes);
|
||||
})
|
||||
.def_static("get_class_indexing",
|
||||
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
});
|
||||
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
});
|
||||
|
||||
(void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
|
||||
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) {
|
||||
.def_static("get_num_rows", [](const std::string &dir) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
|
||||
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
|
||||
return count;
|
||||
});
|
||||
|
||||
|
@ -206,13 +203,13 @@ void bindDatasetOps(py::module *m) {
|
|||
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count));
|
||||
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
|
||||
return count;
|
||||
})
|
||||
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
|
||||
const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
|
||||
const std::string &task_mode, const py::dict &dict) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, numSamples, &output_class_indexing));
|
||||
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
});
|
||||
}
|
||||
|
@ -452,25 +449,19 @@ void bindSamplerOps(py::module *m) {
|
|||
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
|
||||
|
||||
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
|
||||
py::arg("seed"));
|
||||
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
|
||||
|
||||
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
|
||||
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
|
||||
.def(py::init<int64_t, int64_t, bool>());
|
||||
|
||||
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
|
||||
.def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
|
||||
py::arg("num_samples"))
|
||||
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
|
||||
.def(py::init<int64_t, bool, bool>());
|
||||
|
||||
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
|
||||
.def(py::init<>());
|
||||
|
||||
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
|
||||
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
|
||||
.def(py::init<int64_t, int64_t>());
|
||||
|
||||
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
|
||||
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
|
||||
.def(py::init<int64_t, std::vector<int64_t>>());
|
||||
|
||||
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
|
||||
*m, "MindrecordSubsetRandomSampler")
|
||||
|
@ -487,11 +478,10 @@ void bindSamplerOps(py::module *m) {
|
|||
}));
|
||||
|
||||
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
|
||||
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
|
||||
py::arg("replacement"));
|
||||
.def(py::init<int64_t, std::vector<double>, bool>());
|
||||
|
||||
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
|
||||
.def(py::init<py::object>(), py::arg("pySampler"));
|
||||
.def(py::init<int64_t, py::object>());
|
||||
}
|
||||
|
||||
void bindInfoObjects(py::module *m) {
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr), builder_num_samples_(0) {
|
||||
CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
|
@ -38,7 +38,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
|
|||
MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << ".";
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
if (builder_sampler_ == nullptr) {
|
||||
builder_sampler_ = std::make_shared<SequentialSampler>();
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
|
||||
}
|
||||
|
||||
builder_schema_ = std::make_unique<DataSchema>();
|
||||
|
@ -47,10 +49,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
|
|||
// label is like this:0 1 0 0 1......
|
||||
RETURN_IF_NOT_OK(
|
||||
builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
*op =
|
||||
std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, builder_op_connector_size_,
|
||||
builder_decode_, builder_dataset_type_, builder_extensions_, std::move(builder_schema_),
|
||||
std::move(builder_sampler_), builder_num_samples_);
|
||||
*op = std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
|
||||
builder_op_connector_size_, builder_decode_, builder_dataset_type_,
|
||||
builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_));
|
||||
if (*op == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null");
|
||||
}
|
||||
|
@ -68,7 +69,7 @@ Status CelebAOp::Builder::SanityCheck() {
|
|||
|
||||
CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size,
|
||||
bool decode, const std::string &dataset_type, const std::set<std::string> &exts,
|
||||
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler, int64_t num_samples)
|
||||
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, queue_size),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
folder_path_(dir),
|
||||
|
@ -77,8 +78,6 @@ CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::stri
|
|||
data_schema_(std::move(schema)),
|
||||
sampler_(std::move(sampler)),
|
||||
num_rows_in_attr_file_(0),
|
||||
num_rows_exact_(0),
|
||||
num_samples_(num_samples),
|
||||
dataset_type_(dataset_type) {
|
||||
// Set the column name map (base class field)
|
||||
for (int32_t index = 0; index < data_schema_->NumColumns(); index++) {
|
||||
|
@ -202,13 +201,6 @@ Status CelebAOp::ParseImageAttrInfo() {
|
|||
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
|
||||
while (!image_infos.empty() && needMoreData) {
|
||||
for (uint32_t index = 0; index < image_infos.size(); index++) {
|
||||
if (num_samples_ != 0 && image_labels_vec_.size() >= num_samples_) {
|
||||
MS_LOG(WARNING) << "Image number(" << image_labels_vec_.size() << " is more than"
|
||||
<< " rows num eval attr file(" << num_rows_in_attr_file_ << ") or num samples(" << num_samples_
|
||||
<< ").";
|
||||
needMoreData = false;
|
||||
break;
|
||||
}
|
||||
std::string image_info = image_infos[index];
|
||||
std::vector<std::string> split = Split(image_info);
|
||||
std::pair<std::string, std::vector<int32_t>> image_labels;
|
||||
|
@ -239,14 +231,13 @@ Status CelebAOp::ParseImageAttrInfo() {
|
|||
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
|
||||
}
|
||||
|
||||
num_rows_exact_ = image_labels_vec_.size();
|
||||
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_exact_) ? num_rows_exact_ : num_samples_;
|
||||
if (num_rows_exact_ == 0) {
|
||||
num_rows_ = image_labels_vec_.size();
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_exact_ << ".";
|
||||
MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_ << ".";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -268,28 +259,6 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) {
|
|||
return split;
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status CelebAOp::GetNumSamples(int64_t *num) const {
|
||||
if (num == nullptr || num_samples_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
(*num) = num_samples_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CelebAOp::GetNumRowsInDataset(int64_t *num) const {
|
||||
if (num == nullptr || num_rows_exact_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
|
||||
*num = num_rows_exact_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work
|
||||
Status CelebAOp::operator()() {
|
||||
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
|
||||
|
@ -310,9 +279,8 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) {
|
|||
RETURN_IF_NOT_OK((*data_buffer)->PopRow(&sample_row));
|
||||
std::shared_ptr<Tensor> sample_ids = sample_row[0];
|
||||
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
|
||||
if ((*itr) >= num_rows_exact_) {
|
||||
MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_exact_
|
||||
<< ".";
|
||||
if ((*itr) >= num_rows_) {
|
||||
MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_ << ".";
|
||||
continue;
|
||||
}
|
||||
keys.push_back(*itr);
|
||||
|
@ -446,7 +414,7 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nNumber of rows:" << num_rows_exact_ << "\nceleba dir: " << folder_path_ << "\n\n";
|
||||
out << "\nNumber of rows:" << num_rows_ << "\nceleba dir: " << folder_path_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -108,14 +108,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param int64_t num_samples
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(int64_t num_samples) {
|
||||
builder_num_samples_ = num_samples;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param const std::string dataset_type: type to be read
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
|
@ -141,7 +133,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
std::set<std::string> builder_extensions_;
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
std::unique_ptr<DataSchema> builder_schema_;
|
||||
int64_t builder_num_samples_;
|
||||
std::string builder_dataset_type_;
|
||||
};
|
||||
|
||||
|
@ -153,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
// @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read
|
||||
CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode,
|
||||
const std::string &dataset_type, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema,
|
||||
std::shared_ptr<Sampler> sampler, int64_t num_samples);
|
||||
std::shared_ptr<Sampler> sampler);
|
||||
|
||||
~CelebAOp() override = default;
|
||||
|
||||
|
@ -163,16 +154,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get numRows
|
||||
// @param int64_t num - to return numRows
|
||||
// @return Status - The error code return
|
||||
Status GetNumSamples(int64_t *num) const override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get numRows
|
||||
// @param int64_t num - to return numRows
|
||||
// @return Status - The error code return
|
||||
Status GetNumRowsInDataset(int64_t *num) const override;
|
||||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
// @param int32_t worker_id - id of each worker
|
||||
// @return Status - The error code return
|
||||
|
@ -233,11 +214,9 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
std::shared_ptr<Sampler> sampler_;
|
||||
std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_;
|
||||
int64_t num_rows_in_attr_file_; // rows number specified in attr file
|
||||
int64_t num_rows_exact_; // exact rows number,maybe is less than rows_num_in_attr_file_
|
||||
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
|
||||
WaitPost wp_;
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_;
|
||||
int64_t num_samples_;
|
||||
std::string dataset_type_;
|
||||
std::ifstream partition_file_;
|
||||
};
|
||||
|
|
|
@ -35,7 +35,7 @@ constexpr uint32_t kCifarImageChannel = 3;
|
|||
constexpr uint32_t kCifarBlockImageNum = 5;
|
||||
constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel;
|
||||
|
||||
CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) {
|
||||
CifarOp::Builder::Builder() : sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
num_workers_ = cfg->num_parallel_workers();
|
||||
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
|
@ -46,7 +46,9 @@ CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) {
|
|||
Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
if (sampler_ == nullptr) {
|
||||
sampler_ = std::make_shared<SequentialSampler>();
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
|
||||
}
|
||||
schema_ = std::make_unique<DataSchema>();
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
|
@ -62,7 +64,7 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
|
|||
ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar)));
|
||||
}
|
||||
|
||||
*ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, num_samples_,
|
||||
*ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_,
|
||||
std::move(schema_), std::move(sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -76,16 +78,13 @@ Status CifarOp::Builder::SanityCheck() {
|
|||
}
|
||||
|
||||
CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir,
|
||||
int32_t queue_size, int64_t num_samples, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler)
|
||||
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_works, queue_size),
|
||||
cifar_type_(type),
|
||||
rows_per_buffer_(rows_per_buf),
|
||||
folder_path_(file_dir),
|
||||
num_samples_(num_samples),
|
||||
data_schema_(std::move(data_schema)),
|
||||
sampler_(std::move(sampler)),
|
||||
num_rows_(0),
|
||||
row_cnt_(0),
|
||||
buf_cnt_(0) {
|
||||
// set the column name map (base class field)
|
||||
|
@ -112,8 +111,7 @@ Status CifarOp::operator()() {
|
|||
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
|
||||
keys.push_back(*itr);
|
||||
row_cnt_++;
|
||||
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
|
||||
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
|
||||
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
|
||||
if (row_cnt_ % rows_per_buffer_ == 0) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add(
|
||||
std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
|
||||
|
@ -255,30 +253,6 @@ Status CifarOp::InitSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status CifarOp::GetNumSamples(int64_t *num) const {
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
|
||||
std::string err_msg = "There is no valid data matching the dataset API " + api +
|
||||
".Please check file path or dataset API validation first.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
(*num) = num_samples_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status CifarOp::GetNumRowsInDataset(int64_t *num) const {
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
|
||||
std::string err_msg = "There is no valid data matching the dataset API " + api +
|
||||
".Please check file path or dataset API validation first.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
(*num) = num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CifarOp::ReadCifarBlockDataAsync() {
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(GetCifarFiles());
|
||||
|
@ -404,7 +378,6 @@ Status CifarOp::ParseCifarData() {
|
|||
}
|
||||
cifar_image_label_pairs_.shrink_to_fit();
|
||||
num_rows_ = cifar_image_label_pairs_.size();
|
||||
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
|
||||
if (num_rows_ == 0) {
|
||||
std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset";
|
||||
std::string err_msg = "There is no valid data matching the dataset API " + api +
|
||||
|
@ -432,11 +405,11 @@ Status CifarOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool isCIFAR10, int64_t *count) {
|
||||
Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block()
|
||||
std::shared_ptr<CifarOp> op;
|
||||
*count = 0;
|
||||
RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetNumSamples(numSamples).SetCifarType(isCIFAR10).Build(&op));
|
||||
RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->GetCifarFiles());
|
||||
if (op->cifar_type_ == kCifar10) {
|
||||
constexpr int64_t num_cifar10_records = 10000;
|
||||
|
@ -448,7 +421,6 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool
|
|||
}
|
||||
*count = *count + num_cifar10_records;
|
||||
}
|
||||
*count = *count < numSamples || numSamples == 0 ? *count : numSamples;
|
||||
return Status::OK();
|
||||
} else {
|
||||
int64_t num_cifar100_records = 0;
|
||||
|
@ -470,7 +442,7 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool
|
|||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
*count = num_cifar100_records < numSamples || numSamples == 0 ? num_cifar100_records : numSamples;
|
||||
*count = num_cifar100_records;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,14 +73,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param uint64_t num_samples
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(uint64_t num_samples) {
|
||||
num_samples_ = num_samples;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
|
@ -121,7 +113,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
private:
|
||||
std::string dir_;
|
||||
int32_t num_workers_;
|
||||
uint64_t num_samples_;
|
||||
int32_t rows_per_buffer_;
|
||||
int32_t op_connect_size_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
|
@ -137,7 +128,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param uint32_t - queueSize - connector queue size
|
||||
// @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
|
||||
CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size,
|
||||
int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
|
||||
// Destructor.
|
||||
~CifarOp() = default;
|
||||
|
||||
|
@ -152,16 +143,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get numRows
|
||||
// @param uint64_t num - to return numRows
|
||||
// @return Status - The error code return
|
||||
Status GetNumSamples(int64_t *num) const override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
|
||||
// @param uint64_t num - to return numRows
|
||||
// @return Status - The error code return
|
||||
Status GetNumRowsInDataset(int64_t *num) const override;
|
||||
|
||||
// A print method typically used for debugging
|
||||
// @param out
|
||||
// @param show_all
|
||||
|
@ -169,11 +150,10 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
// Function to count the number of samples in the CIFAR dataset
|
||||
// @param dir path to the CIFAR directory
|
||||
// @param numSamples maximum number of samples requested
|
||||
// @param isCIFAR10 true if CIFAR10 and false if CIFAR100
|
||||
// @param count output arg that will hold the minimum of the actual dataset size and numSamples
|
||||
// @param count output arg that will hold the actual dataset size
|
||||
// @return
|
||||
static Status CountTotalRows(const std::string &dir, int64_t numSamples, bool isCIFAR10, int64_t *count);
|
||||
static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count);
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
|
@ -227,10 +207,8 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
CifarType cifar_type_;
|
||||
int32_t rows_per_buffer_;
|
||||
std::string folder_path_;
|
||||
int64_t num_samples_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
int64_t num_rows_;
|
||||
int64_t row_cnt_;
|
||||
int64_t buf_cnt_;
|
||||
|
||||
|
|
|
@ -26,8 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
ImageFolderOp::Builder::Builder()
|
||||
: builder_decode_(false), builder_recursive_(false), builder_num_samples_(0), builder_sampler_(nullptr) {
|
||||
ImageFolderOp::Builder::Builder() : builder_decode_(false), builder_recursive_(false), builder_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
|
@ -37,7 +36,9 @@ ImageFolderOp::Builder::Builder()
|
|||
Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
if (builder_sampler_ == nullptr) {
|
||||
builder_sampler_ = std::make_shared<SequentialSampler>();
|
||||
int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data
|
||||
int64_t start_index = 0;
|
||||
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
|
||||
}
|
||||
builder_schema_ = std::make_unique<DataSchema>();
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
|
@ -46,9 +47,9 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) {
|
|||
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
|
||||
ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
*ptr = std::make_shared<ImageFolderOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
|
||||
builder_op_connector_size_, builder_num_samples_, builder_recursive_,
|
||||
builder_decode_, builder_extensions_, builder_labels_to_read_,
|
||||
std::move(builder_schema_), std::move(builder_sampler_));
|
||||
builder_op_connector_size_, builder_recursive_, builder_decode_,
|
||||
builder_extensions_, builder_labels_to_read_, std::move(builder_schema_),
|
||||
std::move(builder_sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -61,20 +62,18 @@ Status ImageFolderOp::Builder::SanityCheck() {
|
|||
}
|
||||
|
||||
ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size,
|
||||
int64_t num_samples, bool recursive, bool do_decode, const std::set<std::string> &exts,
|
||||
bool recursive, bool do_decode, const std::set<std::string> &exts,
|
||||
const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_wkrs, queue_size),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
folder_path_(file_dir),
|
||||
num_samples_(num_samples),
|
||||
recursive_(recursive),
|
||||
decode_(do_decode),
|
||||
extensions_(exts),
|
||||
class_index_(map),
|
||||
data_schema_(std::move(data_schema)),
|
||||
sampler_(std::move(sampler)),
|
||||
num_rows_(0),
|
||||
row_cnt_(0),
|
||||
buf_cnt_(0),
|
||||
sampler_ind_(0),
|
||||
|
@ -117,7 +116,6 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
|
|||
}
|
||||
image_label_pairs_.shrink_to_fit();
|
||||
num_rows_ = image_label_pairs_.size();
|
||||
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
|
||||
// free memory of two queues used for pre-scan
|
||||
folder_name_queue_->Reset();
|
||||
image_name_queue_->Reset();
|
||||
|
@ -138,8 +136,7 @@ Status ImageFolderOp::operator()() {
|
|||
std::shared_ptr<Tensor> sample_ids = sample_row[0];
|
||||
if (sample_ids->type() != DataType(DataType::DE_INT64)) RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64");
|
||||
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
|
||||
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
|
||||
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
|
||||
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
|
||||
keys.push_back(*itr);
|
||||
row_cnt_++;
|
||||
if (row_cnt_ % rows_per_buffer_ == 0) {
|
||||
|
@ -272,28 +269,6 @@ Status ImageFolderOp::InitSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status ImageFolderOp::GetNumSamples(int64_t *num) const {
|
||||
if (num == nullptr || num_samples_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
(*num) = num_samples_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status ImageFolderOp::GetNumRowsInDataset(int64_t *num) const {
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
(*num) = num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status ImageFolderOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
||||
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
|
||||
|
@ -413,16 +388,14 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t &num_samples,
|
||||
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
|
||||
int64_t dev_id, int64_t num_dev) {
|
||||
Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows,
|
||||
int64_t *num_classes, int64_t dev_id, int64_t num_dev) {
|
||||
Path dir(path);
|
||||
std::string err_msg = "";
|
||||
int64_t row_cnt = 0;
|
||||
err_msg += (dir.Exists() == false || dir.IsDirectory() == false) ? "unable to open dir " + path : "";
|
||||
err_msg += (num_classes == nullptr || num_rows == nullptr) ? "num_class/num_rows is null\n" : "";
|
||||
err_msg += (dev_id >= num_dev || num_dev <= 0) ? "invalid sharding config\n" : "";
|
||||
err_msg += num_samples < 0 ? "num_samples can't be negative! set it to 0 to use all samples\n" : "";
|
||||
if (err_msg.empty() == false) {
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
@ -441,10 +414,6 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t
|
|||
while (dir_itr->hasNext()) {
|
||||
if (exts.empty() || exts.find(subdir.Extension()) != exts.end()) {
|
||||
++row_cnt;
|
||||
if (row_cnt == num_samples * num_dev) {
|
||||
(*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
foldernames.pop();
|
||||
|
|
|
@ -107,14 +107,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param int64_t num_samples
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(int64_t num_samples) {
|
||||
builder_num_samples_ = num_samples;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
|
@ -153,7 +145,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
bool builder_recursive_;
|
||||
std::string builder_dir_;
|
||||
int32_t builder_num_workers_;
|
||||
int64_t builder_num_samples_;
|
||||
int32_t builder_rows_per_buffer_;
|
||||
int32_t builder_op_connector_size_;
|
||||
std::set<std::string> builder_extensions_;
|
||||
|
@ -169,10 +160,9 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param int32_t queue_size - connector queue size
|
||||
// @param std::set<std::string> exts - set of file extensions to read, if empty, read everything under the dir
|
||||
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
|
||||
ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size,
|
||||
int64_t num_samples, bool recursive, bool do_decode, const std::set<std::string> &exts,
|
||||
const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema>,
|
||||
std::shared_ptr<Sampler> sampler);
|
||||
ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive,
|
||||
bool do_decode, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
|
||||
std::unique_ptr<DataSchema>, std::shared_ptr<Sampler> sampler);
|
||||
|
||||
// Destructor.
|
||||
~ImageFolderOp() = default;
|
||||
|
@ -198,16 +188,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get numRows
|
||||
// @param int64_t num - to return numRows
|
||||
// @return Status - The error code return
|
||||
Status GetNumSamples(int64_t *num) const override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
|
||||
// @param int64_t num - to return numRows
|
||||
// @return Status - The error code return
|
||||
Status GetNumRowsInDataset(int64_t *num) const override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
||||
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
|
||||
// @return Status - The error code return
|
||||
|
@ -221,9 +201,8 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
// This function is a hack! It is to return the num_class and num_rows the old storageOp does. The result
|
||||
// returned by this function may not be consistent with what image_folder_op is going to return
|
||||
// user this at your own risk!
|
||||
static Status CountRowsAndClasses(const std::string &path, const int64_t &num_samples,
|
||||
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
|
||||
int64_t dev_id = 0, int64_t num_dev = 1);
|
||||
static Status CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows,
|
||||
int64_t *num_classes, int64_t dev_id = 0, int64_t num_dev = 1);
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
// @param p - Pointer to the NodePass to be accepted.
|
||||
|
@ -266,14 +245,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
int32_t rows_per_buffer_;
|
||||
std::string folder_path_; // directory of image folder
|
||||
int64_t num_samples_;
|
||||
bool recursive_;
|
||||
bool decode_;
|
||||
std::set<std::string> extensions_; // extensions allowed
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
int64_t num_rows_; // total number of images in ImageFolder
|
||||
int64_t row_cnt_;
|
||||
int64_t buf_cnt_;
|
||||
int64_t sampler_ind_;
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_(0), builder_decode_(false) {
|
||||
ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_decode_(false) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
|
@ -39,16 +39,18 @@ ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_
|
|||
Status ManifestOp::Builder::Build(std::shared_ptr<ManifestOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
if (builder_sampler_ == nullptr) {
|
||||
builder_sampler_ = std::make_shared<SequentialSampler>();
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
|
||||
}
|
||||
builder_schema_ = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(
|
||||
builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(
|
||||
builder_schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
*ptr = std::make_shared<ManifestOp>(
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_file_, builder_op_connector_size_, builder_num_samples_,
|
||||
builder_decode_, builder_labels_to_read_, std::move(builder_schema_), std::move(builder_sampler_), builder_usage_);
|
||||
*ptr = std::make_shared<ManifestOp>(builder_num_workers_, builder_rows_per_buffer_, builder_file_,
|
||||
builder_op_connector_size_, builder_decode_, builder_labels_to_read_,
|
||||
std::move(builder_schema_), std::move(builder_sampler_), builder_usage_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -59,9 +61,9 @@ Status ManifestOp::Builder::SanityCheck() {
|
|||
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size,
|
||||
int64_t num_samples, bool decode, const std::map<std::string, int32_t> &class_index,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler, std::string usage)
|
||||
ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
|
||||
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler, std::string usage)
|
||||
: ParallelOp(num_works, queue_size),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
io_block_pushed_(0),
|
||||
|
@ -71,8 +73,6 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f
|
|||
file_(file),
|
||||
class_index_(class_index),
|
||||
sampler_(std::move(sampler)),
|
||||
num_samples_(num_samples),
|
||||
num_rows_(0),
|
||||
decode_(decode),
|
||||
usage_(usage),
|
||||
buf_cnt_(0) {
|
||||
|
@ -101,8 +101,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) {
|
|||
RETURN_IF_NOT_OK((*sampler_buffer)->PopRow(&sample_row));
|
||||
std::shared_ptr<Tensor> sample_ids = sample_row[0];
|
||||
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
|
||||
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
|
||||
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
|
||||
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
|
||||
keys.push_back(*itr);
|
||||
row_cnt_++;
|
||||
if (row_cnt_ % rows_per_buffer_ == 0) {
|
||||
|
@ -269,28 +268,6 @@ Status ManifestOp::InitSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status ManifestOp::GetNumSamples(int64_t *num) const {
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
(*num) = num_samples_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status ManifestOp::GetNumRowsInDataset(int64_t *num) const {
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
(*num) = num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status ManifestOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
||||
if (cls_ids == nullptr || !cls_ids->empty() || image_labelname_.empty()) {
|
||||
|
@ -408,7 +385,6 @@ Status ManifestOp::CountDatasetInfo() {
|
|||
}
|
||||
|
||||
num_rows_ = static_cast<int64_t>(image_labelname_.size());
|
||||
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API "
|
||||
|
@ -417,8 +393,8 @@ Status ManifestOp::CountDatasetInfo() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, const py::dict &dict,
|
||||
const std::string &usage, int64_t *count, int64_t *numClasses) {
|
||||
Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage,
|
||||
int64_t *count, int64_t *numClasses) {
|
||||
// the logic of counting the number of samples is copied from ParseManifestFile()
|
||||
std::map<std::string, int32_t> map;
|
||||
for (auto p : dict) {
|
||||
|
@ -428,17 +404,15 @@ Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, c
|
|||
|
||||
std::shared_ptr<ManifestOp> op;
|
||||
*count = 0;
|
||||
RETURN_IF_NOT_OK(
|
||||
Builder().SetManifestFile(file).SetNumSamples(numSamples).SetClassIndex(map).SetUsage(usage).Build(&op));
|
||||
RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseManifestFile());
|
||||
*numClasses = static_cast<int64_t>(op->label_index_.size());
|
||||
*count = static_cast<int64_t>(op->image_labelname_.size());
|
||||
*count = (*count < numSamples || numSamples == 0) ? *count : numSamples;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples, const py::dict &dict,
|
||||
const std::string &usage, std::map<std::string, int32_t> *output_class_indexing) {
|
||||
Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
|
||||
std::map<std::string, int32_t> *output_class_indexing) {
|
||||
std::map<std::string, int32_t> input_class_indexing;
|
||||
for (auto p : dict) {
|
||||
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
|
||||
|
@ -449,12 +423,7 @@ Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples,
|
|||
*output_class_indexing = input_class_indexing;
|
||||
} else {
|
||||
std::shared_ptr<ManifestOp> op;
|
||||
RETURN_IF_NOT_OK(Builder()
|
||||
.SetManifestFile(file)
|
||||
.SetNumSamples(numSamples)
|
||||
.SetClassIndex(input_class_indexing)
|
||||
.SetUsage(usage)
|
||||
.Build(&op));
|
||||
RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(input_class_indexing).SetUsage(usage).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseManifestFile());
|
||||
RETURN_IF_NOT_OK(op->CountDatasetInfo());
|
||||
uint32_t count = 0;
|
||||
|
|
|
@ -86,14 +86,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param int64_t num_samples
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(int64_t num_samples) {
|
||||
builder_num_samples_ = num_samples;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
|
@ -129,7 +121,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
private:
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
int64_t builder_num_samples_;
|
||||
bool builder_decode_;
|
||||
|
||||
std::string builder_file_;
|
||||
|
@ -147,8 +138,8 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param std::string - file list of Manifest
|
||||
// @param int32_t queue_size - connector queue size
|
||||
// @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
|
||||
ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, int64_t num_samples,
|
||||
bool decode, const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
|
||||
ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
|
||||
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler, std::string usage);
|
||||
// Destructor.
|
||||
~ManifestOp() = default;
|
||||
|
@ -164,16 +155,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get numRows
|
||||
// @param int64_t num - to return numRows
|
||||
// @return Status - The error code return
|
||||
Status GetNumSamples(int64_t *num) const override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get total number of Rows in dataset
|
||||
// @param int64_t num - to return numRows
|
||||
// @return Status - The error code return
|
||||
Status GetNumRowsInDataset(int64_t *num) const override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
||||
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
|
||||
// @return Status - The error code return
|
||||
|
@ -184,12 +165,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param show_all
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
static Status CountTotalRows(const std::string &file, int64_t numSamples, const py::dict &dict,
|
||||
const std::string &usage, int64_t *count, int64_t *numClasses);
|
||||
static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count,
|
||||
int64_t *numClasses);
|
||||
|
||||
// Get str-to-int mapping from label name to index
|
||||
static Status GetClassIndexing(const std::string &file, int64_t numSamples, const py::dict &dict,
|
||||
const std::string &usage, std::map<std::string, int32_t> *output_class_indexing);
|
||||
static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
|
||||
std::map<std::string, int32_t> *output_class_indexing);
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
|
@ -240,8 +221,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
std::string file_; // file that store the information of images
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
int64_t num_samples_;
|
||||
int64_t num_rows_;
|
||||
bool decode_;
|
||||
std::string usage_;
|
||||
int64_t buf_cnt_;
|
||||
|
|
|
@ -91,7 +91,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf
|
|||
block_reader_(block_reader),
|
||||
buffers_needed_(0),
|
||||
buf_cnt_(0),
|
||||
num_rows_(0),
|
||||
ended_worker_(0),
|
||||
buffer_water_mark_(0) {
|
||||
io_blk_queues_.Init(num_workers_, op_connector_queue_size);
|
||||
|
|
|
@ -31,7 +31,7 @@ const int32_t kMnistLabelFileMagicNumber = 2049;
|
|||
const int32_t kMnistImageRows = 28;
|
||||
const int32_t kMnistImageCols = 28;
|
||||
|
||||
MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr) {
|
||||
MnistOp::Builder::Builder() : builder_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
|
@ -41,7 +41,9 @@ MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr)
|
|||
Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
if (builder_sampler_ == nullptr) {
|
||||
builder_sampler_ = std::make_shared<SequentialSampler>();
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
|
||||
}
|
||||
builder_schema_ = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -49,9 +51,8 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) {
|
|||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
|
||||
ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
*ptr =
|
||||
std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, builder_op_connector_size_,
|
||||
builder_num_samples_, std::move(builder_schema_), std::move(builder_sampler_));
|
||||
*ptr = std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_,
|
||||
builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -60,17 +61,14 @@ Status MnistOp::Builder::SanityCheck() {
|
|||
std::string err_msg;
|
||||
err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : "";
|
||||
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : "";
|
||||
err_msg += builder_num_samples_ < 0 ? "Number of samples is set to negative\n" : "";
|
||||
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
|
||||
int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, queue_size),
|
||||
buf_cnt_(0),
|
||||
row_cnt_(0),
|
||||
num_rows_(0),
|
||||
num_samples_(num_samples),
|
||||
folder_path_(folder_path),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
sampler_(std::move(sampler)),
|
||||
|
@ -84,8 +82,7 @@ MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folde
|
|||
|
||||
Status MnistOp::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys) {
|
||||
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
|
||||
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
|
||||
if (row_cnt_ >= num_samples_) break; // enough row read, break for loop
|
||||
if ((*itr) >= num_rows_) continue; // index out of bound, skipping
|
||||
keys->push_back(*itr);
|
||||
row_cnt_++;
|
||||
if (row_cnt_ % rows_per_buffer_ == 0) {
|
||||
|
@ -219,17 +216,6 @@ Status MnistOp::InitSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status MnistOp::GetNumSamples(int64_t *num) const {
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
(*num) = num_samples_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
||||
if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) {
|
||||
|
@ -364,7 +350,6 @@ Status MnistOp::ParseMnistData() {
|
|||
}
|
||||
image_label_pairs_.shrink_to_fit();
|
||||
num_rows_ = image_label_pairs_.size();
|
||||
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -414,11 +399,11 @@ Status MnistOp::LaunchThreadsAndInitOp() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64_t *count) {
|
||||
Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
|
||||
std::shared_ptr<MnistOp> op;
|
||||
*count = 0;
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetNumSamples(numSamples).Build(&op));
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op));
|
||||
|
||||
RETURN_IF_NOT_OK(op->WalkAllFiles());
|
||||
|
||||
|
@ -440,19 +425,6 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64
|
|||
label_reader.close();
|
||||
}
|
||||
|
||||
*count = (numSamples == 0 || *count < numSamples) ? *count : numSamples;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status MnistOp::GetNumRowsInDataset(int64_t *num) const {
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
(*num) = num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -78,14 +78,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param int64_t num_samples
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(int64_t num_samples) {
|
||||
builder_num_samples_ = num_samples;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
|
@ -114,7 +106,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
private:
|
||||
std::string builder_dir_;
|
||||
int32_t builder_num_workers_;
|
||||
int64_t builder_num_samples_;
|
||||
int32_t builder_rows_per_buffer_;
|
||||
int32_t builder_op_connector_size_;
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
|
@ -126,11 +117,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
|
||||
// @param std::string folder_path - dir directory of mnist
|
||||
// @param int32_t queue_size - connector queue size
|
||||
// @param int64_t num_samples - number of samples to read
|
||||
// @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
|
||||
// @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
|
||||
MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
|
||||
int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
|
||||
|
||||
// Destructor.
|
||||
~MnistOp() = default;
|
||||
|
@ -146,16 +136,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get numRows
|
||||
// @param int64_t num - to return numRows
|
||||
// @return Status - The error code return
|
||||
Status GetNumSamples(int64_t *num) const override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset
|
||||
// @param int64_t num - to return numRows
|
||||
// @return Status - The error code return
|
||||
Status GetNumRowsInDataset(int64_t *num) const override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
||||
// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
|
||||
// @return Status - The error code return
|
||||
|
@ -167,11 +147,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Function to count the number of samples in the MNIST dataset
|
||||
// @param dir path to the MNSIT directory
|
||||
// @param numSamples maximum number of samples requested
|
||||
// @param dir path to the MNIST directory
|
||||
// @param count output arg that will hold the minimum of the actual dataset size and numSamples
|
||||
// @return
|
||||
static Status CountTotalRows(const std::string &dir, int64_t numSamples, int64_t *count);
|
||||
static Status CountTotalRows(const std::string &dir, int64_t *count);
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
|
@ -244,9 +223,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
int64_t buf_cnt_;
|
||||
int64_t row_cnt_;
|
||||
int64_t num_rows_; // total number of images in Mnist
|
||||
WaitPost wp_;
|
||||
int64_t num_samples_;
|
||||
std::string folder_path_; // directory of image folder
|
||||
int32_t rows_per_buffer_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
|
|
|
@ -8,6 +8,5 @@ add_library(engine-datasetops-source-sampler OBJECT
|
|||
sampler.cc
|
||||
sequential_sampler.cc
|
||||
subset_random_sampler.cc
|
||||
subset_sampler.cc
|
||||
weighted_random_sampler.cc
|
||||
)
|
||||
|
|
|
@ -23,8 +23,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shuffle, uint32_t seed)
|
||||
: Sampler(),
|
||||
DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
||||
uint32_t seed)
|
||||
: Sampler(num_samples, std::numeric_limits<int64_t>::max()),
|
||||
cnt_(0),
|
||||
seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed),
|
||||
device_id_(dev_id),
|
||||
|
@ -32,6 +33,11 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu
|
|||
shuffle_(shuffle) {}
|
||||
|
||||
Status DistributedSampler::InitSampler() {
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
|
||||
if (num_samples_ == 0 || num_samples_ > num_rows_) {
|
||||
num_samples_ = num_rows_;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0,
|
||||
|
|
|
@ -27,10 +27,11 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
class DistributedSampler : public Sampler {
|
||||
public:
|
||||
// @param int64_t numDev
|
||||
// @param int64_t devId
|
||||
// @param num_samples
|
||||
// @param int64_t num_dev
|
||||
// @param int64_t dev_id
|
||||
// @param bool shuffle
|
||||
DistributedSampler(int64_t num_dev, int64_t dev_id, bool shuffle = true,
|
||||
DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle,
|
||||
uint32_t seed = std::numeric_limits<uint32_t>::max());
|
||||
|
||||
// default destructor
|
||||
|
|
|
@ -20,12 +20,11 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer)
|
||||
: Sampler(samples_per_buffer),
|
||||
PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer)
|
||||
: Sampler(num_samples, samples_per_buffer),
|
||||
shuffle_(shuffle),
|
||||
seed_(GetSeed()),
|
||||
next_id_(0),
|
||||
num_pk_samples_(0),
|
||||
samples_per_class_(val) {}
|
||||
|
||||
Status PKSampler::InitSampler() {
|
||||
|
@ -36,22 +35,34 @@ Status PKSampler::InitSampler() {
|
|||
}
|
||||
}
|
||||
rnd_.seed(seed_++);
|
||||
num_pk_samples_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
|
||||
samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_;
|
||||
num_samples_ = num_pk_samples_;
|
||||
|
||||
// The special handshake gives the list of classes and id's, but it did not set the num_rows_ to
|
||||
// capture the total number of possible sample ids.
|
||||
// Compute that here for this case to find the total number of samples that are available to return.
|
||||
// (in this case, samples per class * total classes).
|
||||
num_rows_ = samples_per_class_ * static_cast<int64_t>(labels_.size());
|
||||
|
||||
// The user may have chosen to sample less than the total amount.
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
|
||||
if (num_samples_ == 0 || num_samples_ > num_rows_) {
|
||||
num_samples_ = num_rows_;
|
||||
}
|
||||
|
||||
samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_;
|
||||
if (shuffle_ == true) {
|
||||
std::shuffle(labels_.begin(), labels_.end(), rnd_);
|
||||
} else {
|
||||
std::sort(labels_.begin(), labels_.end());
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_pk_samples_ > 0, "num_class or K (num samples per class) is not positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_class or K (num samples per class) is not positive");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (next_id_ > num_pk_samples_ || num_pk_samples_ == 0) {
|
||||
if (next_id_ > num_samples_ || num_samples_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler");
|
||||
} else if (next_id_ == num_pk_samples_) {
|
||||
} else if (next_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
|
@ -60,8 +71,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sample_ids;
|
||||
int64_t last_id =
|
||||
(samples_per_buffer_ + next_id_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_ + next_id_;
|
||||
int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_));
|
||||
int64_t *id_ptr = reinterpret_cast<int64_t *>(sample_ids->GetMutableBuffer());
|
||||
while (next_id_ < last_id) {
|
||||
|
@ -85,7 +95,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
}
|
||||
|
||||
Status PKSampler::Reset() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
|
||||
next_id_ = 0;
|
||||
rnd_.seed(seed_++);
|
||||
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
class PKSampler : public Sampler { // NOT YET FINISHED
|
||||
public:
|
||||
// @param int64_t kVal
|
||||
// @param num_samples - the number of samples to draw. value of 0 means to take the full amount
|
||||
// @param int64_t val
|
||||
// @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2
|
||||
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit PKSampler(int64_t val, bool shuffle = false,
|
||||
explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// default destructor
|
||||
|
@ -42,8 +43,9 @@ class PKSampler : public Sampler { // NOT YET FINISHED
|
|||
// @return - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// first handshake between StorageOp and Sampler
|
||||
// @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds()
|
||||
// first handshake between leaf source op and Sampler. This func will determine the amount of data
|
||||
// in the dataset that we can sample from.
|
||||
// @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
|
||||
// @return
|
||||
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
|
||||
|
||||
|
@ -58,7 +60,6 @@ class PKSampler : public Sampler { // NOT YET FINISHED
|
|||
bool shuffle_;
|
||||
uint32_t seed_;
|
||||
int64_t next_id_;
|
||||
int64_t num_pk_samples_;
|
||||
int64_t samples_per_class_;
|
||||
std::mt19937 rnd_;
|
||||
std::vector<int64_t> labels_;
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer)
|
||||
: Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
|
||||
PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer)
|
||||
: Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
|
||||
|
||||
Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (need_to_reset_) {
|
||||
|
@ -65,6 +65,11 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
|
||||
Status PythonSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0");
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
|
||||
if (num_samples_ == 0 || num_samples_ > num_rows_) {
|
||||
num_samples_ = num_rows_;
|
||||
}
|
||||
{
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
|
|
|
@ -26,8 +26,11 @@ namespace dataset {
|
|||
class PythonSampler : public Sampler {
|
||||
public:
|
||||
// Constructor
|
||||
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit PythonSampler(py::object py_sampler_instance,
|
||||
// @param num_samples - the number of samples to draw. Value of 0 means to sample all of the
|
||||
// data from the dataset.
|
||||
// @param py_sampler_instance - the python instance of the sampler
|
||||
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
|
|
|
@ -22,12 +22,11 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples,
|
||||
RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
|
||||
int64_t samples_per_buffer)
|
||||
: Sampler(samples_per_buffer),
|
||||
: Sampler(num_samples, samples_per_buffer),
|
||||
seed_(GetSeed()),
|
||||
replacement_(replacement),
|
||||
user_num_samples_(num_samples),
|
||||
next_id_(0),
|
||||
reshuffle_each_epoch_(reshuffle_each_epoch),
|
||||
dist(nullptr) {}
|
||||
|
@ -70,27 +69,25 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
}
|
||||
|
||||
Status RandomSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive.");
|
||||
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
|
||||
if (num_samples_ == 0 || num_samples_ > num_rows_) {
|
||||
num_samples_ = num_rows_;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
rnd_.seed(seed_);
|
||||
|
||||
if (replacement_ == false) {
|
||||
num_samples_ = std::min(num_samples_, num_rows_);
|
||||
num_samples_ = std::min(num_samples_, user_num_samples_);
|
||||
|
||||
shuffled_ids_.reserve(num_rows_);
|
||||
for (int64_t i = 0; i < num_rows_; i++) {
|
||||
shuffled_ids_.push_back(i);
|
||||
}
|
||||
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
|
||||
} else {
|
||||
num_samples_ = std::min(num_samples_, user_num_samples_);
|
||||
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive.");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -119,7 +116,6 @@ void RandomSampler::Print(std::ostream &out, bool show_all) const {
|
|||
out << "(sampler): RandomSampler\n";
|
||||
|
||||
if (show_all) {
|
||||
out << "user_num_samples_: " << user_num_samples_ << '\n';
|
||||
out << "num_samples_: " << num_samples_ << '\n';
|
||||
out << "next_id_: " << next_id_ << '\n';
|
||||
}
|
||||
|
|
|
@ -27,11 +27,11 @@ namespace dataset {
|
|||
class RandomSampler : public Sampler {
|
||||
public:
|
||||
// Constructor
|
||||
// @param int64_t num_samples - number samples to draw
|
||||
// @param bool replacement - put he id back / or not after a sample
|
||||
// @param int64_t numSamples - number samples to draw
|
||||
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true,
|
||||
int64_t num_samples = std::numeric_limits<int64_t>::max(),
|
||||
// @param reshuffle_each_epoch - T/F to reshuffle after epoch
|
||||
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
|
@ -55,7 +55,6 @@ class RandomSampler : public Sampler {
|
|||
private:
|
||||
uint32_t seed_;
|
||||
bool replacement_;
|
||||
int64_t user_num_samples_;
|
||||
std::vector<int64_t> shuffled_ids_; // only used for NO REPLACEMENT
|
||||
int64_t next_id_;
|
||||
std::mt19937 rnd_;
|
||||
|
|
|
@ -19,8 +19,25 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Sampler::Sampler(int64_t samples_per_buffer)
|
||||
: DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {}
|
||||
Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
|
||||
// The sampler base class itself does not compute it's own num_rows_ value.
|
||||
// Instead, this value is computed by the derived leaf op during it's own initialization
|
||||
// after it has interacted with it's storage layers.
|
||||
// Here, it is just a getter method to return the value. However, it is invalid if there is
|
||||
// not a value set for this count, so generate a failure if that is the case.
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet.");
|
||||
}
|
||||
(*num) = num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer)
|
||||
: DatasetOp(0),
|
||||
num_rows_(0),
|
||||
num_samples_(num_samples),
|
||||
samples_per_buffer_(samples_per_buffer),
|
||||
col_desc_(nullptr) {}
|
||||
|
||||
Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
||||
std::shared_ptr<Sampler> child_sampler;
|
||||
|
@ -36,10 +53,10 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
|||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
|
||||
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
|
||||
|
||||
// If there's a child sampler, set the row count to be it's sample count
|
||||
if (HasChildSampler()) {
|
||||
int64_t child_num_samples = child_sampler->num_samples();
|
||||
num_rows_ = child_num_samples;
|
||||
num_rows_ = child_sampler->num_samples_;
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
|
||||
}
|
||||
|
@ -105,7 +122,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
|
|||
}
|
||||
|
||||
Status Sampler::SetNumSamples(int64_t num_samples) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples > 0, "num_samples is negative or 0");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "num_samples is negative");
|
||||
num_samples_ = num_samples;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -116,6 +133,16 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// inline op doesn't have it's own consumer, it's assigned from parent
|
||||
int32_t Sampler::num_consumers() const {
|
||||
if (parent_.empty() || parent_[0] == nullptr) {
|
||||
MS_LOG(WARNING) << "Sampler with no parent. num_consumers is 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) {
|
||||
if (child == nullptr) {
|
||||
return Status::OK();
|
||||
|
@ -155,5 +182,14 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// inline op doesn't have it's own producers, it's assigned from child
|
||||
int32_t Sampler::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(WARNING) << "Sampler with no child, num_producers is 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,23 +33,10 @@ namespace dataset {
|
|||
// must inherit from if those leaf operator wish to support sampling.
|
||||
class RandomAccessOp {
|
||||
public:
|
||||
// Sampler get numRows from StorageOp
|
||||
// @param int64_t num - return number of rows, normally num of samples
|
||||
// @return - The error code return
|
||||
virtual Status GetNumSamples(int64_t *num_samples) const {
|
||||
// CI complains num_samples not used if the following line is not added
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples != nullptr, "num_samples == nullptr");
|
||||
RETURN_STATUS_UNEXPECTED("function GetNumSamples needs to overridden to support this sampler");
|
||||
}
|
||||
|
||||
// Sampler get number of rows in the dataset!
|
||||
// Sampler get number of rows in the dataset
|
||||
// @param int64_t num - return number of rows for this dataset
|
||||
// @return - The error code return
|
||||
virtual Status GetNumRowsInDataset(int64_t *num_rows) const {
|
||||
// CI complains num_rows not used if the following line is not added
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows != nullptr, "num_rows == nullptr");
|
||||
RETURN_STATUS_UNEXPECTED("function GetNumRowsInDataset needs to overridden to support this sampler");
|
||||
}
|
||||
Status GetNumRowsInDataset(int64_t *num_rows) const;
|
||||
|
||||
// sampler gets label , imageIds from storageOp, this function is unique to PK
|
||||
// @param std::map<int64_t, std::vector<int64_t>> * map
|
||||
|
@ -60,12 +47,20 @@ class RandomAccessOp {
|
|||
|
||||
// default destructor
|
||||
virtual ~RandomAccessOp() = default;
|
||||
|
||||
protected:
|
||||
// The amount of rows in the dataset itself. This is the before-sampling value, the
|
||||
// total count of rows. A sampler may choose to sample less than this amount.
|
||||
int64_t num_rows_;
|
||||
};
|
||||
|
||||
class Sampler : public DatasetOp {
|
||||
public:
|
||||
// Constructor
|
||||
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
|
||||
// indicates that the sampler should produce the complete set of ids.
|
||||
// @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit Sampler(int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
explicit Sampler(int64_t num_samples, int64_t samples_per_buffer);
|
||||
|
||||
// default destructor
|
||||
~Sampler() = default;
|
||||
|
@ -84,33 +79,36 @@ class Sampler : public DatasetOp {
|
|||
// @return - The error code return
|
||||
Status Reset() override = 0;
|
||||
|
||||
// setter function for num_rows_
|
||||
Status SetNumRowsInDataset(int64_t num_rows);
|
||||
|
||||
// setter function for num_samples_
|
||||
Status SetNumSamples(int64_t num_samples);
|
||||
|
||||
int64_t num_samples() { return num_samples_; }
|
||||
|
||||
// first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples
|
||||
// @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds()
|
||||
// first handshake between leaf source op and Sampler. This func will determine the amount of data
|
||||
// in the dataset that we can sample from.
|
||||
// @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is
|
||||
// @return
|
||||
virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op);
|
||||
|
||||
// initialize sampler and perform checks on certain vars
|
||||
virtual Status InitSampler() { return Status::OK(); }
|
||||
|
||||
// Not meant to be called
|
||||
// setter for num samples
|
||||
// @param num_samples - the number of samples to assign.
|
||||
// @return status error code
|
||||
Status SetNumSamples(int64_t num_samples);
|
||||
|
||||
// setter for num or records in the dataset
|
||||
// @param num_rows - the number of records
|
||||
// @return status error code
|
||||
Status SetNumRowsInDataset(int64_t num_rows);
|
||||
|
||||
// Sampler is an inlined op and has no workers. Producers and consumers are computed.
|
||||
// @return
|
||||
int32_t num_workers() const final { return 0; }
|
||||
|
||||
// Not meant to be called
|
||||
// Identify num consumers (inlined op)
|
||||
// @return
|
||||
int32_t num_consumers() const final { return 0; }
|
||||
int32_t num_consumers() const final;
|
||||
|
||||
// Not meant to be called
|
||||
// Identify num producers (inlined op)
|
||||
// @return
|
||||
int32_t num_producers() const final { return 0; }
|
||||
int32_t num_producers() const final;
|
||||
|
||||
// Not meant to be called!
|
||||
// @return - The error code return
|
||||
|
@ -151,10 +149,11 @@ class Sampler : public DatasetOp {
|
|||
// output. Otherwise, num_rows_ is the number of rows in the dataset.
|
||||
int64_t num_rows_;
|
||||
|
||||
// Number of ids this sampler will return.
|
||||
// The user may want to sample less than the full amount of data. num_samples_ reduces the number
|
||||
// of id's returned as request by the user. Derived classes will choose how to sample the smaller
|
||||
// amount.
|
||||
int64_t num_samples_;
|
||||
|
||||
// The max number of ids a DataBuffer returned by this sampler will contain.
|
||||
int64_t samples_per_buffer_;
|
||||
std::unique_ptr<ColDescriptor> col_desc_;
|
||||
std::unique_ptr<DataBuffer> child_ids_;
|
||||
|
|
|
@ -20,34 +20,42 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
SequentialSampler::SequentialSampler(int64_t samples_per_buffer) : Sampler(samples_per_buffer), next_id_(0) {}
|
||||
SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer)
|
||||
: Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {}
|
||||
|
||||
Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (next_id_ > num_samples_) {
|
||||
RETURN_STATUS_UNEXPECTED("Sequential Sampler Internal Error");
|
||||
} else if (next_id_ == num_samples_) {
|
||||
if (id_count_ > num_samples_) {
|
||||
RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error");
|
||||
} else if (id_count_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone);
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(current_id_, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sampleIds;
|
||||
int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_));
|
||||
|
||||
// Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for
|
||||
// samples per buffer though.
|
||||
int64_t remaining_ids = num_samples_ - id_count_;
|
||||
int64_t num_elements = std::min(remaining_ids, samples_per_buffer_);
|
||||
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements));
|
||||
int64_t *idPtr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer());
|
||||
while (next_id_ < lastId) {
|
||||
int64_t sampled_id = next_id_;
|
||||
for (int64_t i = 0; i < num_elements; i++) {
|
||||
int64_t sampled_id = current_id_;
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*idPtr = sampled_id;
|
||||
next_id_++;
|
||||
current_id_++; // Move the current id to the next one in the sequence
|
||||
idPtr++;
|
||||
}
|
||||
|
||||
id_count_ += num_elements; // Count the packed ids towards our overall sample count
|
||||
|
||||
TensorRow row(1, sampleIds);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
|
||||
}
|
||||
|
@ -55,19 +63,24 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
|
|||
}
|
||||
|
||||
Status SequentialSampler::InitSampler() {
|
||||
num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set
|
||||
if (HasChildSampler()) {
|
||||
num_samples_ = std::min(num_samples_, num_rows_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0, "num_samples < 0\n");
|
||||
// Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample
|
||||
// the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data.
|
||||
int64_t available_row_count = num_rows_ - start_index_;
|
||||
if (num_samples_ == 0 || num_samples_ > available_row_count) {
|
||||
num_samples_ = available_row_count;
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler");
|
||||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SequentialSampler::Reset() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late");
|
||||
next_id_ = 0;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late");
|
||||
current_id_ = start_index_;
|
||||
id_count_ = 0;
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
|
|
|
@ -26,8 +26,12 @@ namespace dataset {
|
|||
class SequentialSampler : public Sampler {
|
||||
public:
|
||||
// Constructor
|
||||
// @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the
|
||||
// full amount of ids from the dataset
|
||||
// @param start_index - The starting index value
|
||||
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit SequentialSampler(int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
explicit SequentialSampler(int64_t num_samples, int64_t start_index,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
~SequentialSampler() = default;
|
||||
|
@ -48,7 +52,9 @@ class SequentialSampler : public Sampler {
|
|||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
int64_t next_id_;
|
||||
int64_t current_id_; // The id sequencer. Each new id increments from this
|
||||
int64_t start_index_; // The starting id. current_id_ begins from here.
|
||||
int64_t id_count_; // An internal counter that tracks how many ids have been produced
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,22 +27,28 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor.
|
||||
SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t samples_per_buffer)
|
||||
: Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}
|
||||
SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices,
|
||||
int64_t samples_per_buffer)
|
||||
: Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {}
|
||||
|
||||
// Initialized this Sampler.
|
||||
Status SubsetRandomSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n");
|
||||
|
||||
num_samples_ = indices_.size();
|
||||
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
// In this case, the id's are provided by the user. Cap the num_samples on the number of id's given.
|
||||
if (num_samples_ == 0 || num_samples_ > static_cast<int64_t>(indices_.size())) {
|
||||
num_samples_ = static_cast<int64_t>(indices_.size());
|
||||
}
|
||||
// Initialize random generator with seed from config manager
|
||||
rand_gen_.seed(GetSeed());
|
||||
|
||||
if (static_cast<size_t>(samples_per_buffer_) > indices_.size()) {
|
||||
samples_per_buffer_ = static_cast<int64_t>(indices_.size());
|
||||
if (samples_per_buffer_ > num_samples_) {
|
||||
samples_per_buffer_ = num_samples_;
|
||||
}
|
||||
|
||||
// num_samples_ could be smaller than the total number of input id's.
|
||||
// We will shuffle the full set of id's, but only select the first num_samples_ of them later.
|
||||
std::shuffle(indices_.begin(), indices_.end(), rand_gen_);
|
||||
|
||||
return Status::OK();
|
||||
|
@ -68,7 +74,7 @@ Status SubsetRandomSampler::Reset() {
|
|||
// Get the sample ids.
|
||||
Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
// All samples have been drawn
|
||||
if (sample_id_ == indices_.size()) {
|
||||
if (sample_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
|
@ -80,8 +86,8 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe
|
|||
|
||||
int64_t last_id = sample_id_ + samples_per_buffer_;
|
||||
// Handling the return all samples at once, and when last draw is not a full batch.
|
||||
if (static_cast<size_t>(last_id) > indices_.size()) {
|
||||
last_id = indices_.size();
|
||||
if (last_id > num_samples_) {
|
||||
last_id = num_samples_;
|
||||
}
|
||||
|
||||
// Allocate tensor
|
||||
|
|
|
@ -28,10 +28,11 @@ namespace dataset {
|
|||
class SubsetRandomSampler : public Sampler {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param num_samples The number of samples to draw. 0 for the full amount.
|
||||
// @param indices List of indices from where we will randomly draw samples.
|
||||
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
|
||||
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
|
||||
explicit SubsetRandomSampler(const std::vector<int64_t> &indices,
|
||||
explicit SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices,
|
||||
std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor.
|
||||
SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
|
||||
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
|
||||
|
||||
Status SubsetSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
|
||||
|
||||
num_samples_ = subset_size_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SubsetSampler::Reset() {
|
||||
current_id_ = 0;
|
||||
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->Reset());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SubsetSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (current_id_ > subset_size_) {
|
||||
RETURN_STATUS_UNEXPECTED("SubsetSampler Internal Error");
|
||||
} else if (current_id_ == subset_size_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_));
|
||||
}
|
||||
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
|
||||
std::shared_ptr<Tensor> sampled_ids;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampled_ids, subset_size_));
|
||||
|
||||
int64_t *sampled_ids_start_addr = reinterpret_cast<int64_t *>(sampled_ids->GetMutableBuffer());
|
||||
|
||||
while (current_id_ < subset_size_) {
|
||||
int64_t sampled_id = start_index_ + current_id_;
|
||||
if (HasChildSampler()) {
|
||||
RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
|
||||
}
|
||||
|
||||
*(sampled_ids_start_addr + current_id_) = sampled_id;
|
||||
current_id_++;
|
||||
}
|
||||
|
||||
TensorRow sampled_ids_row(1, sampled_ids);
|
||||
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, sampled_ids_row));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,58 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class SubsetSampler : public Sampler {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param start_index The index we start sampling from.
|
||||
explicit SubsetSampler(int64_t start_index, int64_t subset_size);
|
||||
|
||||
// Destructor.
|
||||
~SubsetSampler() = default;
|
||||
|
||||
// Initialize the sampler.
|
||||
// @return Status
|
||||
Status InitSampler() override;
|
||||
|
||||
// Reset the internal variable to the initial state and reshuffle the indices.
|
||||
// @return Status
|
||||
Status Reset() override;
|
||||
|
||||
// Get the sample ids.
|
||||
// @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed.
|
||||
// @note the sample ids (int64_t) will be placed in one Tensor.
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
private:
|
||||
int64_t start_index_;
|
||||
int64_t subset_size_;
|
||||
int64_t current_id_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_
|
|
@ -27,25 +27,28 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor.
|
||||
WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement,
|
||||
WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement,
|
||||
int64_t samples_per_buffer)
|
||||
: Sampler(samples_per_buffer),
|
||||
: Sampler(num_samples, samples_per_buffer),
|
||||
weights_(weights),
|
||||
replacement_(replacement),
|
||||
sample_id_(0),
|
||||
buffer_id_(0),
|
||||
user_num_samples_(num_samples) {}
|
||||
buffer_id_(0) {}
|
||||
|
||||
// Initialized this Sampler.
|
||||
Status WeightedRandomSampler::InitSampler() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive");
|
||||
// Special value of 0 for num_samples means that the user wants to sample the entire set of data.
|
||||
// If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly.
|
||||
if (num_samples_ == 0 || num_samples_ > num_rows_) {
|
||||
num_samples_ = num_rows_;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, "num_samples & num_rows need to be positive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n");
|
||||
num_samples_ = user_num_samples_;
|
||||
|
||||
// Initialize random generator with seed from config manager
|
||||
rand_gen_.seed(GetSeed());
|
||||
|
||||
samples_per_buffer_ = (samples_per_buffer_ > user_num_samples_) ? user_num_samples_ : samples_per_buffer_;
|
||||
samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_;
|
||||
|
||||
if (!replacement_) {
|
||||
exp_dist_ = std::make_unique<std::exponential_distribution<>>(1);
|
||||
|
@ -67,8 +70,8 @@ void WeightedRandomSampler::InitOnePassSampling() {
|
|||
}
|
||||
|
||||
// Partial sort the first `numSamples` elements.
|
||||
std::partial_sort(val_idx.begin(), val_idx.begin() + user_num_samples_, val_idx.end());
|
||||
for (int64_t i = 0; i < user_num_samples_; i++) {
|
||||
std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end());
|
||||
for (int64_t i = 0; i < num_samples_; i++) {
|
||||
onepass_ids_.push_back(val_idx[i].second);
|
||||
}
|
||||
}
|
||||
|
@ -98,11 +101,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
|
|||
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors");
|
||||
}
|
||||
|
||||
if (!replacement_ && (weights_.size() < static_cast<size_t>(user_num_samples_))) {
|
||||
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
|
||||
RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples");
|
||||
}
|
||||
|
||||
if (sample_id_ == user_num_samples_) {
|
||||
if (sample_id_ == num_samples_) {
|
||||
(*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE);
|
||||
} else {
|
||||
if (HasChildSampler()) {
|
||||
|
@ -114,8 +117,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf
|
|||
|
||||
int64_t last_id = sample_id_ + samples_per_buffer_;
|
||||
// Handling the return all samples at once, and when last draw is not a full batch.
|
||||
if (last_id > user_num_samples_) {
|
||||
last_id = user_num_samples_;
|
||||
if (last_id > num_samples_) {
|
||||
last_id = num_samples_;
|
||||
}
|
||||
|
||||
// Allocate tensor.
|
||||
|
|
|
@ -29,12 +29,12 @@ namespace dataset {
|
|||
class WeightedRandomSampler : public Sampler {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param weights A lift of sample weights.
|
||||
// @param num_samples Number of samples to be drawn.
|
||||
// @param weights A lift of sample weights.
|
||||
// @param replacement Determine if samples are drawn with/without replacement.
|
||||
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
|
||||
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
|
||||
WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement = true,
|
||||
WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
|
@ -69,9 +69,6 @@ class WeightedRandomSampler : public Sampler {
|
|||
// Random engine and device
|
||||
std::mt19937 rand_gen_;
|
||||
|
||||
// num_samples from user
|
||||
int64_t user_num_samples_;
|
||||
|
||||
// Discrete distribution for generating weighted random numbers with replacement.
|
||||
std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_;
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
TextFileOp::Builder::Builder()
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
|
@ -62,7 +62,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
|
|||
builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
|
||||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_,
|
||||
std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_,
|
||||
builder_num_devices_, builder_device_id_);
|
||||
RETURN_IF_NOT_OK(text_file_op->Init());
|
||||
|
@ -71,14 +71,14 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_device),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
num_samples_(num_samples),
|
||||
total_rows_(total_rows),
|
||||
text_files_list_(std::move(text_files_list)),
|
||||
shuffle_files_(shuffle_files),
|
||||
data_schema_(std::move(schema)),
|
||||
|
@ -104,9 +104,9 @@ void TextFileOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
|
||||
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nText files list:\n";
|
||||
out << "\nRows per buffer: " << rows_per_buffer_ << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_
|
||||
<< "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no")
|
||||
<< "\nText files list:\n";
|
||||
for (int i = 0; i < text_files_list_.size(); ++i) {
|
||||
out << " " << text_files_list_[i];
|
||||
}
|
||||
|
@ -404,9 +404,9 @@ Status TextFileOp::operator()() {
|
|||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
|
||||
if (buffer->eoe()) {
|
||||
workers_done++;
|
||||
} else if (num_samples_ == 0 || rows_read < num_samples_) {
|
||||
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
|
||||
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
|
||||
} else if (total_rows_ == 0 || rows_read < total_rows_) {
|
||||
if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) {
|
||||
int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read);
|
||||
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
|
||||
}
|
||||
rows_read += buffer->NumRows();
|
||||
|
|
|
@ -107,8 +107,8 @@ class TextFileOp : public ParallelOp {
|
|||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(int64_t num_samples) {
|
||||
builder_num_samples_ = num_samples;
|
||||
Builder &SetTotalRows(int64_t total_rows) {
|
||||
builder_total_rows_ = total_rows;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -118,7 +118,7 @@ class TextFileOp : public ParallelOp {
|
|||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
int64_t builder_rows_per_buffer_;
|
||||
int64_t builder_num_samples_;
|
||||
int64_t builder_total_rows_;
|
||||
int32_t builder_worker_connector_size_;
|
||||
std::vector<std::string> builder_text_files_list_;
|
||||
bool builder_shuffle_files_;
|
||||
|
@ -136,7 +136,7 @@ class TextFileOp : public ParallelOp {
|
|||
// @param columns_to_load - the names of the columns to load data from.
|
||||
// @param shuffle_files - whether or not to shuffle the files before reading data.
|
||||
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
|
||||
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
|
||||
|
@ -246,7 +246,7 @@ class TextFileOp : public ParallelOp {
|
|||
int32_t device_id_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
int64_t num_samples_;
|
||||
int64_t total_rows_;
|
||||
std::vector<std::string> text_files_list_;
|
||||
bool shuffle_files_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
|
|
@ -44,7 +44,7 @@ const char kSegmentationExtension[] = ".png";
|
|||
const char kAnnotationExtension[] = ".xml";
|
||||
const char kImageSetsExtension[] = ".txt";
|
||||
|
||||
VOCOp::Builder::Builder() : builder_decode_(false), builder_num_samples_(0), builder_sampler_(nullptr) {
|
||||
VOCOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
|
@ -55,7 +55,9 @@ VOCOp::Builder::Builder() : builder_decode_(false), builder_num_samples_(0), bui
|
|||
Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
if (builder_sampler_ == nullptr) {
|
||||
builder_sampler_ = std::make_shared<SequentialSampler>();
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples);
|
||||
}
|
||||
builder_schema_ = std::make_unique<DataSchema>();
|
||||
if (builder_task_type_ == TaskType::Segmentation) {
|
||||
|
@ -71,8 +73,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
|
|||
}
|
||||
*ptr = std::make_shared<VOCOp>(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_,
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_,
|
||||
builder_num_samples_, builder_decode_, std::move(builder_schema_),
|
||||
std::move(builder_sampler_));
|
||||
builder_decode_, std::move(builder_schema_), std::move(builder_sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -81,20 +82,16 @@ Status VOCOp::Builder::SanityCheck() {
|
|||
std::string err_msg;
|
||||
err_msg += dir.IsDirectory() == false ? "VOC path is invalid or not set\n" : "";
|
||||
err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : "";
|
||||
err_msg += builder_num_samples_ < 0 ? "num_samples is negative\n" : "";
|
||||
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path,
|
||||
const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer,
|
||||
int32_t queue_size, int64_t num_samples, bool decode, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler)
|
||||
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, queue_size),
|
||||
decode_(decode),
|
||||
row_cnt_(0),
|
||||
buf_cnt_(0),
|
||||
num_rows_(0),
|
||||
num_samples_(num_samples),
|
||||
task_type_(task_type),
|
||||
task_mode_(task_mode),
|
||||
folder_path_(folder_path),
|
||||
|
@ -112,7 +109,6 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std:
|
|||
Status VOCOp::TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys) {
|
||||
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
|
||||
if ((*itr) > num_rows_) continue;
|
||||
if (row_cnt_ == num_samples_) break;
|
||||
keys->push_back(*itr);
|
||||
row_cnt_++;
|
||||
if (row_cnt_ % rows_per_buffer_ == 0) {
|
||||
|
@ -187,16 +183,6 @@ Status VOCOp::Reset() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::GetNumSamples(int64_t *num) const {
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API VOCDataset.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
(*num) = num_samples_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::LoadTensorRow(const std::string &image_id, TensorRow *trow) {
|
||||
if (task_type_ == TaskType::Segmentation) {
|
||||
std::shared_ptr<Tensor> image, target;
|
||||
|
@ -280,7 +266,6 @@ Status VOCOp::ParseImageIds() {
|
|||
in_file.close();
|
||||
image_ids_.shrink_to_fit();
|
||||
num_rows_ = image_ids_.size();
|
||||
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -305,7 +290,6 @@ Status VOCOp::ParseAnnotationIds() {
|
|||
}
|
||||
|
||||
num_rows_ = image_ids_.size();
|
||||
num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -432,19 +416,8 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescripto
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp
|
||||
Status VOCOp::GetNumRowsInDataset(int64_t *num) const {
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API VOCDataset.Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
(*num) = num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples, int64_t *count) {
|
||||
const py::dict &dict, int64_t *count) {
|
||||
if (task_type == "Detection") {
|
||||
std::map<std::string, int32_t> input_class_indexing;
|
||||
for (auto p : dict) {
|
||||
|
@ -464,14 +437,12 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ
|
|||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
*count = static_cast<int64_t>(op->image_ids_.size());
|
||||
}
|
||||
*count = (numSamples == 0 || *count < numSamples) ? *count : numSamples;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples,
|
||||
std::map<std::string, int32_t> *output_class_indexing) {
|
||||
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing) {
|
||||
std::map<std::string, int32_t> input_class_indexing;
|
||||
for (auto p : dict) {
|
||||
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
|
||||
|
|
|
@ -116,14 +116,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @param int64_t num_samples
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(int64_t num_samples) {
|
||||
builder_num_samples_ = num_samples;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
|
@ -157,7 +149,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
int32_t builder_rows_per_buffer_;
|
||||
int64_t builder_num_samples_;
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
std::unique_ptr<DataSchema> builder_schema_;
|
||||
std::map<std::string, int32_t> builder_labels_to_read_;
|
||||
|
@ -171,14 +162,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param int32_t num_workers - number of workers reading images in parallel
|
||||
// @param int32_t rows_per_buffer - number of images (rows) in each buffer
|
||||
// @param int32_t queue_size - connector queue size
|
||||
// @param int64_t num_samples - number of samples to read
|
||||
// @param bool decode - whether to decode images
|
||||
// @param std::unique_ptr<DataSchema> data_schema - the schema of the VOC dataset
|
||||
// @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read
|
||||
VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path,
|
||||
const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer,
|
||||
int32_t queue_size, int64_t num_samples, bool decode, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler);
|
||||
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
|
||||
|
||||
// Destructor
|
||||
~VOCOp() = default;
|
||||
|
@ -194,15 +183,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// Method derived from RandomAccessOp, enable Sampler to get numRows
|
||||
// @param uint64_t num - to return numRows
|
||||
// return Status - The error code return
|
||||
Status GetNumSamples(int64_t *num) const override;
|
||||
|
||||
// Method derived from RandomAccessOp, enable Sampler to get total number of rows in dataset
|
||||
// @param uint64_t num - to return numRows
|
||||
Status GetNumRowsInDataset(int64_t *num) const override;
|
||||
|
||||
// A print method typically used for debugging
|
||||
// @param out
|
||||
// @param show_all
|
||||
|
@ -212,10 +192,9 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param const std::string &task_type - task type of reading voc job
|
||||
// @param const std::string &task_mode - task mode of reading voc job
|
||||
// @param const py::dict &dict - input dict of class index
|
||||
// @param int64_t numSamples - samples number of VOCDataset
|
||||
// @param int64_t *count - output rows number of VOCDataset
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples, int64_t *count);
|
||||
const py::dict &dict, int64_t *count);
|
||||
|
||||
// @param const std::string &dir - VOC dir path
|
||||
// @param const std::string &task_type - task type of reading voc job
|
||||
|
@ -224,8 +203,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param int64_t numSamples - samples number of VOCDataset
|
||||
// @param std::map<std::string, int32_t> *output_class_indexing - output class index of VOCDataset
|
||||
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples,
|
||||
std::map<std::string, int32_t> *output_class_indexing);
|
||||
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing);
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
|
@ -283,8 +261,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
bool decode_;
|
||||
int64_t row_cnt_;
|
||||
int64_t buf_cnt_;
|
||||
int64_t num_rows_;
|
||||
int64_t num_samples_;
|
||||
std::string folder_path_;
|
||||
TaskType task_type_;
|
||||
std::string task_mode_;
|
||||
|
|
|
@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
|
|||
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
|
||||
Schema, Shuffle, zip, RandomDataset
|
||||
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
||||
WeightedRandomSampler, SubsetSampler, Sampler
|
||||
WeightedRandomSampler, Sampler
|
||||
from .engine.serializer_deserializer import serialize, deserialize, show
|
||||
from .engine.graphdata import GraphData
|
||||
|
||||
|
|
|
@ -1261,8 +1261,8 @@ class MappableDataset(SourceDataset):
|
|||
|
||||
def _get_sampler_dataset_size(self):
|
||||
if self.sampler is not None:
|
||||
if hasattr(self.sampler, 'get_dataset_size'):
|
||||
return self.sampler.get_dataset_size()
|
||||
if hasattr(self.sampler, 'get_num_samples'):
|
||||
return self.sampler.get_num_samples()
|
||||
if hasattr(self.sampler, '__len__'):
|
||||
return len(self.sampler)
|
||||
|
||||
|
@ -1355,7 +1355,7 @@ class MappableDataset(SourceDataset):
|
|||
random_sampler.reshuffle_each_epoch = False
|
||||
ds.add_sampler(random_sampler)
|
||||
|
||||
subset_sampler = samplers.SubsetSampler(current_split_start_index, size)
|
||||
subset_sampler = samplers.SequentialSampler(current_split_start_index, size)
|
||||
ds.add_sampler(subset_sampler)
|
||||
|
||||
# add sequential sampler, so that if user calls use_sampler, we will
|
||||
|
@ -2226,31 +2226,45 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|||
num_shards (int): Number of shard for sharding.
|
||||
shard_id (int): Shard ID.
|
||||
"""
|
||||
if input_sampler is not None:
|
||||
# If the user provided a sampler, then it doesn't matter what the other args are because
|
||||
# we are being asked specifically to use the given sampler.
|
||||
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
|
||||
# be None. Consider this example:
|
||||
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
|
||||
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
|
||||
# In this case, the user has given different sample-related arguments that contradict each other.
|
||||
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
|
||||
if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
||||
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
||||
samplers.WeightedRandomSampler, samplers.Sampler)) and
|
||||
(num_shards is not None or shard_id is not None or shuffle is not None or num_samples is not None)):
|
||||
raise ValueError(
|
||||
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
|
||||
' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle))
|
||||
return input_sampler
|
||||
if shuffle is None:
|
||||
if input_sampler is not None:
|
||||
# If shuffle is not specified, user provided sampler, use user's sampler
|
||||
return input_sampler
|
||||
if num_shards is not None:
|
||||
# If shuffle is not specified, sharding enabled, use distributed random sampler
|
||||
shuffle = True
|
||||
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
|
||||
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
||||
# If shuffle is not specified, sharding disabled, use random sampler
|
||||
if num_samples is not None:
|
||||
return samplers.RandomSampler(replacement=True, num_samples=num_samples)
|
||||
return samplers.RandomSampler()
|
||||
return samplers.RandomSampler(num_samples=num_samples)
|
||||
if shuffle is True:
|
||||
if num_shards is not None:
|
||||
# If shuffle enabled, sharding enabled, use distributed random sampler
|
||||
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
|
||||
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
||||
# If shuffle enabled, sharding disabled, use random sampler
|
||||
if num_samples is not None:
|
||||
return samplers.RandomSampler(replacement=True, num_samples=num_samples)
|
||||
return samplers.RandomSampler()
|
||||
return samplers.RandomSampler(num_samples=num_samples)
|
||||
if num_shards is not None:
|
||||
# If shuffle disabled, sharding enabled, use distributed sequential sampler
|
||||
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
|
||||
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
||||
# If shuffle disabled, sharding disabled, use sequential sampler
|
||||
return samplers.SequentialSampler()
|
||||
return samplers.SequentialSampler(num_samples=num_samples)
|
||||
|
||||
|
||||
class ImageFolderDatasetV2(MappableDataset):
|
||||
|
@ -2370,11 +2384,7 @@ class ImageFolderDatasetV2(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0]
|
||||
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0]
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
|
@ -2390,11 +2400,7 @@ class ImageFolderDatasetV2(MappableDataset):
|
|||
Return:
|
||||
Number, number of classes.
|
||||
"""
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[1]
|
||||
return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[1]
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -2503,12 +2509,7 @@ class MnistDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
|
||||
num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples)
|
||||
num_rows = MnistOp.get_num_rows(self.dataset_dir)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
|
@ -2956,11 +2957,8 @@ class GeneratorDataset(MappableDataset):
|
|||
if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
|
||||
samplers.RandomSampler, samplers.SubsetRandomSampler,
|
||||
samplers.WeightedRandomSampler, samplers.Sampler)):
|
||||
if num_samples is None:
|
||||
num_samples = len(source)
|
||||
sampler_instance = self.sampler.create()
|
||||
sampler_instance.set_num_rows(len(source))
|
||||
sampler_instance.set_num_samples(num_samples)
|
||||
sampler_instance.initialize()
|
||||
if num_parallel_workers > 1:
|
||||
self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers))
|
||||
|
@ -3304,17 +3302,12 @@ class ManifestDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
|
||||
if self.class_indexing is None:
|
||||
class_indexing = dict()
|
||||
else:
|
||||
class_indexing = self.class_indexing
|
||||
|
||||
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0]
|
||||
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0]
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
|
@ -3330,17 +3323,12 @@ class ManifestDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of classes.
|
||||
"""
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
|
||||
if self.class_indexing is None:
|
||||
class_indexing = dict()
|
||||
else:
|
||||
class_indexing = self.class_indexing
|
||||
|
||||
return ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[1]
|
||||
return ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[1]
|
||||
|
||||
def get_class_indexing(self):
|
||||
"""
|
||||
|
@ -3349,17 +3337,12 @@ class ManifestDataset(MappableDataset):
|
|||
Return:
|
||||
Dict, A str-to-int mapping from label name to index.
|
||||
"""
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
|
||||
if self.class_indexing is None:
|
||||
class_indexing = dict()
|
||||
else:
|
||||
class_indexing = self.class_indexing
|
||||
|
||||
return ManifestOp.get_class_indexing(self.dataset_file, num_samples, class_indexing, self.usage)
|
||||
return ManifestOp.get_class_indexing(self.dataset_file, class_indexing, self.usage)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -3473,12 +3456,8 @@ class Cifar10Dataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True)
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
|
@ -3597,12 +3576,8 @@ class Cifar100Dataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False)
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, False)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
|
@ -3631,7 +3606,7 @@ class RandomDataset(SourceDataset):
|
|||
Args:
|
||||
num_samples (int): number of samples to generate.
|
||||
schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
|
||||
If the schema is not provided, the meta data from the TFRecord file is considered the schema.
|
||||
If the schema is not provided, the random dataset generates a random schema.
|
||||
columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
|
||||
num_parallel_workers (int, optional): number of workers to read the data
|
||||
(default=None, number set in the config).
|
||||
|
@ -3644,9 +3619,12 @@ class RandomDataset(SourceDataset):
|
|||
schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
|
||||
self.schema = schema
|
||||
self.columns_list = columns_list
|
||||
self.num_samples = num_samples
|
||||
if schema_obj is not None and num_samples is None:
|
||||
self.num_samples = schema_obj.num_rows
|
||||
elif num_samples is None:
|
||||
self.num_samples = 0
|
||||
else:
|
||||
self.num_samples = num_samples
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -4015,17 +3993,12 @@ class VOCDataset(MappableDataset):
|
|||
if self.task != "Detection":
|
||||
raise NotImplementedError()
|
||||
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
|
||||
if self.class_indexing is None:
|
||||
class_indexing = dict()
|
||||
else:
|
||||
class_indexing = self.class_indexing
|
||||
|
||||
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
|
||||
return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing)
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -4205,9 +4178,11 @@ class TextFileDataset(SourceDataset):
|
|||
if self._dataset_size is None:
|
||||
num_rows = TextFileOp.get_num_rows(self.dataset_files)
|
||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is None:
|
||||
return num_rows
|
||||
return min(self.num_samples, num_rows)
|
||||
# If the user gave a num samples in the dataset, then the sampler will limit the rows returned
|
||||
# to that amount. Account for that here in the row count
|
||||
if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples:
|
||||
num_rows = self.num_samples
|
||||
return num_rows
|
||||
return self._dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
|
|
|
@ -22,7 +22,6 @@ User can also define custom sampler by extending from Sampler class.
|
|||
import numpy as np
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
|
||||
class Sampler:
|
||||
"""
|
||||
Base class for user defined sampler.
|
||||
|
@ -44,10 +43,10 @@ class Sampler:
|
|||
>>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler())
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, num_samples=None):
|
||||
self.dataset_size = 0
|
||||
self.num_samples = 0
|
||||
self.child_sampler = None
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
|
@ -84,7 +83,8 @@ class Sampler:
|
|||
# Instance fetcher
|
||||
# Do not override this method!
|
||||
def create(self):
|
||||
c_sampler = cde.PythonSampler(self)
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.PythonSampler(num_samples, self)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
@ -114,7 +114,7 @@ class Sampler:
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
def get_num_samples(self):
|
||||
return self._get_indices().size
|
||||
|
||||
|
||||
|
@ -124,8 +124,9 @@ class BuiltinSampler:
|
|||
|
||||
User should not extend this class.
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self, num_samples=None):
|
||||
self.child_sampler = None
|
||||
self.num_samples = num_samples
|
||||
|
||||
def create(self):
|
||||
pass
|
||||
|
@ -149,11 +150,37 @@ class BuiltinSampler:
|
|||
def is_sharded(self):
|
||||
raise NotImplementedError("Sampler must implement is_sharded.")
|
||||
|
||||
def get_dataset_size(self):
|
||||
if self.child_sampler is not None:
|
||||
return self.child_sampler.get_dataset_size()
|
||||
def get_num_samples(self):
|
||||
"""
|
||||
All samplers can contain a numeric num_samples value (or it could be set to None).
|
||||
Child sampler can exist or be None.
|
||||
if child sampler exists, then the child sampler count can be a numeric value or None.
|
||||
Given these conditions, we need to output what the sampler count is for this sampler.
|
||||
The following table shows the possible results from calling this function.
|
||||
|
||||
return None
|
||||
child sampler num_samples child_samples result
|
||||
------------- ----------- ------------- --------
|
||||
T x y min(x, y)
|
||||
T x None x
|
||||
T None y y
|
||||
T None None None
|
||||
None x n/a x
|
||||
None None n/a None
|
||||
|
||||
Returns:
|
||||
int, The number of samples, or None
|
||||
"""
|
||||
if self.child_sampler is not None:
|
||||
child_samples = self.child_sampler.get_num_samples()
|
||||
if self.num_samples is not None:
|
||||
if child_samples is not None:
|
||||
return min(self.num_samples, child_samples)
|
||||
|
||||
return self.num_samples
|
||||
|
||||
return child_samples
|
||||
|
||||
return self.num_samples
|
||||
|
||||
|
||||
class DistributedSampler(BuiltinSampler):
|
||||
|
@ -164,6 +191,7 @@ class DistributedSampler(BuiltinSampler):
|
|||
num_shards (int): Number of shards to divide the dataset into.
|
||||
shard_id (int): Shard ID of the current shard within num_shards.
|
||||
shuffle (bool, optional): If true, the indices are shuffled (default=True).
|
||||
num_samples (int, optional): The number of samples to draw (default=None, all elements).
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -180,7 +208,7 @@ class DistributedSampler(BuiltinSampler):
|
|||
ValueError: If shuffle is not a boolean value.
|
||||
"""
|
||||
|
||||
def __init__(self, num_shards, shard_id, shuffle=True):
|
||||
def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None):
|
||||
if num_shards <= 0:
|
||||
raise ValueError("num_shards should be a positive integer value, but got num_shards={}".format(num_shards))
|
||||
|
||||
|
@ -194,12 +222,13 @@ class DistributedSampler(BuiltinSampler):
|
|||
self.shard_id = shard_id
|
||||
self.shuffle = shuffle
|
||||
self.seed = 0
|
||||
super().__init__()
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
|
||||
self.seed += 1
|
||||
c_sampler = cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
|
||||
c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id, self.shuffle, self.seed)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
@ -226,6 +255,7 @@ class PKSampler(BuiltinSampler):
|
|||
num_class (int, optional): Number of classes to sample (default=None, all classes).
|
||||
shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
|
||||
class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset.
|
||||
num_samples (int, optional): The number of samples to draw (default=None, all elements).
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -242,7 +272,7 @@ class PKSampler(BuiltinSampler):
|
|||
ValueError: If shuffle is not boolean.
|
||||
"""
|
||||
|
||||
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'):
|
||||
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
|
||||
if num_val <= 0:
|
||||
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
|
||||
|
||||
|
@ -255,10 +285,11 @@ class PKSampler(BuiltinSampler):
|
|||
self.num_val = num_val
|
||||
self.shuffle = shuffle
|
||||
self.class_column = class_column # work for minddataset
|
||||
super().__init__()
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
c_sampler = cde.PKSampler(self.num_val, self.shuffle)
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
@ -309,23 +340,18 @@ class RandomSampler(BuiltinSampler):
|
|||
raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement))
|
||||
|
||||
if num_samples is not None:
|
||||
if num_samples <= 0:
|
||||
if num_samples < 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(num_samples))
|
||||
|
||||
self.deterministic = False
|
||||
self.replacement = replacement
|
||||
self.num_samples = num_samples
|
||||
self.reshuffle_each_epoch = True
|
||||
super().__init__()
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
c_sampler = None
|
||||
if self.num_samples is None:
|
||||
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch)
|
||||
else:
|
||||
c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch, self.num_samples)
|
||||
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
@ -339,14 +365,15 @@ class RandomSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
class SequentialSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples the dataset elements sequentially, same as not having a sampler.
|
||||
|
||||
Args:
|
||||
start_index (int, optional): Index to start sampling at. (dafault=None starts at first id)
|
||||
num_samples (int, optional): Number of elements to sample (default=None, all elements).
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
|
@ -357,66 +384,14 @@ class SequentialSampler(BuiltinSampler):
|
|||
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
|
||||
"""
|
||||
|
||||
def create(self):
|
||||
c_sampler = cde.SequentialSampler()
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_shuffled()
|
||||
|
||||
def is_sharded(self):
|
||||
if self.child_sampler is None:
|
||||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
|
||||
class SubsetSampler(BuiltinSampler):
|
||||
"""
|
||||
Samples a subset of elements consecutively from a given index.
|
||||
|
||||
Args:
|
||||
start_index (int): Index to start sampling at.
|
||||
subset_size (int): How many samples to include in this subset.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> dataset_dir = "path/to/imagefolder_directory"
|
||||
>>>
|
||||
>>> # creates a SubsetSampler, will sample the next 5 images from the 100th image.
|
||||
>>> sampler = ds.SubsetSampler(100, 5)
|
||||
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
|
||||
|
||||
Raises:
|
||||
ValueError: If start_index is not a positive int.
|
||||
ValueError: If subset_size is not a positive int.
|
||||
"""
|
||||
|
||||
def __init__(self, start_index, subset_size):
|
||||
if not isinstance(start_index, int):
|
||||
raise ValueError("start_index should be an int.")
|
||||
|
||||
if start_index < 0:
|
||||
raise ValueError("start_index should not be negative.")
|
||||
|
||||
if not isinstance(subset_size, int):
|
||||
raise ValueError("start_index should be an int")
|
||||
|
||||
if subset_size < 0:
|
||||
raise ValueError("subset_size should not be negative.")
|
||||
|
||||
def __init__(self, start_index=None, num_samples=None):
|
||||
self.start_index = start_index
|
||||
self.subset_size = subset_size
|
||||
super().__init__()
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
c_sampler = cde.SubsetSampler(self.start_index, self.subset_size)
|
||||
start_index = self.start_index if self.start_index is not None else 0
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.SequentialSampler(num_samples, start_index)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
@ -433,9 +408,6 @@ class SubsetSampler(BuiltinSampler):
|
|||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.subset_size
|
||||
|
||||
|
||||
class SubsetRandomSampler(BuiltinSampler):
|
||||
"""
|
||||
|
@ -443,6 +415,7 @@ class SubsetRandomSampler(BuiltinSampler):
|
|||
|
||||
Args:
|
||||
indices (list[int]): A sequence of indices.
|
||||
num_samples (int, optional): Number of elements to sample (default=None, all elements).
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -456,15 +429,16 @@ class SubsetRandomSampler(BuiltinSampler):
|
|||
>>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
|
||||
"""
|
||||
|
||||
def __init__(self, indices):
|
||||
def __init__(self, indices, num_samples=None):
|
||||
if not isinstance(indices, list):
|
||||
indices = [indices]
|
||||
|
||||
self.indices = indices
|
||||
super().__init__()
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
c_sampler = cde.SubsetRandomSampler(self.indices)
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.SubsetRandomSampler(num_samples, self.indices)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
@ -481,9 +455,9 @@ class SubsetRandomSampler(BuiltinSampler):
|
|||
def _create_for_minddataset(self):
|
||||
return cde.MindrecordSubsetRandomSampler(self.indices)
|
||||
|
||||
|
||||
def get_dataset_size(self):
|
||||
return len(self.indices)
|
||||
def get_num_samples(self):
|
||||
num_samples = super().get_num_samples()
|
||||
return min(len(self.indices), num_samples)
|
||||
|
||||
|
||||
class WeightedRandomSampler(BuiltinSampler):
|
||||
|
@ -492,7 +466,7 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
|
||||
Args:
|
||||
weights (list[float]): A sequence of weights, not necessarily summing up to 1.
|
||||
num_samples (int): Number of elements to sample.
|
||||
num_samples (int): Number of elements to sample (default=None, all elements).
|
||||
replacement (bool, optional): If True, put the sample ID back for the next draw (default=True).
|
||||
|
||||
Examples:
|
||||
|
@ -511,24 +485,25 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
ValueError: If replacement is not boolean.
|
||||
"""
|
||||
|
||||
def __init__(self, weights, num_samples, replacement=True):
|
||||
def __init__(self, weights, num_samples=None, replacement=True):
|
||||
if not isinstance(weights, list):
|
||||
weights = [weights]
|
||||
|
||||
if num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(num_samples))
|
||||
if num_samples is not None:
|
||||
if num_samples < 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(num_samples))
|
||||
|
||||
if not isinstance(replacement, bool):
|
||||
raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement))
|
||||
|
||||
self.weights = weights
|
||||
self.num_samples = num_samples
|
||||
self.replacement = replacement
|
||||
super().__init__()
|
||||
super().__init__(num_samples)
|
||||
|
||||
def create(self):
|
||||
c_sampler = cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement)
|
||||
c_child_sampler = self.create_child()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
@ -541,6 +516,3 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
return False
|
||||
|
||||
return self.child_sampler.is_sharded()
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.num_samples
|
||||
|
|
|
@ -161,6 +161,20 @@ def traverse(node):
|
|||
else:
|
||||
node_repr[k] = v
|
||||
|
||||
# If a sampler exists in this node, then the following 4 arguments must be set to None:
|
||||
# num_samples, shard_id, num_shards, shuffle
|
||||
# These arguments get moved into the sampler itself, so they are no longer needed to
|
||||
# be set at the dataset level.
|
||||
if 'sampler' in node_args.keys():
|
||||
if 'num_samples' in node_repr.keys():
|
||||
node_repr['num_samples'] = None
|
||||
if 'shuffle' in node_repr.keys():
|
||||
node_repr['shuffle'] = None
|
||||
if 'num_shards' in node_repr.keys():
|
||||
node_repr['num_shards'] = None
|
||||
if 'shard_id' in node_repr.keys():
|
||||
node_repr['shard_id'] = None
|
||||
|
||||
# Leaf node doesn't have input attribute.
|
||||
if not node.input:
|
||||
return node_repr
|
||||
|
|
|
@ -283,8 +283,8 @@ def check_num_parallel_workers(value):
|
|||
|
||||
def check_num_samples(value):
|
||||
check_type(value, 'num_samples', int)
|
||||
if value <= 0:
|
||||
raise ValueError("num_samples must be greater than 0!")
|
||||
if value < 0:
|
||||
raise ValueError("num_samples cannot be less than 0!")
|
||||
|
||||
|
||||
def check_dataset_dir(dataset_dir):
|
||||
|
|
|
@ -39,14 +39,13 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
|
|||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
std::shared_ptr<CelebAOp> Celeba(int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size,
|
||||
const std::string &dir, int64_t num_samples = 0,
|
||||
std::unique_ptr<Sampler> sampler = nullptr, bool decode = false,
|
||||
const std::string &dataset_type="all") {
|
||||
const std::string &dir, std::shared_ptr<Sampler> sampler = nullptr,
|
||||
bool decode = false, const std::string &dataset_type="all") {
|
||||
std::shared_ptr<CelebAOp> so;
|
||||
CelebAOp::Builder builder;
|
||||
Status rc = builder.SetNumWorkers(num_workers).SetCelebADir(dir).SetRowsPerBuffer(rows_per_buffer)
|
||||
.SetOpConnectorSize(queue_size).SetSampler(std::move(sampler)).SetDecode(decode)
|
||||
.SetNumSamples(num_samples).SetDatasetType(dataset_type).Build(&so);
|
||||
.SetDatasetType(dataset_type).Build(&so);
|
||||
return so;
|
||||
}
|
||||
|
||||
|
@ -116,11 +115,12 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
|
|||
|
||||
TEST_F(MindDataTestCelebaDataset, TestSubsetRandomSamplerCeleba) {
|
||||
std::vector<int64_t> indices({1});
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<SubsetRandomSampler>(indices);
|
||||
int64_t num_samples = 0;
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<SubsetRandomSampler>(num_samples, indices);
|
||||
uint32_t expect_labels[1][40] = {{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}};
|
||||
std::string dir = datasets_root_path_ + "/testCelebAData/";
|
||||
uint32_t count = 0;
|
||||
auto tree = Build({Celeba(16, 2, 32, dir, 0, std::move(sampler))});
|
||||
auto tree = Build({Celeba(16, 2, 32, dir, std::move(sampler))});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -143,25 +143,3 @@ TEST_F(MindDataTestCelebaDataset, TestSubsetRandomSamplerCeleba) {
|
|||
EXPECT_TRUE(count == 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCelebaDataset, TestCelebaNumSamples) {
|
||||
std::string dir = datasets_root_path_ + "/testCelebAData/";
|
||||
uint32_t count = 0;
|
||||
auto tree = Build({Celeba(16, 2, 32, dir, 1)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << ".";
|
||||
EXPECT_TRUE(false);
|
||||
} else {
|
||||
DatasetIterator di(tree);
|
||||
TensorMap tersor_map;
|
||||
di.GetNextAsMap(&tersor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while (tersor_map.size() != 0) {
|
||||
count++;
|
||||
di.GetNextAsMap(&tersor_map);
|
||||
}
|
||||
EXPECT_TRUE(count == 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,13 +45,12 @@ std::shared_ptr<RepeatOp> Repeat(int repeatCnt);
|
|||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
std::shared_ptr<CifarOp> Cifarop(uint64_t num_works, uint64_t rows, uint64_t conns, std::string path,
|
||||
std::unique_ptr<Sampler> sampler = nullptr,
|
||||
uint64_t num_samples = 0, bool cifar10 = true) {
|
||||
std::shared_ptr<Sampler> sampler = nullptr, bool cifar10 = true) {
|
||||
std::shared_ptr<CifarOp> so;
|
||||
CifarOp::Builder builder;
|
||||
Status rc = builder.SetNumWorkers(num_works).SetCifarDir(path).SetRowsPerBuffer(rows)
|
||||
.SetOpConnectorSize(conns).SetSampler(std::move(sampler)).SetCifarType(cifar10)
|
||||
.SetNumSamples(num_samples).Build(&so);
|
||||
.Build(&so);
|
||||
return so;
|
||||
}
|
||||
|
||||
|
@ -66,7 +65,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
|
|||
//appear in this dataset
|
||||
//Example: python tests/dataset/data/prep_data.py
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100)});
|
||||
auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -79,7 +78,8 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
|
|||
EXPECT_TRUE(rc.IsOk());
|
||||
uint64_t i = 0;
|
||||
uint32_t label = 0;
|
||||
while (tensor_map.size() != 0) {
|
||||
// Note: only iterating first 100 rows then break out.
|
||||
while (tensor_map.size() != 0 && i < 100) {
|
||||
tensor_map["label"]->GetItemAt<uint32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n";
|
||||
i++;
|
||||
|
@ -92,9 +92,9 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) {
|
|||
TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) {
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(0);
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
|
||||
std::shared_ptr<Sampler> sampler = std::make_unique<RandomSampler>(12, true, true);
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)});
|
||||
auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler))});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -118,34 +118,9 @@ TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) {
|
|||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCifarOp, TestCifar10NumSample) {
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
|
||||
auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Return code error detected during tree launch: " << common::SafeCStr(rc.ToString()) << ".";
|
||||
EXPECT_TRUE(false);
|
||||
} else {
|
||||
DatasetIterator di(tree);
|
||||
TensorMap tensor_map;
|
||||
di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
uint64_t i = 0;
|
||||
uint32_t label = 0;
|
||||
while (tensor_map.size() != 0) {
|
||||
tensor_map["label"]->GetItemAt<uint32_t>(&label, {});
|
||||
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n";
|
||||
i++;
|
||||
di.GetNextAsMap(&tensor_map);
|
||||
}
|
||||
EXPECT_TRUE(i == 100);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar100) {
|
||||
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
|
||||
auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100, false)});
|
||||
auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, false)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -159,7 +134,8 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar100) {
|
|||
uint64_t i = 0;
|
||||
uint32_t coarse = 0;
|
||||
uint32_t fine = 0;
|
||||
while (tensor_map.size() != 0) {
|
||||
// only iterate to 100 then break out of loop
|
||||
while (tensor_map.size() != 0 && i < 100) {
|
||||
tensor_map["coarse_label"]->GetItemAt<uint32_t>(&coarse, {});
|
||||
tensor_map["fine_label"]->GetItemAt<uint32_t>(&fine, {});
|
||||
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << " coarse:"
|
||||
|
|
|
@ -50,9 +50,8 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
|
|||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
|
||||
bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, int64_t num_samples = 0,
|
||||
bool decode = false) {
|
||||
bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, bool decode = false) {
|
||||
std::shared_ptr<ImageFolderOp> so;
|
||||
ImageFolderOp::Builder builder;
|
||||
Status rc = builder.SetNumWorkers(num_works)
|
||||
|
@ -63,7 +62,6 @@ std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int6
|
|||
.SetSampler(std::move(sampler))
|
||||
.SetClassIndex(map)
|
||||
.SetDecode(decode)
|
||||
.SetNumSamples(num_samples)
|
||||
.Build(&so);
|
||||
return so;
|
||||
}
|
||||
|
@ -138,7 +136,8 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) {
|
|||
TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) {
|
||||
int32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(0);
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12);
|
||||
int64_t num_samples = 12;
|
||||
std::shared_ptr<Sampler> sampler = std::make_unique<RandomSampler>(num_samples, true, true);
|
||||
int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});
|
||||
|
@ -200,7 +199,8 @@ TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeatBatch)
|
|||
TEST_F(MindDataTestImageFolderSampler, TestSubsetRandomSamplerImageFolder) {
|
||||
// id range 0 - 10 is label 0, and id range 11 - 21 is label 1
|
||||
std::vector<int64_t> indices({0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11});
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<SubsetRandomSampler>(indices);
|
||||
int64_t num_samples = 0;
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<SubsetRandomSampler>(num_samples, indices);
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
// Expect 6 samples for label 0 and 1
|
||||
int res[2] = {6, 6};
|
||||
|
@ -237,8 +237,8 @@ TEST_F(MindDataTestImageFolderSampler, TestWeightedRandomSamplerImageFolder) {
|
|||
std::vector<double> weights(total_samples, std::rand() % 100);
|
||||
|
||||
// create sampler with replacement = replacement
|
||||
std::unique_ptr<Sampler> sampler =
|
||||
std::make_unique<WeightedRandomSampler>(weights, num_samples, true, samples_per_buffer);
|
||||
std::shared_ptr<Sampler> sampler =
|
||||
std::make_shared<WeightedRandomSampler>(num_samples, weights, true, samples_per_buffer);
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});
|
||||
|
@ -295,7 +295,8 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderClassIndex) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) {
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(11, 10, false);
|
||||
int64_t num_samples = 0;
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 11, 10, false);
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)), Repeat(4)});
|
||||
tree->Prepare();
|
||||
|
@ -322,7 +323,8 @@ TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestImageFolderSampler, TestPKSamplerImageFolder) {
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<PKSampler>(3, false, 4);
|
||||
int64_t num_samples = 0;
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<PKSampler>(num_samples, 3, false, 4);
|
||||
int32_t res[] = {0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3}; // ground truth label
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))});
|
||||
|
@ -349,39 +351,16 @@ TEST_F(MindDataTestImageFolderSampler, TestPKSamplerImageFolder) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestImageFolderSampler, TestImageFolderNumSamples) {
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, nullptr, {}, 11), Repeat(2)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Return code error detected during tree launch: " << common::SafeCStr(rc.ToString()) << ".";
|
||||
EXPECT_TRUE(false);
|
||||
} else {
|
||||
DatasetIterator di(tree);
|
||||
TensorMap tensor_map;
|
||||
di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
uint64_t i = 0;
|
||||
int32_t label = 0;
|
||||
while (tensor_map.size() != 0) {
|
||||
tensor_map["label"]->GetItemAt<int32_t>(&label, {});
|
||||
EXPECT_TRUE(0 == label);
|
||||
MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n";
|
||||
i++;
|
||||
di.GetNextAsMap(&tensor_map);
|
||||
}
|
||||
EXPECT_TRUE(i == 22);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) {
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
std::map<std::string, int32_t> map;
|
||||
map["class3"] = 333;
|
||||
map["class1"] = 111;
|
||||
map["wrong folder name"] = 1234; // this is skipped
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, nullptr, map, 20, true)});
|
||||
int64_t num_samples = 20;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(seq_sampler), map, true)});
|
||||
int64_t res[2] = {111, 333};
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
|
@ -408,33 +387,12 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestImageFolderSampler, TestImageFolderDatasetSize) {
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
int64_t num_rows = 0;
|
||||
int64_t num_classes = 0;
|
||||
ImageFolderOp::CountRowsAndClasses(folder_path, 15, {}, &num_rows, &num_classes);
|
||||
EXPECT_TRUE(num_rows == 15 && num_classes == 4);
|
||||
ImageFolderOp::CountRowsAndClasses(folder_path, 44, {}, &num_rows, &num_classes);
|
||||
EXPECT_TRUE(num_rows == 44 && num_classes == 4);
|
||||
ImageFolderOp::CountRowsAndClasses(folder_path, 0, {}, &num_rows, &num_classes);
|
||||
EXPECT_TRUE(num_rows == 44 && num_classes == 4);
|
||||
ImageFolderOp::CountRowsAndClasses(folder_path, 55, {}, &num_rows, &num_classes);
|
||||
EXPECT_TRUE(num_rows == 44 && num_classes == 4);
|
||||
ImageFolderOp::CountRowsAndClasses(folder_path, 44, {}, &num_rows, &num_classes, 2, 3);
|
||||
EXPECT_TRUE(num_rows == 15 && num_classes == 4);
|
||||
ImageFolderOp::CountRowsAndClasses(folder_path, 33, {}, &num_rows, &num_classes, 0, 3);
|
||||
EXPECT_TRUE(num_rows == 15 && num_classes == 4);
|
||||
ImageFolderOp::CountRowsAndClasses(folder_path, 13, {}, &num_rows, &num_classes, 0, 11);
|
||||
EXPECT_TRUE(num_rows == 4 && num_classes == 4);
|
||||
ImageFolderOp::CountRowsAndClasses(folder_path, 3, {}, &num_rows, &num_classes, 0, 11);
|
||||
EXPECT_TRUE(num_rows == 3 && num_classes == 4);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) {
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(4, 0, false);
|
||||
int64_t num_samples = 5;
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 4, 0, false);
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
// numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler), {}, 5)});
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler), {})});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
int32_t labels[5] = {0, 0, 0, 1, 1};
|
||||
|
@ -460,10 +418,11 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding2) {
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(4, 3, false);
|
||||
int64_t num_samples = 12;
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 4, 3, false);
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
// numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode
|
||||
auto tree = Build({ImageFolder(16, 16, 32, folder_path, false, std::move(sampler), {}, 12)});
|
||||
auto tree = Build({ImageFolder(16, 16, 32, folder_path, false, std::move(sampler), {})});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
uint32_t labels[11] = {0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3};
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "dataset/core/client.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "dataset/util/de_error.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
@ -42,14 +43,13 @@ std::shared_ptr<RepeatOp> Repeat(int repeatCnt);
|
|||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
std::shared_ptr<ManifestOp> Manifest(int32_t num_works, int32_t rows, int32_t conns, const std::string &file,
|
||||
std::string usage = "train", std::unique_ptr<Sampler> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, uint64_t num_samples = 0,
|
||||
bool decode = false) {
|
||||
std::string usage = "train", std::shared_ptr<Sampler> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, bool decode = false) {
|
||||
std::shared_ptr<ManifestOp> so;
|
||||
ManifestOp::Builder builder;
|
||||
Status rc = builder.SetNumWorkers(num_works).SetManifestFile(file).SetRowsPerBuffer(
|
||||
rows).SetOpConnectorSize(conns).SetSampler(std::move(sampler)).SetClassIndex(map).SetDecode(decode)
|
||||
.SetNumSamples(num_samples).SetUsage(usage).Build(&so);
|
||||
.SetUsage(usage).Build(&so);
|
||||
return so;
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,8 @@ TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) {
|
|||
|
||||
TEST_F(MindDataTestManifest, TestSubsetRandomSamplerManifest) {
|
||||
std::vector<int64_t> indices({1});
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<SubsetRandomSampler>(indices);
|
||||
int64_t num_samples = 0;
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<SubsetRandomSampler>(num_samples, indices);
|
||||
std::string file = datasets_root_path_ + "/testManifestData/cpp.json";
|
||||
// Expect 6 samples for label 0 and 1
|
||||
auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(sampler))});
|
||||
|
@ -145,7 +146,10 @@ TEST_F(MindDataTestManifest, MindDataTestManifestClassIndex) {
|
|||
|
||||
TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) {
|
||||
std::string file = datasets_root_path_ + "/testManifestData/cpp.json";
|
||||
auto tree = Build({Manifest(16, 2, 32, file, "train", nullptr, {}, 1), Repeat(4)});
|
||||
int64_t num_samples = 1;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}), Repeat(4)});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -171,7 +175,10 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) {
|
|||
|
||||
TEST_F(MindDataTestManifest, MindDataTestManifestEval) {
|
||||
std::string file = datasets_root_path_ + "/testManifestData/cpp.json";
|
||||
auto tree = Build({Manifest(16, 2, 32, file, "eval", nullptr, {}, 1)});
|
||||
int64_t num_samples = 1;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
auto tree = Build({Manifest(16, 2, 32, file, "eval", std::move(seq_sampler), {})});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
|
|
@ -120,9 +120,8 @@ class MindDataTestMapOp : public UT::DatasetOpTesting {
|
|||
};
|
||||
|
||||
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
|
||||
bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, int64_t num_samples = 0,
|
||||
bool decode = false);
|
||||
bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, bool decode = false);
|
||||
|
||||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
|
|
|
@ -53,13 +53,11 @@ Status Create1DTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements,
|
|||
DataType::Type data_type = DataType::DE_UINT32);
|
||||
|
||||
std::shared_ptr<MnistOp> CreateMnist(int64_t num_wrks, int64_t rows, int64_t conns, std::string path,
|
||||
bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr,
|
||||
int64_t num_samples = 0) {
|
||||
bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr) {
|
||||
std::shared_ptr<MnistOp> so;
|
||||
MnistOp::Builder builder;
|
||||
Status rc = builder.SetNumWorkers(num_wrks).SetDir(path).SetRowsPerBuffer(rows)
|
||||
.SetOpConnectorSize(conns).SetSampler(std::move(sampler))
|
||||
.SetNumSamples(num_samples).Build(&so);
|
||||
.SetOpConnectorSize(conns).SetSampler(std::move(sampler)).Build(&so);
|
||||
return so;
|
||||
}
|
||||
|
||||
|
@ -74,7 +72,10 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) {
|
|||
// appear in this dataset
|
||||
// Example: python tests/dataset/data/prep_data.py
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, nullptr, 10), Repeat(2)});
|
||||
int64_t num_samples = 10;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2)});
|
||||
tree->Prepare();
|
||||
uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
Status rc = tree->Launch();
|
||||
|
@ -101,7 +102,10 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) {
|
|||
|
||||
TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) {
|
||||
std::string folder_path = datasets_root_path_ + "/testMnistData/";
|
||||
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, nullptr, 10), Repeat(2), Batch(5)});
|
||||
int64_t num_samples = 10;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2), Batch(5)});
|
||||
tree->Prepare();
|
||||
uint32_t res[4][5] = { {0, 0, 0, 0, 0 },
|
||||
{0, 0, 0, 0, 0 },
|
||||
|
|
|
@ -43,20 +43,11 @@ class MindDataTestStandAloneSampler : public UT::DatasetOpTesting {
|
|||
protected:
|
||||
class MockStorageOp : public RandomAccessOp {
|
||||
public:
|
||||
MockStorageOp(int64_t val) : m_val_(val) {}
|
||||
|
||||
Status GetNumSamples(int64_t *ptr) const override {
|
||||
(*ptr) = m_val_;
|
||||
return Status::OK();
|
||||
MockStorageOp(int64_t val){
|
||||
// row count is in base class as protected member
|
||||
// GetNumRowsInDataset does not need an override, the default from base class is fine.
|
||||
num_rows_ = val;
|
||||
}
|
||||
|
||||
Status GetNumRowsInDataset(int64_t *ptr) const override {
|
||||
(*ptr) = m_val_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t m_val_;
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -73,8 +64,9 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
|
|||
MockStorageOp mock(20);
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
int64_t num_samples = 0;
|
||||
for (int i = 0; i < 6; i++) {
|
||||
std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(3, i % 3, (i < 3 ? false : true));
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 3, i % 3, (i < 3 ? false : true));
|
||||
sampler->HandshakeRandomAccessOp(&mock);
|
||||
sampler->GetNextBuffer(&db);
|
||||
db->GetTensor(&tensor, 0, 0);
|
||||
|
@ -92,7 +84,9 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
|
|||
std::shared_ptr<Tensor> label1, label2;
|
||||
CreateINT64Tensor(&label1, 3, reinterpret_cast<unsigned char *>(res));
|
||||
CreateINT64Tensor(&label2, 2, reinterpret_cast<unsigned char *>(res + 3));
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(3);
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(num_samples, start_index, 3);
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
sampler->HandshakeRandomAccessOp(&mock);
|
||||
|
|
|
@ -31,26 +31,17 @@ class MindDataTestSubsetRandomSampler : public UT::Common {
|
|||
public:
|
||||
class DummyRandomAccessOp : public RandomAccessOp {
|
||||
public:
|
||||
DummyRandomAccessOp(int64_t num_rows) : num_rows_(num_rows) {};
|
||||
Status GetNumSamples(int64_t *num) const {
|
||||
*num = num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetNumRowsInDataset(int64_t *num) const {
|
||||
*num = num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t num_rows_;
|
||||
DummyRandomAccessOp(int64_t num_rows) {
|
||||
num_rows_ = num_rows; // base class
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
|
||||
std::vector<int64_t> in({0, 1, 2, 3, 4});
|
||||
std::unordered_set<int64_t> in_set(in.begin(), in.end());
|
||||
SubsetRandomSampler sampler(in);
|
||||
int64_t num_samples = 0;
|
||||
SubsetRandomSampler sampler(num_samples, in);
|
||||
|
||||
DummyRandomAccessOp dummyRandomAccessOp(5);
|
||||
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
@ -77,8 +68,9 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
|
|||
TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
|
||||
int64_t total_samples = 100000 - 5;
|
||||
int64_t samples_per_buffer = 10;
|
||||
int64_t num_samples = 0;
|
||||
std::vector<int64_t> input(total_samples, 1);
|
||||
SubsetRandomSampler sampler(input, samples_per_buffer);
|
||||
SubsetRandomSampler sampler(num_samples, input, samples_per_buffer);
|
||||
|
||||
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
@ -109,7 +101,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) {
|
|||
TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
|
||||
std::vector<int64_t> in({0, 1, 2, 3, 4});
|
||||
std::unordered_set<int64_t> in_set(in.begin(), in.end());
|
||||
SubsetRandomSampler sampler(in);
|
||||
int64_t num_samples = 0;
|
||||
SubsetRandomSampler sampler(num_samples, in);
|
||||
|
||||
DummyRandomAccessOp dummyRandomAccessOp(5);
|
||||
sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
|
|
@ -35,19 +35,11 @@ class MindDataTestWeightedRandomSampler : public UT::Common {
|
|||
public:
|
||||
class DummyRandomAccessOp : public RandomAccessOp {
|
||||
public:
|
||||
DummyRandomAccessOp(uint64_t num_rows) : num_rows_(num_rows) {};
|
||||
Status GetNumSamples(int64_t *num) const {
|
||||
*num = num_rows_;
|
||||
return Status::OK();
|
||||
DummyRandomAccessOp(uint64_t num_rows) {
|
||||
// row count is in base class as protected member
|
||||
// GetNumRowsInDataset does not need an override, the default from base class is fine.
|
||||
num_rows_ = num_rows;
|
||||
}
|
||||
|
||||
Status GetNumRowsInDataset(int64_t *num) const {
|
||||
*num = num_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
uint64_t num_rows_;
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -59,7 +51,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
|
|||
std::vector<uint64_t> freq(total_samples, 0);
|
||||
|
||||
// create sampler with replacement = true
|
||||
WeightedRandomSampler m_sampler(weights, num_samples, true);
|
||||
WeightedRandomSampler m_sampler(num_samples, weights, true);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
|
@ -89,7 +81,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
|
|||
std::vector<uint64_t> freq(total_samples, 0);
|
||||
|
||||
// create sampler with replacement = replacement
|
||||
WeightedRandomSampler m_sampler(weights, num_samples, false);
|
||||
WeightedRandomSampler m_sampler(num_samples, weights, false);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
|
@ -125,7 +117,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) {
|
|||
std::vector<double> weights(total_samples, std::rand() % 100);
|
||||
|
||||
// create sampler with replacement = replacement
|
||||
WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer);
|
||||
WeightedRandomSampler m_sampler(num_samples, weights, true, samples_per_buffer);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
|
@ -161,7 +153,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) {
|
|||
std::vector<uint64_t> freq(total_samples, 0);
|
||||
|
||||
// create sampler with replacement = replacement
|
||||
WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer);
|
||||
WeightedRandomSampler m_sampler(num_samples, weights, false, samples_per_buffer);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
|
@ -202,7 +194,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
|
|||
std::vector<uint64_t> freq(total_samples, 0);
|
||||
|
||||
// create sampler with replacement = true
|
||||
WeightedRandomSampler m_sampler(weights, num_samples, true);
|
||||
WeightedRandomSampler m_sampler(num_samples, weights, true);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
|
@ -247,7 +239,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
|
|||
std::vector<uint64_t> freq(total_samples, 0);
|
||||
|
||||
// create sampler with replacement = true
|
||||
WeightedRandomSampler m_sampler(weights, num_samples, false);
|
||||
WeightedRandomSampler m_sampler(num_samples, weights, false);
|
||||
DummyRandomAccessOp dummyRandomAccessOp(total_samples);
|
||||
m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_imagefolder_numsamples():
|
|||
assert num_iter == 10
|
||||
|
||||
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
|
||||
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
|
||||
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
|
@ -67,7 +67,7 @@ def test_imagefolder_numsamples():
|
|||
assert num_iter == 3
|
||||
|
||||
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
|
||||
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler)
|
||||
data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
|
||||
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator():
|
||||
|
|
|
@ -162,8 +162,8 @@ def test_voc_shardings(print_res=False):
|
|||
voc_dir = "../data/dataset/testVOC2012"
|
||||
|
||||
def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
|
||||
sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
|
||||
data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_samples=num_samples)
|
||||
sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
|
||||
data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler)
|
||||
data1 = data1.repeat(repeat_cnt)
|
||||
res = []
|
||||
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||
|
|
|
@ -35,18 +35,13 @@ def test_exception_01():
|
|||
|
||||
def test_exception_02():
|
||||
"""
|
||||
Test multiple exceptions with invalid input
|
||||
Test exceptions with invalid input, and test valid input
|
||||
"""
|
||||
logger.info("test_exception_02")
|
||||
num_samples = 0
|
||||
with pytest.raises(ValueError) as info:
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||
assert "num_samples must be greater than 0" in str(info.value)
|
||||
|
||||
num_samples = -1
|
||||
with pytest.raises(ValueError) as info:
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||
assert "num_samples must be greater than 0" in str(info.value)
|
||||
assert "num_samples cannot be less than 0" in str(info.value)
|
||||
|
||||
num_samples = 1
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples)
|
||||
|
|
|
@ -544,7 +544,7 @@ def test_distributed_sampler():
|
|||
def test_num_samples():
|
||||
source = [(np.array([x]),) for x in range(64)]
|
||||
num_samples = 32
|
||||
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_samples=num_samples)
|
||||
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples))
|
||||
ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples=num_samples)
|
||||
ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples)
|
||||
|
||||
|
@ -660,4 +660,6 @@ if __name__ == "__main__":
|
|||
test_sequential_sampler()
|
||||
test_distributed_sampler()
|
||||
test_random_sampler()
|
||||
test_num_samples()
|
||||
test_num_samples_underflow()
|
||||
test_schema()
|
||||
|
|
|
@ -28,8 +28,8 @@ def test_sequential_sampler(print_res=False):
|
|||
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(num_samples, num_repeats=None):
|
||||
sampler = ds.SequentialSampler()
|
||||
data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler)
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
|
||||
if num_repeats is not None:
|
||||
data1 = data1.repeat(num_repeats)
|
||||
res = []
|
||||
|
@ -43,6 +43,7 @@ def test_sequential_sampler(print_res=False):
|
|||
|
||||
assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2]
|
||||
assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2
|
||||
assert test_config(num_samples=0, num_repeats=2) == [0, 1, 2, 3, 4] * 2
|
||||
assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2
|
||||
|
||||
|
||||
|
@ -119,8 +120,8 @@ def test_python_sampler():
|
|||
return iter([i for i in range(self.dataset_size)])
|
||||
|
||||
class Sp2(ds.Sampler):
|
||||
def __init__(self):
|
||||
super(Sp2, self).__init__()
|
||||
def __init__(self, num_samples=None):
|
||||
super(Sp2, self).__init__(num_samples)
|
||||
# at this stage, self.dataset_size and self.num_samples are not yet known
|
||||
self.cnt = 0
|
||||
|
||||
|
@ -130,8 +131,8 @@ def test_python_sampler():
|
|||
def reset(self):
|
||||
self.cnt = (self.cnt + 1) % self.dataset_size
|
||||
|
||||
def test_config(num_samples, num_repeats, sampler):
|
||||
data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler)
|
||||
def test_config(num_repeats, sampler):
|
||||
data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
|
||||
if num_repeats is not None:
|
||||
data1 = data1.repeat(num_repeats)
|
||||
res = []
|
||||
|
@ -154,8 +155,8 @@ def test_python_sampler():
|
|||
assert data[0] == (np.array(i),)
|
||||
i = i - 1
|
||||
|
||||
assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
|
||||
assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
|
||||
assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
|
||||
assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
|
||||
test_generator()
|
||||
|
||||
sp1 = Sp1().create()
|
||||
|
@ -169,9 +170,8 @@ def test_subset_sampler():
|
|||
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(num_samples, start_index, subset_size):
|
||||
_ = num_samples
|
||||
sampler = ds.SubsetSampler(start_index, subset_size)
|
||||
def test_config(start_index, num_samples):
|
||||
sampler = ds.SequentialSampler(start_index, num_samples)
|
||||
d = ds.ManifestDataset(manifest_file, sampler=sampler)
|
||||
|
||||
res = []
|
||||
|
@ -180,19 +180,15 @@ def test_subset_sampler():
|
|||
|
||||
return res
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
test_config(5, 0, 0)
|
||||
assert "subset_size <= 0" in str(info.value)
|
||||
|
||||
assert test_config(5, 0, 1) == [0]
|
||||
assert test_config(5, 0, 2) == [0, 1]
|
||||
assert test_config(5, 0, 3) == [0, 1, 2]
|
||||
assert test_config(5, 0, 4) == [0, 1, 2, 3]
|
||||
assert test_config(5, 0, 5) == [0, 1, 2, 3, 4]
|
||||
assert test_config(5, 1, 1) == [1]
|
||||
assert test_config(5, 2, 3) == [2, 3, 4]
|
||||
assert test_config(5, 3, 2) == [3, 4]
|
||||
assert test_config(5, 4, 1) == [4]
|
||||
assert test_config(0, 1) == [0]
|
||||
assert test_config(0, 2) == [0, 1]
|
||||
assert test_config(0, 3) == [0, 1, 2]
|
||||
assert test_config(0, 4) == [0, 1, 2, 3]
|
||||
assert test_config(0, 5) == [0, 1, 2, 3, 4]
|
||||
assert test_config(1, 1) == [1]
|
||||
assert test_config(2, 3) == [2, 3, 4]
|
||||
assert test_config(3, 2) == [3, 4]
|
||||
assert test_config(4, 1) == [4]
|
||||
|
||||
|
||||
def test_sampler_chain():
|
||||
|
@ -200,11 +196,11 @@ def test_sampler_chain():
|
|||
map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
|
||||
|
||||
def test_config(num_shards, shard_id):
|
||||
sampler = ds.DistributedSampler(num_shards, shard_id, False)
|
||||
sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=False, num_samples=5)
|
||||
child_sampler = ds.SequentialSampler()
|
||||
sampler.add_child(child_sampler)
|
||||
|
||||
data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler)
|
||||
data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
|
||||
|
||||
res = []
|
||||
for item in data1.create_dict_iterator():
|
||||
|
@ -234,6 +230,11 @@ def test_add_sampler_invalid_input():
|
|||
data1.use_sampler("sampler")
|
||||
assert "not an instance of a sampler" in str(info.value)
|
||||
|
||||
sampler = ds.SequentialSampler()
|
||||
with pytest.raises(ValueError) as info:
|
||||
data2 = ds.ManifestDataset(manifest_file, sampler=sampler, num_samples=20)
|
||||
assert "Conflicting arguments during sampler assignments" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequential_sampler(True)
|
||||
|
|
Loading…
Reference in New Issue