diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 0194785090b..9b87f044f59 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -856,9 +856,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptrSetNumSamples(ToInt(value)); - } else if (key == "num_parallel_workers") { + if (key == "num_parallel_workers") { (void)builder->SetNumWorkers(ToInt(value)); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); @@ -893,9 +891,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptrSetNumSamples(ToInt(value)); - } else if (key == "num_parallel_workers") { + if (key == "num_parallel_workers") { (void)builder->SetNumWorkers(ToInt(value)); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); @@ -930,9 +926,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr * std::string key = py::str(arg.first); py::handle value = arg.second; if (!value.is_none()) { - if (key == "num_samples") { - (void)builder->SetNumSamples(ToInt(value)); - } else if (key == "num_parallel_workers") { + if (key == "num_parallel_workers") { (void)builder->SetNumWorkers(ToInt(value)); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); @@ -966,9 +960,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptrSetNumSamples(ToInt(value)); - } else if (key == "num_parallel_workers") { + if (key == "num_parallel_workers") { (void)builder->SetNumWorkers(ToInt(value)); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); @@ -1001,9 +993,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptrSetNumSamples(ToInt(value)); - } else if (key == "num_parallel_workers") { + if (key == "num_parallel_workers") { (void)builder->SetNumWorkers(ToInt(value)); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); @@ -1039,10 +1029,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr std::string key = py::str(arg.first); py::handle value = arg.second; if (!value.is_none()) { - if (key == "num_samples") { - (void)builder->SetNumSamples(ToInt(value)); - } else if (key == "num_parallel_workers") { + if (key == "num_parallel_workers") { (void)builder->SetNumWorkers(ToInt(value)); } else if (key == "sampler") { auto create = py::reinterpret_borrow(value).attr("create"); @@ -1121,8 +1111,6 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptrSetDecode(ToBool(value)); } else if (key == "extensions") { (void)builder->SetExtensions(ToStringSet(value)); - } else if (key == "num_samples") { - (void)builder->SetNumSamples(ToInt(value)); } else if (key == "dataset_type") { (void)builder->SetDatasetType(ToString(value)); } @@ -1153,7 +1141,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptrSetShuffleFiles(ToBool(value)); } else if (key == "num_samples") { - (void)builder->SetNumSamples(ToInt(value)); + (void)builder->SetTotalRows(ToInt(value)); } else if (key == "num_shards") { (void)builder->SetNumDevices(ToInt(value)); } else if (key == "shard_id") { diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 55918d8b432..a38a88beaba 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -49,7 +49,6 @@ #include "dataset/engine/datasetops/source/sampler/pk_sampler.h" #include "dataset/engine/datasetops/source/sampler/random_sampler.h" #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_sampler.h" #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" #include "dataset/engine/datasetops/source/sampler/python_sampler.h" @@ -143,17 +142,16 @@ void bindDatasetOps(py::module *m) { }); (void)py::class_>(*m, "CifarOp") - .def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) { + .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) { int64_t count = 0; - THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count)); + THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count)); return count; }); (void)py::class_>(*m, "ImageFolderOp") - .def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) { + .def_static("get_num_rows_and_classes", [](const std::string &path) { int64_t count = 0, num_classes = 0; - THROW_IF_ERROR( - ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set{}, &count, &num_classes)); + THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set{}, &count, &num_classes)); return py::make_tuple(count, num_classes); }); @@ -172,22 +170,21 @@ void bindDatasetOps(py::module *m) { (void)py::class_>(*m, "ManifestOp") .def_static("get_num_rows_and_classes", - [](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) { + [](const std::string &file, const py::dict &dict, const std::string &usage) { int64_t count = 0, num_classes = 0; - THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes)); + THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); return py::make_tuple(count, num_classes); }) - .def_static("get_class_indexing", - [](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) { - std::map output_class_indexing; - THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing)); - return output_class_indexing; - }); + .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) { + std::map output_class_indexing; + THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); + return output_class_indexing; + }); (void)py::class_>(*m, "MnistOp") - .def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) { + .def_static("get_num_rows", [](const std::string &dir) { int64_t count = 0; - THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count)); + THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count)); return count; }); @@ -206,13 +203,13 @@ void bindDatasetOps(py::module *m) { [](const std::string &dir, const std::string &task_type, const std::string &task_mode, const py::dict &dict, int64_t numSamples) { int64_t count = 0; - THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count)); + THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); return count; }) .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, - const std::string &task_mode, const py::dict &dict, int64_t numSamples) { + const std::string &task_mode, const py::dict &dict) { std::map output_class_indexing; - THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, numSamples, &output_class_indexing)); + THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing)); return output_class_indexing; }); } @@ -452,25 +449,19 @@ void bindSamplerOps(py::module *m) { (void)py::class_>(*m, "ShardOperator"); (void)py::class_>(*m, "DistributedSampler") - .def(py::init(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"), - py::arg("seed")); + .def(py::init()); (void)py::class_>(*m, "PKSampler") - .def(py::init(), py::arg("kVal"), py::arg("shuffle")); + .def(py::init()); (void)py::class_>(*m, "RandomSampler") - .def(py::init(), py::arg("replacement"), py::arg("reshuffle_each_epoch"), - py::arg("num_samples")) - .def(py::init(), py::arg("replacement"), py::arg("reshuffle_each_epoch")); + .def(py::init()); (void)py::class_>(*m, "SequentialSampler") - .def(py::init<>()); - - (void)py::class_>(*m, "SubsetSampler") - .def(py::init(), py::arg("start_index"), py::arg("subset_size")); + .def(py::init()); (void)py::class_>(*m, "SubsetRandomSampler") - .def(py::init>(), py::arg("indices")); + .def(py::init>()); (void)py::class_>( *m, "MindrecordSubsetRandomSampler") @@ -487,11 +478,10 @@ void bindSamplerOps(py::module *m) { })); (void)py::class_>(*m, "WeightedRandomSampler") - .def(py::init, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), - py::arg("replacement")); + .def(py::init, bool>()); (void)py::class_>(*m, "PythonSampler") - .def(py::init(), py::arg("pySampler")); + .def(py::init()); } void bindInfoObjects(py::module *m) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc index 8f8c57b0126..f016481af69 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace dataset { -CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr), builder_num_samples_(0) { +CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { std::shared_ptr cfg = GlobalContext::config_manager(); builder_num_workers_ = cfg->num_parallel_workers(); builder_rows_per_buffer_ = cfg->rows_per_buffer(); @@ -38,7 +38,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr *op) { MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << "."; RETURN_IF_NOT_OK(SanityCheck()); if (builder_sampler_ == nullptr) { - builder_sampler_ = std::make_shared(); + int64_t num_samples = 0; + int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); @@ -47,10 +49,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr *op) { // label is like this:0 1 0 0 1...... RETURN_IF_NOT_OK( builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - *op = - std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, builder_op_connector_size_, - builder_decode_, builder_dataset_type_, builder_extensions_, std::move(builder_schema_), - std::move(builder_sampler_), builder_num_samples_); + *op = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, + builder_op_connector_size_, builder_decode_, builder_dataset_type_, + builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_)); if (*op == nullptr) { return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null"); } @@ -68,7 +69,7 @@ Status CelebAOp::Builder::SanityCheck() { CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, const std::string &dataset_type, const std::set &exts, - std::unique_ptr schema, std::shared_ptr sampler, int64_t num_samples) + std::unique_ptr schema, std::shared_ptr sampler) : ParallelOp(num_workers, queue_size), rows_per_buffer_(rows_per_buffer), folder_path_(dir), @@ -77,8 +78,6 @@ CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::stri data_schema_(std::move(schema)), sampler_(std::move(sampler)), num_rows_in_attr_file_(0), - num_rows_exact_(0), - num_samples_(num_samples), dataset_type_(dataset_type) { // Set the column name map (base class field) for (int32_t index = 0; index < data_schema_->NumColumns(); index++) { @@ -202,13 +201,6 @@ Status CelebAOp::ParseImageAttrInfo() { RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); while (!image_infos.empty() && needMoreData) { for (uint32_t index = 0; index < image_infos.size(); index++) { - if (num_samples_ != 0 && image_labels_vec_.size() >= num_samples_) { - MS_LOG(WARNING) << "Image number(" << image_labels_vec_.size() << " is more than" - << " rows num eval attr file(" << num_rows_in_attr_file_ << ") or num samples(" << num_samples_ - << ")."; - needMoreData = false; - break; - } std::string image_info = image_infos[index]; std::vector split = Split(image_info); std::pair> image_labels; @@ -239,14 +231,13 @@ Status CelebAOp::ParseImageAttrInfo() { RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); } - num_rows_exact_ = image_labels_vec_.size(); - num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_exact_) ? num_rows_exact_ : num_samples_; - if (num_rows_exact_ == 0) { + num_rows_ = image_labels_vec_.size(); + if (num_rows_ == 0) { RETURN_STATUS_UNEXPECTED( "There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API " "validation first."); } - MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_exact_ << "."; + MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_ << "."; return Status::OK(); } @@ -268,28 +259,6 @@ std::vector CelebAOp::Split(const std::string &line) { return split; } -// Derived from RandomAccessOp -Status CelebAOp::GetNumSamples(int64_t *num) const { - if (num == nullptr || num_samples_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API " - "validation first."); - } - (*num) = num_samples_; - return Status::OK(); -} - -Status CelebAOp::GetNumRowsInDataset(int64_t *num) const { - if (num == nullptr || num_rows_exact_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API " - "validation first."); - } - - *num = num_rows_exact_; - return Status::OK(); -} - // Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work Status CelebAOp::operator()() { RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); @@ -310,9 +279,8 @@ Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { RETURN_IF_NOT_OK((*data_buffer)->PopRow(&sample_row)); std::shared_ptr sample_ids = sample_row[0]; for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_exact_) { - MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_exact_ - << "."; + if ((*itr) >= num_rows_) { + MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_ << "."; continue; } keys.push_back(*itr); @@ -446,7 +414,7 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common detailed info ParallelOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_exact_ << "\nceleba dir: " << folder_path_ << "\n\n"; + out << "\nNumber of rows:" << num_rows_ << "\nceleba dir: " << folder_path_ << "\n\n"; } } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h index e0055441efd..f92ba94df75 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h @@ -108,14 +108,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { return *this; } - // Setter method - // @param int64_t num_samples - // @return Builder setter method returns reference to the builder. - Builder &SetNumSamples(int64_t num_samples) { - builder_num_samples_ = num_samples; - return *this; - } - // Setter method // @param const std::string dataset_type: type to be read // @return Builder setter method returns reference to the builder. @@ -141,7 +133,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { std::set builder_extensions_; std::shared_ptr builder_sampler_; std::unique_ptr builder_schema_; - int64_t builder_num_samples_; std::string builder_dataset_type_; }; @@ -153,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { // @param std::unique_ptr sampler - sampler tells CelebAOp what to read CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, const std::string &dataset_type, const std::set &exts, std::unique_ptr schema, - std::shared_ptr sampler, int64_t num_samples); + std::shared_ptr sampler); ~CelebAOp() override = default; @@ -163,16 +154,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { // @return Status - The error code return Status operator()() override; - // Method derived from RandomAccess Op, enable Sampler to get numRows - // @param int64_t num - to return numRows - // @return Status - The error code return - Status GetNumSamples(int64_t *num) const override; - - // Method derived from RandomAccess Op, enable Sampler to get numRows - // @param int64_t num - to return numRows - // @return Status - The error code return - Status GetNumRowsInDataset(int64_t *num) const override; - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector // @param int32_t worker_id - id of each worker // @return Status - The error code return @@ -233,11 +214,9 @@ class CelebAOp : public ParallelOp, RandomAccessOp { std::shared_ptr sampler_; std::unique_ptr>> attr_info_queue_; int64_t num_rows_in_attr_file_; // rows number specified in attr file - int64_t num_rows_exact_; // exact rows number,maybe is less than rows_num_in_attr_file_ QueueList> io_block_queues_; WaitPost wp_; std::vector>> image_labels_vec_; - int64_t num_samples_; std::string dataset_type_; std::ifstream partition_file_; }; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc index d0a17b56f9a..5867d5e7eaa 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc @@ -35,7 +35,7 @@ constexpr uint32_t kCifarImageChannel = 3; constexpr uint32_t kCifarBlockImageNum = 5; constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel; -CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) { +CifarOp::Builder::Builder() : sampler_(nullptr) { std::shared_ptr cfg = GlobalContext::config_manager(); num_workers_ = cfg->num_parallel_workers(); rows_per_buffer_ = cfg->rows_per_buffer(); @@ -46,7 +46,9 @@ CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) { Status CifarOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); if (sampler_ == nullptr) { - sampler_ = std::make_shared(); + int64_t num_samples = 0; + int64_t start_index = 0; + sampler_ = std::make_shared(start_index, num_samples); } schema_ = std::make_unique(); TensorShape scalar = TensorShape::CreateScalar(); @@ -62,7 +64,7 @@ Status CifarOp::Builder::Build(std::shared_ptr *ptr) { ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar))); } - *ptr = std::make_shared(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, num_samples_, + *ptr = std::make_shared(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, std::move(schema_), std::move(sampler_)); return Status::OK(); } @@ -76,16 +78,13 @@ Status CifarOp::Builder::SanityCheck() { } CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, - int32_t queue_size, int64_t num_samples, std::unique_ptr data_schema, - std::shared_ptr sampler) + int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) : ParallelOp(num_works, queue_size), cifar_type_(type), rows_per_buffer_(rows_per_buf), folder_path_(file_dir), - num_samples_(num_samples), data_schema_(std::move(data_schema)), sampler_(std::move(sampler)), - num_rows_(0), row_cnt_(0), buf_cnt_(0) { // set the column name map (base class field) @@ -112,8 +111,7 @@ Status CifarOp::operator()() { for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { keys.push_back(*itr); row_cnt_++; - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - if (row_cnt_ >= num_samples_) break; // enough row read, break for loop + if ((*itr) >= num_rows_) continue; // index out of bound, skipping if (row_cnt_ % rows_per_buffer_ == 0) { RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); @@ -255,30 +253,6 @@ Status CifarOp::InitSampler() { return Status::OK(); } -// Derived from RandomAccessOp -Status CifarOp::GetNumSamples(int64_t *num) const { - if (num == nullptr || num_rows_ == 0) { - std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; - std::string err_msg = "There is no valid data matching the dataset API " + api + - ".Please check file path or dataset API validation first."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (*num) = num_samples_; - return Status::OK(); -} - -// Derived from RandomAccessOp -Status CifarOp::GetNumRowsInDataset(int64_t *num) const { - if (num == nullptr || num_rows_ == 0) { - std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; - std::string err_msg = "There is no valid data matching the dataset API " + api + - ".Please check file path or dataset API validation first."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (*num) = num_rows_; - return Status::OK(); -} - Status CifarOp::ReadCifarBlockDataAsync() { TaskManager::FindMe()->Post(); RETURN_IF_NOT_OK(GetCifarFiles()); @@ -404,7 +378,6 @@ Status CifarOp::ParseCifarData() { } cifar_image_label_pairs_.shrink_to_fit(); num_rows_ = cifar_image_label_pairs_.size(); - num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; if (num_rows_ == 0) { std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; std::string err_msg = "There is no valid data matching the dataset API " + api + @@ -432,11 +405,11 @@ Status CifarOp::GetClassIds(std::map> *cls_ids) co return Status::OK(); } -Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool isCIFAR10, int64_t *count) { +Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) { // the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block() std::shared_ptr op; *count = 0; - RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetNumSamples(numSamples).SetCifarType(isCIFAR10).Build(&op)); + RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op)); RETURN_IF_NOT_OK(op->GetCifarFiles()); if (op->cifar_type_ == kCifar10) { constexpr int64_t num_cifar10_records = 10000; @@ -448,7 +421,6 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool } *count = *count + num_cifar10_records; } - *count = *count < numSamples || numSamples == 0 ? *count : numSamples; return Status::OK(); } else { int64_t num_cifar100_records = 0; @@ -470,7 +442,7 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool RETURN_STATUS_UNEXPECTED(err_msg); } } - *count = num_cifar100_records < numSamples || numSamples == 0 ? num_cifar100_records : numSamples; + *count = num_cifar100_records; return Status::OK(); } } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h index ade0998c30b..35c0121818e 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h @@ -73,14 +73,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp { return *this; } - // Setter method - // @param uint64_t num_samples - // @return Builder setter method returns reference to the builder. - Builder &SetNumSamples(uint64_t num_samples) { - num_samples_ = num_samples; - return *this; - } - // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. @@ -121,7 +113,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp { private: std::string dir_; int32_t num_workers_; - uint64_t num_samples_; int32_t rows_per_buffer_; int32_t op_connect_size_; std::shared_ptr sampler_; @@ -137,7 +128,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { // @param uint32_t - queueSize - connector queue size // @param std::unique_ptr sampler - sampler tells ImageFolderOp what to read CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size, - int64_t num_samples, std::unique_ptr data_schema, std::shared_ptr sampler); + std::unique_ptr data_schema, std::shared_ptr sampler); // Destructor. ~CifarOp() = default; @@ -152,16 +143,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp { // @return Status - The error code return Status operator()() override; - // Method derived from RandomAccess Op, enable Sampler to get numRows - // @param uint64_t num - to return numRows - // @return Status - The error code return - Status GetNumSamples(int64_t *num) const override; - - // Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset - // @param uint64_t num - to return numRows - // @return Status - The error code return - Status GetNumRowsInDataset(int64_t *num) const override; - // A print method typically used for debugging // @param out // @param show_all @@ -169,11 +150,10 @@ class CifarOp : public ParallelOp, public RandomAccessOp { // Function to count the number of samples in the CIFAR dataset // @param dir path to the CIFAR directory - // @param numSamples maximum number of samples requested // @param isCIFAR10 true if CIFAR10 and false if CIFAR100 - // @param count output arg that will hold the minimum of the actual dataset size and numSamples + // @param count output arg that will hold the actual dataset size // @return - static Status CountTotalRows(const std::string &dir, int64_t numSamples, bool isCIFAR10, int64_t *count); + static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count); private: // Initialize Sampler, calls sampler->Init() within @@ -227,10 +207,8 @@ class CifarOp : public ParallelOp, public RandomAccessOp { CifarType cifar_type_; int32_t rows_per_buffer_; std::string folder_path_; - int64_t num_samples_; std::unique_ptr data_schema_; std::shared_ptr sampler_; - int64_t num_rows_; int64_t row_cnt_; int64_t buf_cnt_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc index ce8fef7404b..bc5caf3b7e3 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc @@ -26,8 +26,7 @@ namespace mindspore { namespace dataset { -ImageFolderOp::Builder::Builder() - : builder_decode_(false), builder_recursive_(false), builder_num_samples_(0), builder_sampler_(nullptr) { +ImageFolderOp::Builder::Builder() : builder_decode_(false), builder_recursive_(false), builder_sampler_(nullptr) { std::shared_ptr cfg = GlobalContext::config_manager(); builder_num_workers_ = cfg->num_parallel_workers(); builder_rows_per_buffer_ = cfg->rows_per_buffer(); @@ -37,7 +36,9 @@ ImageFolderOp::Builder::Builder() Status ImageFolderOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); if (builder_sampler_ == nullptr) { - builder_sampler_ = std::make_shared(); + int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data + int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); TensorShape scalar = TensorShape::CreateScalar(); @@ -46,9 +47,9 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(builder_schema_->AddColumn( ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, - builder_op_connector_size_, builder_num_samples_, builder_recursive_, - builder_decode_, builder_extensions_, builder_labels_to_read_, - std::move(builder_schema_), std::move(builder_sampler_)); + builder_op_connector_size_, builder_recursive_, builder_decode_, + builder_extensions_, builder_labels_to_read_, std::move(builder_schema_), + std::move(builder_sampler_)); return Status::OK(); } @@ -61,20 +62,18 @@ Status ImageFolderOp::Builder::SanityCheck() { } ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, - int64_t num_samples, bool recursive, bool do_decode, const std::set &exts, + bool recursive, bool do_decode, const std::set &exts, const std::map &map, std::unique_ptr data_schema, std::shared_ptr sampler) : ParallelOp(num_wkrs, queue_size), rows_per_buffer_(rows_per_buffer), folder_path_(file_dir), - num_samples_(num_samples), recursive_(recursive), decode_(do_decode), extensions_(exts), class_index_(map), data_schema_(std::move(data_schema)), sampler_(std::move(sampler)), - num_rows_(0), row_cnt_(0), buf_cnt_(0), sampler_ind_(0), @@ -117,7 +116,6 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) { } image_label_pairs_.shrink_to_fit(); num_rows_ = image_label_pairs_.size(); - num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; // free memory of two queues used for pre-scan folder_name_queue_->Reset(); image_name_queue_->Reset(); @@ -138,8 +136,7 @@ Status ImageFolderOp::operator()() { std::shared_ptr sample_ids = sample_row[0]; if (sample_ids->type() != DataType(DataType::DE_INT64)) RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - if (row_cnt_ >= num_samples_) break; // enough row read, break for loop + if ((*itr) >= num_rows_) continue; // index out of bound, skipping keys.push_back(*itr); row_cnt_++; if (row_cnt_ % rows_per_buffer_ == 0) { @@ -272,28 +269,6 @@ Status ImageFolderOp::InitSampler() { return Status::OK(); } -// Derived from RandomAccessOp -Status ImageFolderOp::GetNumSamples(int64_t *num) const { - if (num == nullptr || num_samples_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API " - "validation first."); - } - (*num) = num_samples_; - return Status::OK(); -} - -// Derived from RandomAccessOp -Status ImageFolderOp::GetNumRowsInDataset(int64_t *num) const { - if (num == nullptr || num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API " - "validation first."); - } - (*num) = num_rows_; - return Status::OK(); -} - // Derived from RandomAccessOp Status ImageFolderOp::GetClassIds(std::map> *cls_ids) const { if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { @@ -413,16 +388,14 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() { return Status::OK(); } -Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t &num_samples, - const std::set &exts, int64_t *num_rows, int64_t *num_classes, - int64_t dev_id, int64_t num_dev) { +Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::set &exts, int64_t *num_rows, + int64_t *num_classes, int64_t dev_id, int64_t num_dev) { Path dir(path); std::string err_msg = ""; int64_t row_cnt = 0; err_msg += (dir.Exists() == false || dir.IsDirectory() == false) ? "unable to open dir " + path : ""; err_msg += (num_classes == nullptr || num_rows == nullptr) ? "num_class/num_rows is null\n" : ""; err_msg += (dev_id >= num_dev || num_dev <= 0) ? "invalid sharding config\n" : ""; - err_msg += num_samples < 0 ? "num_samples can't be negative! set it to 0 to use all samples\n" : ""; if (err_msg.empty() == false) { RETURN_STATUS_UNEXPECTED(err_msg); } @@ -441,10 +414,6 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t while (dir_itr->hasNext()) { if (exts.empty() || exts.find(subdir.Extension()) != exts.end()) { ++row_cnt; - if (row_cnt == num_samples * num_dev) { - (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); - return Status::OK(); - } } } foldernames.pop(); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h index 72d47224fb8..7acecedba61 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h @@ -107,14 +107,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { return *this; } - // Setter method - // @param int64_t num_samples - // @return Builder setter method returns reference to the builder. - Builder &SetNumSamples(int64_t num_samples) { - builder_num_samples_ = num_samples; - return *this; - } - // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. @@ -153,7 +145,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { bool builder_recursive_; std::string builder_dir_; int32_t builder_num_workers_; - int64_t builder_num_samples_; int32_t builder_rows_per_buffer_; int32_t builder_op_connector_size_; std::set builder_extensions_; @@ -169,10 +160,9 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { // @param int32_t queue_size - connector queue size // @param std::set exts - set of file extensions to read, if empty, read everything under the dir // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read - ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, - int64_t num_samples, bool recursive, bool do_decode, const std::set &exts, - const std::map &map, std::unique_ptr, - std::shared_ptr sampler); + ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, + bool do_decode, const std::set &exts, const std::map &map, + std::unique_ptr, std::shared_ptr sampler); // Destructor. ~ImageFolderOp() = default; @@ -198,16 +188,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { // @return Status - The error code return Status operator()() override; - // Method derived from RandomAccess Op, enable Sampler to get numRows - // @param int64_t num - to return numRows - // @return Status - The error code return - Status GetNumSamples(int64_t *num) const override; - - // Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset - // @param int64_t num - to return numRows - // @return Status - The error code return - Status GetNumRowsInDataset(int64_t *num) const override; - // Method derived from RandomAccess Op, enable Sampler to get all ids for each class // @param (std::map> * map - key label, val all ids for this class // @return Status - The error code return @@ -221,9 +201,8 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { // This function is a hack! It is to return the num_class and num_rows the old storageOp does. The result // returned by this function may not be consistent with what image_folder_op is going to return // user this at your own risk! - static Status CountRowsAndClasses(const std::string &path, const int64_t &num_samples, - const std::set &exts, int64_t *num_rows, int64_t *num_classes, - int64_t dev_id = 0, int64_t num_dev = 1); + static Status CountRowsAndClasses(const std::string &path, const std::set &exts, int64_t *num_rows, + int64_t *num_classes, int64_t dev_id = 0, int64_t num_dev = 1); // Base-class override for NodePass visitor acceptor. // @param p - Pointer to the NodePass to be accepted. @@ -266,14 +245,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { int32_t rows_per_buffer_; std::string folder_path_; // directory of image folder - int64_t num_samples_; bool recursive_; bool decode_; std::set extensions_; // extensions allowed std::map class_index_; std::unique_ptr data_schema_; std::shared_ptr sampler_; - int64_t num_rows_; // total number of images in ImageFolder int64_t row_cnt_; int64_t buf_cnt_; int64_t sampler_ind_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc index 5892b10701f..70867ed2906 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc @@ -29,7 +29,7 @@ namespace mindspore { namespace dataset { -ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_(0), builder_decode_(false) { +ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_decode_(false) { std::shared_ptr cfg = GlobalContext::config_manager(); builder_num_workers_ = cfg->num_parallel_workers(); builder_rows_per_buffer_ = cfg->rows_per_buffer(); @@ -39,16 +39,18 @@ ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_ Status ManifestOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); if (builder_sampler_ == nullptr) { - builder_sampler_ = std::make_shared(); + int64_t num_samples = 0; + int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); RETURN_IF_NOT_OK( builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); RETURN_IF_NOT_OK( builder_schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - *ptr = std::make_shared( - builder_num_workers_, builder_rows_per_buffer_, builder_file_, builder_op_connector_size_, builder_num_samples_, - builder_decode_, builder_labels_to_read_, std::move(builder_schema_), std::move(builder_sampler_), builder_usage_); + *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_file_, + builder_op_connector_size_, builder_decode_, builder_labels_to_read_, + std::move(builder_schema_), std::move(builder_sampler_), builder_usage_); return Status::OK(); } @@ -59,9 +61,9 @@ Status ManifestOp::Builder::SanityCheck() { return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); } -ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, - int64_t num_samples, bool decode, const std::map &class_index, - std::unique_ptr data_schema, std::shared_ptr sampler, std::string usage) +ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, + const std::map &class_index, std::unique_ptr data_schema, + std::shared_ptr sampler, std::string usage) : ParallelOp(num_works, queue_size), rows_per_buffer_(rows_per_buffer), io_block_pushed_(0), @@ -71,8 +73,6 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f file_(file), class_index_(class_index), sampler_(std::move(sampler)), - num_samples_(num_samples), - num_rows_(0), decode_(decode), usage_(usage), buf_cnt_(0) { @@ -101,8 +101,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { RETURN_IF_NOT_OK((*sampler_buffer)->PopRow(&sample_row)); std::shared_ptr sample_ids = sample_row[0]; for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - if (row_cnt_ >= num_samples_) break; // enough row read, break for loop + if ((*itr) >= num_rows_) continue; // index out of bound, skipping keys.push_back(*itr); row_cnt_++; if (row_cnt_ % rows_per_buffer_ == 0) { @@ -269,28 +268,6 @@ Status ManifestOp::InitSampler() { return Status::OK(); } -// Derived from RandomAccessOp -Status ManifestOp::GetNumSamples(int64_t *num) const { - if (num == nullptr || num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " - "validation first."); - } - (*num) = num_samples_; - return Status::OK(); -} - -// Derived from RandomAccessOp -Status ManifestOp::GetNumRowsInDataset(int64_t *num) const { - if (num == nullptr || num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " - "validation first."); - } - (*num) = num_rows_; - return Status::OK(); -} - // Derived from RandomAccessOp Status ManifestOp::GetClassIds(std::map> *cls_ids) const { if (cls_ids == nullptr || !cls_ids->empty() || image_labelname_.empty()) { @@ -408,7 +385,6 @@ Status ManifestOp::CountDatasetInfo() { } num_rows_ = static_cast(image_labelname_.size()); - num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; if (num_rows_ == 0) { RETURN_STATUS_UNEXPECTED( "There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " @@ -417,8 +393,8 @@ Status ManifestOp::CountDatasetInfo() { return Status::OK(); } -Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, const py::dict &dict, - const std::string &usage, int64_t *count, int64_t *numClasses) { +Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, + int64_t *count, int64_t *numClasses) { // the logic of counting the number of samples is copied from ParseManifestFile() std::map map; for (auto p : dict) { @@ -428,17 +404,15 @@ Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, c std::shared_ptr op; *count = 0; - RETURN_IF_NOT_OK( - Builder().SetManifestFile(file).SetNumSamples(numSamples).SetClassIndex(map).SetUsage(usage).Build(&op)); + RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op)); RETURN_IF_NOT_OK(op->ParseManifestFile()); *numClasses = static_cast(op->label_index_.size()); *count = static_cast(op->image_labelname_.size()); - *count = (*count < numSamples || numSamples == 0) ? *count : numSamples; return Status::OK(); } -Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples, const py::dict &dict, - const std::string &usage, std::map *output_class_indexing) { +Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, + std::map *output_class_indexing) { std::map input_class_indexing; for (auto p : dict) { (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), @@ -449,12 +423,7 @@ Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples, *output_class_indexing = input_class_indexing; } else { std::shared_ptr op; - RETURN_IF_NOT_OK(Builder() - .SetManifestFile(file) - .SetNumSamples(numSamples) - .SetClassIndex(input_class_indexing) - .SetUsage(usage) - .Build(&op)); + RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(input_class_indexing).SetUsage(usage).Build(&op)); RETURN_IF_NOT_OK(op->ParseManifestFile()); RETURN_IF_NOT_OK(op->CountDatasetInfo()); uint32_t count = 0; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h index e015496acce..5283a3ca380 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h @@ -86,14 +86,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { return *this; } - // Setter method - // @param int64_t num_samples - // @return Builder setter method returns reference to the builder. - Builder &SetNumSamples(int64_t num_samples) { - builder_num_samples_ = num_samples; - return *this; - } - // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. @@ -129,7 +121,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { private: std::shared_ptr builder_sampler_; - int64_t builder_num_samples_; bool builder_decode_; std::string builder_file_; @@ -147,8 +138,8 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { // @param std::string - file list of Manifest // @param int32_t queue_size - connector queue size // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read - ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, int64_t num_samples, - bool decode, const std::map &class_index, std::unique_ptr data_schema, + ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, + const std::map &class_index, std::unique_ptr data_schema, std::shared_ptr sampler, std::string usage); // Destructor. ~ManifestOp() = default; @@ -164,16 +155,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { // @return Status - The error code return Status operator()() override; - // Method derived from RandomAccess Op, enable Sampler to get numRows - // @param int64_t num - to return numRows - // @return Status - The error code return - Status GetNumSamples(int64_t *num) const override; - - // Method derived from RandomAccess Op, enable Sampler to get total number of Rows in dataset - // @param int64_t num - to return numRows - // @return Status - The error code return - Status GetNumRowsInDataset(int64_t *num) const override; - // Method derived from RandomAccess Op, enable Sampler to get all ids for each class // @param (std::map> * map - key label, val all ids for this class // @return Status - The error code return @@ -184,12 +165,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { // @param show_all void Print(std::ostream &out, bool show_all) const override; - static Status CountTotalRows(const std::string &file, int64_t numSamples, const py::dict &dict, - const std::string &usage, int64_t *count, int64_t *numClasses); + static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count, + int64_t *numClasses); // Get str-to-int mapping from label name to index - static Status GetClassIndexing(const std::string &file, int64_t numSamples, const py::dict &dict, - const std::string &usage, std::map *output_class_indexing); + static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, + std::map *output_class_indexing); private: // Initialize Sampler, calls sampler->Init() within @@ -240,8 +221,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { std::string file_; // file that store the information of images std::map class_index_; std::shared_ptr sampler_; - int64_t num_samples_; - int64_t num_rows_; bool decode_; std::string usage_; int64_t buf_cnt_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index 358dd07872b..52899873382 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -91,7 +91,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf block_reader_(block_reader), buffers_needed_(0), buf_cnt_(0), - num_rows_(0), ended_worker_(0), buffer_water_mark_(0) { io_blk_queues_.Init(num_workers_, op_connector_queue_size); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc index 53c32b19042..4ecc1e96ee2 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc @@ -31,7 +31,7 @@ const int32_t kMnistLabelFileMagicNumber = 2049; const int32_t kMnistImageRows = 28; const int32_t kMnistImageCols = 28; -MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr) { +MnistOp::Builder::Builder() : builder_sampler_(nullptr) { std::shared_ptr cfg = GlobalContext::config_manager(); builder_num_workers_ = cfg->num_parallel_workers(); builder_rows_per_buffer_ = cfg->rows_per_buffer(); @@ -41,7 +41,9 @@ MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr) Status MnistOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); if (builder_sampler_ == nullptr) { - builder_sampler_ = std::make_shared(); + int64_t num_samples = 0; + int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); RETURN_IF_NOT_OK( @@ -49,9 +51,8 @@ Status MnistOp::Builder::Build(std::shared_ptr *ptr) { TensorShape scalar = TensorShape::CreateScalar(); RETURN_IF_NOT_OK(builder_schema_->AddColumn( ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - *ptr = - std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, builder_op_connector_size_, - builder_num_samples_, std::move(builder_schema_), std::move(builder_sampler_)); + *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, + builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_)); return Status::OK(); } @@ -60,17 +61,14 @@ Status MnistOp::Builder::SanityCheck() { std::string err_msg; err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : ""; err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : ""; - err_msg += builder_num_samples_ < 0 ? "Number of samples is set to negative\n" : ""; return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); } MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, - int64_t num_samples, std::unique_ptr data_schema, std::shared_ptr sampler) + std::unique_ptr data_schema, std::shared_ptr sampler) : ParallelOp(num_workers, queue_size), buf_cnt_(0), row_cnt_(0), - num_rows_(0), - num_samples_(num_samples), folder_path_(folder_path), rows_per_buffer_(rows_per_buffer), sampler_(std::move(sampler)), @@ -84,8 +82,7 @@ MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folde Status MnistOp::TraversalSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - if (row_cnt_ >= num_samples_) break; // enough row read, break for loop + if ((*itr) >= num_rows_) continue; // index out of bound, skipping keys->push_back(*itr); row_cnt_++; if (row_cnt_ % rows_per_buffer_ == 0) { @@ -219,17 +216,6 @@ Status MnistOp::InitSampler() { return Status::OK(); } -// Derived from RandomAccessOp -Status MnistOp::GetNumSamples(int64_t *num) const { - if (num == nullptr || num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API " - "validation first."); - } - (*num) = num_samples_; - return Status::OK(); -} - // Derived from RandomAccessOp Status MnistOp::GetClassIds(std::map> *cls_ids) const { if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { @@ -364,7 +350,6 @@ Status MnistOp::ParseMnistData() { } image_label_pairs_.shrink_to_fit(); num_rows_ = image_label_pairs_.size(); - num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; return Status::OK(); } @@ -414,11 +399,11 @@ Status MnistOp::LaunchThreadsAndInitOp() { return Status::OK(); } -Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64_t *count) { +Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) { // the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader() std::shared_ptr op; *count = 0; - RETURN_IF_NOT_OK(Builder().SetDir(dir).SetNumSamples(numSamples).Build(&op)); + RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op)); RETURN_IF_NOT_OK(op->WalkAllFiles()); @@ -440,19 +425,6 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64 label_reader.close(); } - *count = (numSamples == 0 || *count < numSamples) ? *count : numSamples; - - return Status::OK(); -} - -// Derived from RandomAccessOp -Status MnistOp::GetNumRowsInDataset(int64_t *num) const { - if (num == nullptr || num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API " - "validation first."); - } - (*num) = num_rows_; return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h index 397a51710e2..4a31b7cf666 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h @@ -78,14 +78,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { return *this; } - // Setter method - // @param int64_t num_samples - // @return Builder setter method returns reference to the builder. - Builder &SetNumSamples(int64_t num_samples) { - builder_num_samples_ = num_samples; - return *this; - } - // Setter method // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. @@ -114,7 +106,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { private: std::string builder_dir_; int32_t builder_num_workers_; - int64_t builder_num_samples_; int32_t builder_rows_per_buffer_; int32_t builder_op_connector_size_; std::shared_ptr builder_sampler_; @@ -126,11 +117,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp { // @param int32_t rows_per_buffer - number of images (rows) in each buffer // @param std::string folder_path - dir directory of mnist // @param int32_t queue_size - connector queue size - // @param int64_t num_samples - number of samples to read // @param std::unique_ptr data_schema - the schema of the mnist dataset // @param td::unique_ptr sampler - sampler tells MnistOp what to read MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, - int64_t num_samples, std::unique_ptr data_schema, std::shared_ptr sampler); + std::unique_ptr data_schema, std::shared_ptr sampler); // Destructor. ~MnistOp() = default; @@ -146,16 +136,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { // @return Status - The error code return Status operator()() override; - // Method derived from RandomAccess Op, enable Sampler to get numRows - // @param int64_t num - to return numRows - // @return Status - The error code return - Status GetNumSamples(int64_t *num) const override; - - // Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset - // @param int64_t num - to return numRows - // @return Status - The error code return - Status GetNumRowsInDataset(int64_t *num) const override; - // Method derived from RandomAccess Op, enable Sampler to get all ids for each class // @param (std::map> * map - key label, val all ids for this class // @return Status - The error code return @@ -167,11 +147,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp { void Print(std::ostream &out, bool show_all) const override; // Function to count the number of samples in the MNIST dataset - // @param dir path to the MNSIT directory - // @param numSamples maximum number of samples requested + // @param dir path to the MNIST directory // @param count output arg that will hold the minimum of the actual dataset size and numSamples // @return - static Status CountTotalRows(const std::string &dir, int64_t numSamples, int64_t *count); + static Status CountTotalRows(const std::string &dir, int64_t *count); private: // Initialize Sampler, calls sampler->Init() within @@ -244,9 +223,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { int64_t buf_cnt_; int64_t row_cnt_; - int64_t num_rows_; // total number of images in Mnist WaitPost wp_; - int64_t num_samples_; std::string folder_path_; // directory of image folder int32_t rows_per_buffer_; std::shared_ptr sampler_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt index 152b887ef44..5209d9ba4ad 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt @@ -8,6 +8,5 @@ add_library(engine-datasetops-source-sampler OBJECT sampler.cc sequential_sampler.cc subset_random_sampler.cc - subset_sampler.cc weighted_random_sampler.cc ) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index d4e5a732db7..77207e9a6cc 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -23,8 +23,9 @@ namespace mindspore { namespace dataset { -DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shuffle, uint32_t seed) - : Sampler(), +DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, + uint32_t seed) + : Sampler(num_samples, std::numeric_limits::max()), cnt_(0), seed_(seed == std::numeric_limits::max() ? GetSeed() : seed), device_id_(dev_id), @@ -32,6 +33,11 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu shuffle_(shuffle) {} Status DistributedSampler::InitSampler() { + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n"); CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h index 29b5cda0da6..aeea2bfe5dd 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -27,10 +27,11 @@ namespace mindspore { namespace dataset { class DistributedSampler : public Sampler { public: - // @param int64_t numDev - // @param int64_t devId + // @param num_samples + // @param int64_t num_dev + // @param int64_t dev_id // @param bool shuffle - DistributedSampler(int64_t num_dev, int64_t dev_id, bool shuffle = true, + DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, uint32_t seed = std::numeric_limits::max()); // default destructor diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc index 72c2cc18746..48c59c45032 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc @@ -20,12 +20,11 @@ namespace mindspore { namespace dataset { -PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer) - : Sampler(samples_per_buffer), +PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), shuffle_(shuffle), seed_(GetSeed()), next_id_(0), - num_pk_samples_(0), samples_per_class_(val) {} Status PKSampler::InitSampler() { @@ -36,22 +35,34 @@ Status PKSampler::InitSampler() { } } rnd_.seed(seed_++); - num_pk_samples_ = samples_per_class_ * static_cast(labels_.size()); - samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_; - num_samples_ = num_pk_samples_; + + // The special handshake gives the list of classes and id's, but it did not set the num_rows_ to + // capture the total number of possible sample ids. + // Compute that here for this case to find the total number of samples that are available to return. + // (in this case, samples per class * total classes). + num_rows_ = samples_per_class_ * static_cast(labels_.size()); + + // The user may have chosen to sample less than the total amount. + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + + samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; if (shuffle_ == true) { std::shuffle(labels_.begin(), labels_.end(), rnd_); } else { std::sort(labels_.begin(), labels_.end()); } - CHECK_FAIL_RETURN_UNEXPECTED(num_pk_samples_ > 0, "num_class or K (num samples per class) is not positive"); + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_class or K (num samples per class) is not positive"); return Status::OK(); } Status PKSampler::GetNextBuffer(std::unique_ptr *out_buffer) { - if (next_id_ > num_pk_samples_ || num_pk_samples_ == 0) { + if (next_id_ > num_samples_ || num_samples_ == 0) { RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler"); - } else if (next_id_ == num_pk_samples_) { + } else if (next_id_ == num_samples_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { @@ -60,8 +71,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr *out_buffer) { (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); std::shared_ptr sample_ids; - int64_t last_id = - (samples_per_buffer_ + next_id_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_ + next_id_; + int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_)); int64_t *id_ptr = reinterpret_cast(sample_ids->GetMutableBuffer()); while (next_id_ < last_id) { @@ -85,7 +95,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr *out_buffer) { } Status PKSampler::Reset() { - CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late"); + CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); next_id_ = 0; rnd_.seed(seed_++); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h index 14f598a9ce4..a8538874ec4 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h @@ -28,10 +28,11 @@ namespace mindspore { namespace dataset { class PKSampler : public Sampler { // NOT YET FINISHED public: - // @param int64_t kVal + // @param num_samples - the number of samples to draw. value of 0 means to take the full amount + // @param int64_t val // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit PKSampler(int64_t val, bool shuffle = false, + explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer = std::numeric_limits::max()); // default destructor @@ -42,8 +43,9 @@ class PKSampler : public Sampler { // NOT YET FINISHED // @return - The error code return Status GetNextBuffer(std::unique_ptr *out_buffer) override; - // first handshake between StorageOp and Sampler - // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() + // first handshake between leaf source op and Sampler. This func will determine the amount of data + // in the dataset that we can sample from. + // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is // @return Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; @@ -58,7 +60,6 @@ class PKSampler : public Sampler { // NOT YET FINISHED bool shuffle_; uint32_t seed_; int64_t next_id_; - int64_t num_pk_samples_; int64_t samples_per_class_; std::mt19937 rnd_; std::vector labels_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc index ca999e31a53..bff11a0b448 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -20,8 +20,8 @@ namespace mindspore { namespace dataset { -PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer) - : Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} +PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} Status PythonSampler::GetNextBuffer(std::unique_ptr *out_buffer) { if (need_to_reset_) { @@ -65,6 +65,11 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr *out_buffer) { Status PythonSampler::InitSampler() { CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } { py::gil_scoped_acquire gil_acquire; if (Py_IsInitialized() == 0) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h index b8734fee6af..bba9804952c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -26,8 +26,11 @@ namespace dataset { class PythonSampler : public Sampler { public: // Constructor - // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit PythonSampler(py::object py_sampler_instance, + // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the + // data from the dataset. + // @param py_sampler_instance - the python instance of the sampler + // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc index 0de55e0fb48..2adf6bc8c79 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -22,12 +22,11 @@ namespace mindspore { namespace dataset { -RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples, +RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, int64_t samples_per_buffer) - : Sampler(samples_per_buffer), + : Sampler(num_samples, samples_per_buffer), seed_(GetSeed()), replacement_(replacement), - user_num_samples_(num_samples), next_id_(0), reshuffle_each_epoch_(reshuffle_each_epoch), dist(nullptr) {} @@ -70,27 +69,25 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr *out_buffer) { } Status RandomSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive."); - + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); + samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; rnd_.seed(seed_); if (replacement_ == false) { - num_samples_ = std::min(num_samples_, num_rows_); - num_samples_ = std::min(num_samples_, user_num_samples_); - shuffled_ids_.reserve(num_rows_); for (int64_t i = 0; i < num_rows_; i++) { shuffled_ids_.push_back(i); } std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); } else { - num_samples_ = std::min(num_samples_, user_num_samples_); dist = std::make_unique>(0, num_rows_ - 1); } - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive."); - samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; - return Status::OK(); } @@ -119,7 +116,6 @@ void RandomSampler::Print(std::ostream &out, bool show_all) const { out << "(sampler): RandomSampler\n"; if (show_all) { - out << "user_num_samples_: " << user_num_samples_ << '\n'; out << "num_samples_: " << num_samples_ << '\n'; out << "next_id_: " << next_id_ << '\n'; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h index 352751dbb8b..bb8bb724289 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h @@ -27,11 +27,11 @@ namespace dataset { class RandomSampler : public Sampler { public: // Constructor + // @param int64_t num_samples - number samples to draw // @param bool replacement - put he id back / or not after a sample - // @param int64_t numSamples - number samples to draw - // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true, - int64_t num_samples = std::numeric_limits::max(), + // @param reshuffle_each_epoch - T/F to reshuffle after epoch + // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. @@ -55,7 +55,6 @@ class RandomSampler : public Sampler { private: uint32_t seed_; bool replacement_; - int64_t user_num_samples_; std::vector shuffled_ids_; // only used for NO REPLACEMENT int64_t next_id_; std::mt19937 rnd_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index 600d8c576b1..7c96a2c54aa 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -19,8 +19,25 @@ namespace mindspore { namespace dataset { -Sampler::Sampler(int64_t samples_per_buffer) - : DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} +Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { + // The sampler base class itself does not compute it's own num_rows_ value. + // Instead, this value is computed by the derived leaf op during it's own initialization + // after it has interacted with it's storage layers. + // Here, it is just a getter method to return the value. However, it is invalid if there is + // not a value set for this count, so generate a failure if that is the case. + if (num == nullptr || num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet."); + } + (*num) = num_rows_; + return Status::OK(); +} + +Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) + : DatasetOp(0), + num_rows_(0), + num_samples_(num_samples), + samples_per_buffer_(samples_per_buffer), + col_desc_(nullptr) {} Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { std::shared_ptr child_sampler; @@ -36,10 +53,10 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { } CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); - RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); + + // If there's a child sampler, set the row count to be it's sample count if (HasChildSampler()) { - int64_t child_num_samples = child_sampler->num_samples(); - num_rows_ = child_num_samples; + num_rows_ = child_sampler->num_samples_; } else { RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); } @@ -105,7 +122,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { } Status Sampler::SetNumSamples(int64_t num_samples) { - CHECK_FAIL_RETURN_UNEXPECTED(num_samples > 0, "num_samples is negative or 0"); + CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "num_samples is negative"); num_samples_ = num_samples; return Status::OK(); } @@ -116,6 +133,16 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) { return Status::OK(); } +// inline op doesn't have it's own consumer, it's assigned from parent +int32_t Sampler::num_consumers() const { + if (parent_.empty() || parent_[0] == nullptr) { + MS_LOG(WARNING) << "Sampler with no parent. num_consumers is 0."; + return 0; + } else { + return parent_[0]->num_consumers(); + } +} + Status Sampler::AddChild(std::shared_ptr child) { if (child == nullptr) { return Status::OK(); @@ -155,5 +182,14 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { return Status::OK(); } +// inline op doesn't have it's own producers, it's assigned from child +int32_t Sampler::num_producers() const { + if (child_.empty() || child_[0] == nullptr) { + MS_LOG(WARNING) << "Sampler with no child, num_producers is 0."; + return 0; + } else { + return child_[0]->num_producers(); + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h index 936a80bb381..8880e5e9f8a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h @@ -33,23 +33,10 @@ namespace dataset { // must inherit from if those leaf operator wish to support sampling. class RandomAccessOp { public: - // Sampler get numRows from StorageOp - // @param int64_t num - return number of rows, normally num of samples - // @return - The error code return - virtual Status GetNumSamples(int64_t *num_samples) const { - // CI complains num_samples not used if the following line is not added - CHECK_FAIL_RETURN_UNEXPECTED(num_samples != nullptr, "num_samples == nullptr"); - RETURN_STATUS_UNEXPECTED("function GetNumSamples needs to overridden to support this sampler"); - } - - // Sampler get number of rows in the dataset! + // Sampler get number of rows in the dataset // @param int64_t num - return number of rows for this dataset // @return - The error code return - virtual Status GetNumRowsInDataset(int64_t *num_rows) const { - // CI complains num_rows not used if the following line is not added - CHECK_FAIL_RETURN_UNEXPECTED(num_rows != nullptr, "num_rows == nullptr"); - RETURN_STATUS_UNEXPECTED("function GetNumRowsInDataset needs to overridden to support this sampler"); - } + Status GetNumRowsInDataset(int64_t *num_rows) const; // sampler gets label , imageIds from storageOp, this function is unique to PK // @param std::map> * map @@ -60,12 +47,20 @@ class RandomAccessOp { // default destructor virtual ~RandomAccessOp() = default; + + protected: + // The amount of rows in the dataset itself. This is the before-sampling value, the + // total count of rows. A sampler may choose to sample less than this amount. + int64_t num_rows_; }; class Sampler : public DatasetOp { public: + // Constructor + // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 + // indicates that the sampler should produce the complete set of ids. // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit Sampler(int64_t samples_per_buffer = std::numeric_limits::max()); + explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); // default destructor ~Sampler() = default; @@ -84,33 +79,36 @@ class Sampler : public DatasetOp { // @return - The error code return Status Reset() override = 0; - // setter function for num_rows_ - Status SetNumRowsInDataset(int64_t num_rows); - - // setter function for num_samples_ - Status SetNumSamples(int64_t num_samples); - - int64_t num_samples() { return num_samples_; } - - // first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples - // @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds() + // first handshake between leaf source op and Sampler. This func will determine the amount of data + // in the dataset that we can sample from. + // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is // @return virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); // initialize sampler and perform checks on certain vars virtual Status InitSampler() { return Status::OK(); } - // Not meant to be called + // setter for num samples + // @param num_samples - the number of samples to assign. + // @return status error code + Status SetNumSamples(int64_t num_samples); + + // setter for num or records in the dataset + // @param num_rows - the number of records + // @return status error code + Status SetNumRowsInDataset(int64_t num_rows); + + // Sampler is an inlined op and has no workers. Producers and consumers are computed. // @return int32_t num_workers() const final { return 0; } - // Not meant to be called + // Identify num consumers (inlined op) // @return - int32_t num_consumers() const final { return 0; } + int32_t num_consumers() const final; - // Not meant to be called + // Identify num producers (inlined op) // @return - int32_t num_producers() const final { return 0; } + int32_t num_producers() const final; // Not meant to be called! // @return - The error code return @@ -151,10 +149,11 @@ class Sampler : public DatasetOp { // output. Otherwise, num_rows_ is the number of rows in the dataset. int64_t num_rows_; - // Number of ids this sampler will return. + // The user may want to sample less than the full amount of data. num_samples_ reduces the number + // of id's returned as request by the user. Derived classes will choose how to sample the smaller + // amount. int64_t num_samples_; - // The max number of ids a DataBuffer returned by this sampler will contain. int64_t samples_per_buffer_; std::unique_ptr col_desc_; std::unique_ptr child_ids_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index 789f232e1e0..b26fc630671 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -20,34 +20,42 @@ namespace mindspore { namespace dataset { -SequentialSampler::SequentialSampler(int64_t samples_per_buffer) : Sampler(samples_per_buffer), next_id_(0) {} +SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} Status SequentialSampler::GetNextBuffer(std::unique_ptr *out_buffer) { - if (next_id_ > num_samples_) { - RETURN_STATUS_UNEXPECTED("Sequential Sampler Internal Error"); - } else if (next_id_ == num_samples_) { + if (id_count_ > num_samples_) { + RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); + } else if (id_count_ == num_samples_) { (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); } - (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); + (*out_buffer) = std::make_unique(current_id_, DataBuffer::kDeBFlagNone); std::shared_ptr sampleIds; - int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; - RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_)); + + // Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for + // samples per buffer though. + int64_t remaining_ids = num_samples_ - id_count_; + int64_t num_elements = std::min(remaining_ids, samples_per_buffer_); + + RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements)); int64_t *idPtr = reinterpret_cast(sampleIds->GetMutableBuffer()); - while (next_id_ < lastId) { - int64_t sampled_id = next_id_; + for (int64_t i = 0; i < num_elements; i++) { + int64_t sampled_id = current_id_; if (HasChildSampler()) { RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); } *idPtr = sampled_id; - next_id_++; + current_id_++; // Move the current id to the next one in the sequence idPtr++; } + id_count_ += num_elements; // Count the packed ids towards our overall sample count + TensorRow row(1, sampleIds); (*out_buffer)->set_tensor_table(std::make_unique(1, row)); } @@ -55,19 +63,24 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr *out_buffer) } Status SequentialSampler::InitSampler() { - num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set - if (HasChildSampler()) { - num_samples_ = std::min(num_samples_, num_rows_); + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0, "num_samples < 0\n"); + // Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample + // the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data. + int64_t available_row_count = num_rows_ - start_index_; + if (num_samples_ == 0 || num_samples_ > available_row_count) { + num_samples_ = available_row_count; } - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; return Status::OK(); } Status SequentialSampler::Reset() { - CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); - next_id_ = 0; + CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); + current_id_ = start_index_; + id_count_ = 0; if (HasChildSampler()) { RETURN_IF_NOT_OK(child_[0]->Reset()); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h index 4e195d75dbb..46cfb7a3047 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -26,8 +26,12 @@ namespace dataset { class SequentialSampler : public Sampler { public: // Constructor + // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the + // full amount of ids from the dataset + // @param start_index - The starting index value // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit SequentialSampler(int64_t samples_per_buffer = std::numeric_limits::max()); + explicit SequentialSampler(int64_t num_samples, int64_t start_index, + int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. ~SequentialSampler() = default; @@ -48,7 +52,9 @@ class SequentialSampler : public Sampler { void Print(std::ostream &out, bool show_all) const override; private: - int64_t next_id_; + int64_t current_id_; // The id sequencer. Each new id increments from this + int64_t start_index_; // The starting id. current_id_ begins from here. + int64_t id_count_; // An internal counter that tracks how many ids have been produced }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc index ca1160299a2..0dfeb1a191b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -27,22 +27,28 @@ namespace mindspore { namespace dataset { // Constructor. -SubsetRandomSampler::SubsetRandomSampler(const std::vector &indices, int64_t samples_per_buffer) - : Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} +SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector &indices, + int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} // Initialized this Sampler. Status SubsetRandomSampler::InitSampler() { CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); - num_samples_ = indices_.size(); - + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // In this case, the id's are provided by the user. Cap the num_samples on the number of id's given. + if (num_samples_ == 0 || num_samples_ > static_cast(indices_.size())) { + num_samples_ = static_cast(indices_.size()); + } // Initialize random generator with seed from config manager rand_gen_.seed(GetSeed()); - if (static_cast(samples_per_buffer_) > indices_.size()) { - samples_per_buffer_ = static_cast(indices_.size()); + if (samples_per_buffer_ > num_samples_) { + samples_per_buffer_ = num_samples_; } + // num_samples_ could be smaller than the total number of input id's. + // We will shuffle the full set of id's, but only select the first num_samples_ of them later. std::shuffle(indices_.begin(), indices_.end(), rand_gen_); return Status::OK(); @@ -68,7 +74,7 @@ Status SubsetRandomSampler::Reset() { // Get the sample ids. Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr *out_buffer) { // All samples have been drawn - if (sample_id_ == indices_.size()) { + if (sample_id_ == num_samples_) { (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { @@ -80,8 +86,8 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr *out_buffe int64_t last_id = sample_id_ + samples_per_buffer_; // Handling the return all samples at once, and when last draw is not a full batch. - if (static_cast(last_id) > indices_.size()) { - last_id = indices_.size(); + if (last_id > num_samples_) { + last_id = num_samples_; } // Allocate tensor diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h index 1f4c155748b..d1ab13c5404 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h @@ -28,10 +28,11 @@ namespace dataset { class SubsetRandomSampler : public Sampler { public: // Constructor. + // @param num_samples The number of samples to draw. 0 for the full amount. // @param indices List of indices from where we will randomly draw samples. // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. - explicit SubsetRandomSampler(const std::vector &indices, + explicit SubsetRandomSampler(int64_t num_samples, const std::vector &indices, std::int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc deleted file mode 100644 index 0ae7a7d5031..00000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.cc +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/subset_sampler.h" - -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" - -namespace mindspore { -namespace dataset { -// Constructor. -SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size) - : Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {} - -Status SubsetSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n"); - CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); - CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); - CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n"); - - num_samples_ = subset_size_; - - return Status::OK(); -} - -Status SubsetSampler::Reset() { - current_id_ = 0; - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->Reset()); - } - - return Status::OK(); -} - -Status SubsetSampler::GetNextBuffer(std::unique_ptr *out_buffer) { - if (current_id_ > subset_size_) { - RETURN_STATUS_UNEXPECTED("SubsetSampler Internal Error"); - } else if (current_id_ == subset_size_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); - } - - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); - std::shared_ptr sampled_ids; - RETURN_IF_NOT_OK(CreateSamplerTensor(&sampled_ids, subset_size_)); - - int64_t *sampled_ids_start_addr = reinterpret_cast(sampled_ids->GetMutableBuffer()); - - while (current_id_ < subset_size_) { - int64_t sampled_id = start_index_ + current_id_; - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *(sampled_ids_start_addr + current_id_) = sampled_id; - current_id_++; - } - - TensorRow sampled_ids_row(1, sampled_ids); - (*out_buffer)->set_tensor_table(std::make_unique(1, sampled_ids_row)); - } - - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h deleted file mode 100644 index 5e8774f6732..00000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_sampler.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ - -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { - -class SubsetSampler : public Sampler { - public: - // Constructor. - // @param start_index The index we start sampling from. - explicit SubsetSampler(int64_t start_index, int64_t subset_size); - - // Destructor. - ~SubsetSampler() = default; - - // Initialize the sampler. - // @return Status - Status InitSampler() override; - - // Reset the internal variable to the initial state and reshuffle the indices. - // @return Status - Status Reset() override; - - // Get the sample ids. - // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. - // @note the sample ids (int64_t) will be placed in one Tensor. - Status GetNextBuffer(std::unique_ptr *out_buffer) override; - - private: - int64_t start_index_; - int64_t subset_size_; - int64_t current_id_; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc index 5027dcdd67b..96b2571786a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -27,25 +27,28 @@ namespace mindspore { namespace dataset { // Constructor. -WeightedRandomSampler::WeightedRandomSampler(const std::vector &weights, int64_t num_samples, bool replacement, +WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, int64_t samples_per_buffer) - : Sampler(samples_per_buffer), + : Sampler(num_samples, samples_per_buffer), weights_(weights), replacement_(replacement), sample_id_(0), - buffer_id_(0), - user_num_samples_(num_samples) {} + buffer_id_(0) {} // Initialized this Sampler. Status WeightedRandomSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive"); + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, "num_samples & num_rows need to be positive"); CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); - num_samples_ = user_num_samples_; // Initialize random generator with seed from config manager rand_gen_.seed(GetSeed()); - samples_per_buffer_ = (samples_per_buffer_ > user_num_samples_) ? user_num_samples_ : samples_per_buffer_; + samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; if (!replacement_) { exp_dist_ = std::make_unique>(1); @@ -67,8 +70,8 @@ void WeightedRandomSampler::InitOnePassSampling() { } // Partial sort the first `numSamples` elements. - std::partial_sort(val_idx.begin(), val_idx.begin() + user_num_samples_, val_idx.end()); - for (int64_t i = 0; i < user_num_samples_; i++) { + std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end()); + for (int64_t i = 0; i < num_samples_; i++) { onepass_ids_.push_back(val_idx[i].second); } } @@ -98,11 +101,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr *out_buf "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); } - if (!replacement_ && (weights_.size() < static_cast(user_num_samples_))) { + if (!replacement_ && (weights_.size() < static_cast(num_samples_))) { RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); } - if (sample_id_ == user_num_samples_) { + if (sample_id_ == num_samples_) { (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); } else { if (HasChildSampler()) { @@ -114,8 +117,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr *out_buf int64_t last_id = sample_id_ + samples_per_buffer_; // Handling the return all samples at once, and when last draw is not a full batch. - if (last_id > user_num_samples_) { - last_id = user_num_samples_; + if (last_id > num_samples_) { + last_id = num_samples_; } // Allocate tensor. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h index 5381bb64b08..775176ccdac 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h @@ -29,12 +29,12 @@ namespace dataset { class WeightedRandomSampler : public Sampler { public: // Constructor. - // @param weights A lift of sample weights. // @param num_samples Number of samples to be drawn. + // @param weights A lift of sample weights. // @param replacement Determine if samples are drawn with/without replacement. // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. - WeightedRandomSampler(const std::vector &weights, int64_t num_samples, bool replacement = true, + WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, int64_t samples_per_buffer = std::numeric_limits::max()); // Destructor. @@ -69,9 +69,6 @@ class WeightedRandomSampler : public Sampler { // Random engine and device std::mt19937 rand_gen_; - // num_samples from user - int64_t user_num_samples_; - // Discrete distribution for generating weighted random numbers with replacement. std::unique_ptr> discrete_dist_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc index e51eb4e00d5..40ffe7e9abb 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc @@ -33,7 +33,7 @@ namespace mindspore { namespace dataset { TextFileOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -62,7 +62,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); std::shared_ptr text_file_op = std::make_shared( - builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, + builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); RETURN_IF_NOT_OK(text_file_op->Init()); @@ -71,14 +71,14 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { return Status::OK(); } -TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, +TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr schema, std::vector text_files_list, int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) : ParallelOp(num_workers, op_connector_size), device_id_(device_id), num_devices_(num_device), rows_per_buffer_(rows_per_buffer), - num_samples_(num_samples), + total_rows_(total_rows), text_files_list_(std::move(text_files_list)), shuffle_files_(shuffle_files), data_schema_(std::move(schema)), @@ -104,9 +104,9 @@ void TextFileOp::Print(std::ostream &out, bool show_all) const { // Call the super class for displaying any common detailed info ParallelOp::Print(out, show_all); // Then show any custom derived-internal stuff - out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ - << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ - << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nText files list:\n"; + out << "\nRows per buffer: " << rows_per_buffer_ << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ + << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") + << "\nText files list:\n"; for (int i = 0; i < text_files_list_.size(); ++i) { out << " " << text_files_list_[i]; } @@ -404,9 +404,9 @@ Status TextFileOp::operator()() { RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); if (buffer->eoe()) { workers_done++; - } else if (num_samples_ == 0 || rows_read < num_samples_) { - if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { - int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); + } else if (total_rows_ == 0 || rows_read < total_rows_) { + if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) { + int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read); RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); } rows_read += buffer->NumRows(); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h index 8b8eda00feb..63dae54930d 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h @@ -107,8 +107,8 @@ class TextFileOp : public ParallelOp { // Setter method. // @return Builder - setter method returns reference to the builder. - Builder &SetNumSamples(int64_t num_samples) { - builder_num_samples_ = num_samples; + Builder &SetTotalRows(int64_t total_rows) { + builder_total_rows_ = total_rows; return *this; } @@ -118,7 +118,7 @@ class TextFileOp : public ParallelOp { int32_t builder_num_workers_; int32_t builder_op_connector_size_; int64_t builder_rows_per_buffer_; - int64_t builder_num_samples_; + int64_t builder_total_rows_; int32_t builder_worker_connector_size_; std::vector builder_text_files_list_; bool builder_shuffle_files_; @@ -136,7 +136,7 @@ class TextFileOp : public ParallelOp { // @param columns_to_load - the names of the columns to load data from. // @param shuffle_files - whether or not to shuffle the files before reading data. // @param equal_rows_per_shard - whether or not to get equal rows for each process. - TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); @@ -246,7 +246,7 @@ class TextFileOp : public ParallelOp { int32_t device_id_; int32_t num_devices_; int64_t rows_per_buffer_; - int64_t num_samples_; + int64_t total_rows_; std::vector text_files_list_; bool shuffle_files_; std::unique_ptr data_schema_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc index d96b3a8872d..2cbb1756b9e 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc @@ -44,7 +44,7 @@ const char kSegmentationExtension[] = ".png"; const char kAnnotationExtension[] = ".xml"; const char kImageSetsExtension[] = ".txt"; -VOCOp::Builder::Builder() : builder_decode_(false), builder_num_samples_(0), builder_sampler_(nullptr) { +VOCOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { std::shared_ptr cfg = GlobalContext::config_manager(); builder_num_workers_ = cfg->num_parallel_workers(); builder_rows_per_buffer_ = cfg->rows_per_buffer(); @@ -55,7 +55,9 @@ VOCOp::Builder::Builder() : builder_decode_(false), builder_num_samples_(0), bui Status VOCOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); if (builder_sampler_ == nullptr) { - builder_sampler_ = std::make_shared(); + int64_t num_samples = 0; + int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); } builder_schema_ = std::make_unique(); if (builder_task_type_ == TaskType::Segmentation) { @@ -71,8 +73,7 @@ Status VOCOp::Builder::Build(std::shared_ptr *ptr) { } *ptr = std::make_shared(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_, builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, - builder_num_samples_, builder_decode_, std::move(builder_schema_), - std::move(builder_sampler_)); + builder_decode_, std::move(builder_schema_), std::move(builder_sampler_)); return Status::OK(); } @@ -81,20 +82,16 @@ Status VOCOp::Builder::SanityCheck() { std::string err_msg; err_msg += dir.IsDirectory() == false ? "VOC path is invalid or not set\n" : ""; err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : ""; - err_msg += builder_num_samples_ < 0 ? "num_samples is negative\n" : ""; return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); } VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, - int32_t queue_size, int64_t num_samples, bool decode, std::unique_ptr data_schema, - std::shared_ptr sampler) + int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler) : ParallelOp(num_workers, queue_size), decode_(decode), row_cnt_(0), buf_cnt_(0), - num_rows_(0), - num_samples_(num_samples), task_type_(task_type), task_mode_(task_mode), folder_path_(folder_path), @@ -112,7 +109,6 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std: Status VOCOp::TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { if ((*itr) > num_rows_) continue; - if (row_cnt_ == num_samples_) break; keys->push_back(*itr); row_cnt_++; if (row_cnt_ % rows_per_buffer_ == 0) { @@ -187,16 +183,6 @@ Status VOCOp::Reset() { return Status::OK(); } -Status VOCOp::GetNumSamples(int64_t *num) const { - if (num == nullptr || num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API VOCDataset.Please check file path or dataset API " - "validation first."); - } - (*num) = num_samples_; - return Status::OK(); -} - Status VOCOp::LoadTensorRow(const std::string &image_id, TensorRow *trow) { if (task_type_ == TaskType::Segmentation) { std::shared_ptr image, target; @@ -280,7 +266,6 @@ Status VOCOp::ParseImageIds() { in_file.close(); image_ids_.shrink_to_fit(); num_rows_ = image_ids_.size(); - num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; return Status::OK(); } @@ -305,7 +290,6 @@ Status VOCOp::ParseAnnotationIds() { } num_rows_ = image_ids_.size(); - num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; return Status::OK(); } @@ -432,19 +416,8 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescripto return Status::OK(); } -// Derived from RandomAccessOp -Status VOCOp::GetNumRowsInDataset(int64_t *num) const { - if (num == nullptr || num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API VOCDataset.Please check file path or dataset API " - "validation first."); - } - (*num) = num_rows_; - return Status::OK(); -} - Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t numSamples, int64_t *count) { + const py::dict &dict, int64_t *count) { if (task_type == "Detection") { std::map input_class_indexing; for (auto p : dict) { @@ -464,14 +437,12 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ RETURN_IF_NOT_OK(op->ParseImageIds()); *count = static_cast(op->image_ids_.size()); } - *count = (numSamples == 0 || *count < numSamples) ? *count : numSamples; return Status::OK(); } Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t numSamples, - std::map *output_class_indexing) { + const py::dict &dict, std::map *output_class_indexing) { std::map input_class_indexing; for (auto p : dict) { (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h index 203ec05fabb..8a823614619 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h @@ -116,14 +116,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { return *this; } - // Setter method. - // @param int64_t num_samples - // @return Builder setter method returns reference to the builder. - Builder &SetNumSamples(int64_t num_samples) { - builder_num_samples_ = num_samples; - return *this; - } - // Setter method. // @param std::shared_ptr sampler // @return Builder setter method returns reference to the builder. @@ -157,7 +149,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { int32_t builder_num_workers_; int32_t builder_op_connector_size_; int32_t builder_rows_per_buffer_; - int64_t builder_num_samples_; std::shared_ptr builder_sampler_; std::unique_ptr builder_schema_; std::map builder_labels_to_read_; @@ -171,14 +162,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp { // @param int32_t num_workers - number of workers reading images in parallel // @param int32_t rows_per_buffer - number of images (rows) in each buffer // @param int32_t queue_size - connector queue size - // @param int64_t num_samples - number of samples to read // @param bool decode - whether to decode images // @param std::unique_ptr data_schema - the schema of the VOC dataset // @param std::shared_ptr sampler - sampler tells VOCOp what to read VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, - int32_t queue_size, int64_t num_samples, bool decode, std::unique_ptr data_schema, - std::shared_ptr sampler); + int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler); // Destructor ~VOCOp() = default; @@ -194,15 +183,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { // @return Status - The error code return Status operator()() override; - // Method derived from RandomAccessOp, enable Sampler to get numRows - // @param uint64_t num - to return numRows - // return Status - The error code return - Status GetNumSamples(int64_t *num) const override; - - // Method derived from RandomAccessOp, enable Sampler to get total number of rows in dataset - // @param uint64_t num - to return numRows - Status GetNumRowsInDataset(int64_t *num) const override; - // A print method typically used for debugging // @param out // @param show_all @@ -212,10 +192,9 @@ class VOCOp : public ParallelOp, public RandomAccessOp { // @param const std::string &task_type - task type of reading voc job // @param const std::string &task_mode - task mode of reading voc job // @param const py::dict &dict - input dict of class index - // @param int64_t numSamples - samples number of VOCDataset // @param int64_t *count - output rows number of VOCDataset static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t numSamples, int64_t *count); + const py::dict &dict, int64_t *count); // @param const std::string &dir - VOC dir path // @param const std::string &task_type - task type of reading voc job @@ -224,8 +203,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { // @param int64_t numSamples - samples number of VOCDataset // @param std::map *output_class_indexing - output class index of VOCDataset static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t numSamples, - std::map *output_class_indexing); + const py::dict &dict, std::map *output_class_indexing); private: // Initialize Sampler, calls sampler->Init() within @@ -283,8 +261,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { bool decode_; int64_t row_cnt_; int64_t buf_cnt_; - int64_t num_rows_; - int64_t num_samples_; std::string folder_path_; TaskType task_type_; std::string task_mode_; diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index ceca1881120..0631ade36aa 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ Schema, Shuffle, zip, RandomDataset from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ - WeightedRandomSampler, SubsetSampler, Sampler + WeightedRandomSampler, Sampler from .engine.serializer_deserializer import serialize, deserialize, show from .engine.graphdata import GraphData diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8d05aebbd1e..0afb6ce6b0a 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1261,8 +1261,8 @@ class MappableDataset(SourceDataset): def _get_sampler_dataset_size(self): if self.sampler is not None: - if hasattr(self.sampler, 'get_dataset_size'): - return self.sampler.get_dataset_size() + if hasattr(self.sampler, 'get_num_samples'): + return self.sampler.get_num_samples() if hasattr(self.sampler, '__len__'): return len(self.sampler) @@ -1355,7 +1355,7 @@ class MappableDataset(SourceDataset): random_sampler.reshuffle_each_epoch = False ds.add_sampler(random_sampler) - subset_sampler = samplers.SubsetSampler(current_split_start_index, size) + subset_sampler = samplers.SequentialSampler(current_split_start_index, size) ds.add_sampler(subset_sampler) # add sequential sampler, so that if user calls use_sampler, we will @@ -2226,31 +2226,45 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): num_shards (int): Number of shard for sharding. shard_id (int): Shard ID. """ + if input_sampler is not None: + # If the user provided a sampler, then it doesn't matter what the other args are because + # we are being asked specifically to use the given sampler. + # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all + # be None. Consider this example: + # sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle) + # data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1) + # In this case, the user has given different sample-related arguments that contradict each other. + # To prevent this, only allow the user to manually specify the sampler if those arguments are all None + if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler, + samplers.RandomSampler, samplers.SubsetRandomSampler, + samplers.WeightedRandomSampler, samplers.Sampler)) and + (num_shards is not None or shard_id is not None or shuffle is not None or num_samples is not None)): + raise ValueError( + 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},' + ' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle)) + return input_sampler if shuffle is None: - if input_sampler is not None: - # If shuffle is not specified, user provided sampler, use user's sampler - return input_sampler if num_shards is not None: # If shuffle is not specified, sharding enabled, use distributed random sampler shuffle = True - return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle) + return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) # If shuffle is not specified, sharding disabled, use random sampler if num_samples is not None: return samplers.RandomSampler(replacement=True, num_samples=num_samples) - return samplers.RandomSampler() + return samplers.RandomSampler(num_samples=num_samples) if shuffle is True: if num_shards is not None: # If shuffle enabled, sharding enabled, use distributed random sampler - return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle) + return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) # If shuffle enabled, sharding disabled, use random sampler if num_samples is not None: return samplers.RandomSampler(replacement=True, num_samples=num_samples) - return samplers.RandomSampler() + return samplers.RandomSampler(num_samples=num_samples) if num_shards is not None: # If shuffle disabled, sharding enabled, use distributed sequential sampler - return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle) + return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) # If shuffle disabled, sharding disabled, use sequential sampler - return samplers.SequentialSampler() + return samplers.SequentialSampler(num_samples=num_samples) class ImageFolderDatasetV2(MappableDataset): @@ -2370,11 +2384,7 @@ class ImageFolderDatasetV2(MappableDataset): Return: Number, number of batches. """ - if self.num_samples is None: - num_samples = 0 - else: - num_samples = self.num_samples - num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0] + num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0] rows_per_shard = get_num_rows(num_rows, self.num_shards) rows_from_sampler = self._get_sampler_dataset_size() @@ -2390,11 +2400,7 @@ class ImageFolderDatasetV2(MappableDataset): Return: Number, number of classes. """ - if self.num_samples is None: - num_samples = 0 - else: - num_samples = self.num_samples - return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[1] + return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[1] def is_shuffled(self): if self.shuffle_level is None: @@ -2503,12 +2509,7 @@ class MnistDataset(MappableDataset): Return: Number, number of batches. """ - if self.num_samples is None: - num_samples = 0 - else: - num_samples = self.num_samples - - num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples) + num_rows = MnistOp.get_num_rows(self.dataset_dir) rows_per_shard = get_num_rows(num_rows, self.num_shards) rows_from_sampler = self._get_sampler_dataset_size() @@ -2956,11 +2957,8 @@ class GeneratorDataset(MappableDataset): if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SubsetRandomSampler, samplers.WeightedRandomSampler, samplers.Sampler)): - if num_samples is None: - num_samples = len(source) sampler_instance = self.sampler.create() sampler_instance.set_num_rows(len(source)) - sampler_instance.set_num_samples(num_samples) sampler_instance.initialize() if num_parallel_workers > 1: self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers)) @@ -3304,17 +3302,12 @@ class ManifestDataset(MappableDataset): Return: Number, number of batches. """ - if self.num_samples is None: - num_samples = 0 - else: - num_samples = self.num_samples - if self.class_indexing is None: class_indexing = dict() else: class_indexing = self.class_indexing - num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0] + num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0] rows_per_shard = get_num_rows(num_rows, self.num_shards) rows_from_sampler = self._get_sampler_dataset_size() @@ -3330,17 +3323,12 @@ class ManifestDataset(MappableDataset): Return: Number, number of classes. """ - if self.num_samples is None: - num_samples = 0 - else: - num_samples = self.num_samples - if self.class_indexing is None: class_indexing = dict() else: class_indexing = self.class_indexing - return ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[1] + return ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[1] def get_class_indexing(self): """ @@ -3349,17 +3337,12 @@ class ManifestDataset(MappableDataset): Return: Dict, A str-to-int mapping from label name to index. """ - if self.num_samples is None: - num_samples = 0 - else: - num_samples = self.num_samples - if self.class_indexing is None: class_indexing = dict() else: class_indexing = self.class_indexing - return ManifestOp.get_class_indexing(self.dataset_file, num_samples, class_indexing, self.usage) + return ManifestOp.get_class_indexing(self.dataset_file, class_indexing, self.usage) def is_shuffled(self): if self.shuffle_level is None: @@ -3473,12 +3456,8 @@ class Cifar10Dataset(MappableDataset): Return: Number, number of batches. """ - if self.num_samples is None: - num_samples = 0 - else: - num_samples = self.num_samples - num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True) + num_rows = CifarOp.get_num_rows(self.dataset_dir, True) rows_per_shard = get_num_rows(num_rows, self.num_shards) rows_from_sampler = self._get_sampler_dataset_size() @@ -3597,12 +3576,8 @@ class Cifar100Dataset(MappableDataset): Return: Number, number of batches. """ - if self.num_samples is None: - num_samples = 0 - else: - num_samples = self.num_samples - num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False) + num_rows = CifarOp.get_num_rows(self.dataset_dir, False) rows_per_shard = get_num_rows(num_rows, self.num_shards) rows_from_sampler = self._get_sampler_dataset_size() @@ -3631,7 +3606,7 @@ class RandomDataset(SourceDataset): Args: num_samples (int): number of samples to generate. schema (str or Schema, optional): Path to the json schema file or schema object (default=None). - If the schema is not provided, the meta data from the TFRecord file is considered the schema. + If the schema is not provided, the random dataset generates a random schema. columns_list (list[str], optional): List of columns to be read (default=None, read all columns) num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). @@ -3644,9 +3619,12 @@ class RandomDataset(SourceDataset): schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it self.schema = schema self.columns_list = columns_list - self.num_samples = num_samples if schema_obj is not None and num_samples is None: self.num_samples = schema_obj.num_rows + elif num_samples is None: + self.num_samples = 0 + else: + self.num_samples = num_samples def get_args(self): args = super().get_args() @@ -4015,17 +3993,12 @@ class VOCDataset(MappableDataset): if self.task != "Detection": raise NotImplementedError() - if self.num_samples is None: - num_samples = 0 - else: - num_samples = self.num_samples - if self.class_indexing is None: class_indexing = dict() else: class_indexing = self.class_indexing - return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) + return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing) def is_shuffled(self): if self.shuffle_level is None: @@ -4205,9 +4178,11 @@ class TextFileDataset(SourceDataset): if self._dataset_size is None: num_rows = TextFileOp.get_num_rows(self.dataset_files) num_rows = get_num_rows(num_rows, self.num_shards) - if self.num_samples is None: - return num_rows - return min(self.num_samples, num_rows) + # If the user gave a num samples in the dataset, then the sampler will limit the rows returned + # to that amount. Account for that here in the row count + if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples: + num_rows = self.num_samples + return num_rows return self._dataset_size def is_shuffled(self): diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 8951a1c4a08..3ae917dd41c 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -22,7 +22,6 @@ User can also define custom sampler by extending from Sampler class. import numpy as np import mindspore._c_dataengine as cde - class Sampler: """ Base class for user defined sampler. @@ -44,10 +43,10 @@ class Sampler: >>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler()) """ - def __init__(self): + def __init__(self, num_samples=None): self.dataset_size = 0 - self.num_samples = 0 self.child_sampler = None + self.num_samples = num_samples def __iter__(self): """ @@ -84,7 +83,8 @@ class Sampler: # Instance fetcher # Do not override this method! def create(self): - c_sampler = cde.PythonSampler(self) + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.PythonSampler(num_samples, self) c_child_sampler = self.create_child() c_sampler.add_child(c_child_sampler) return c_sampler @@ -114,7 +114,7 @@ class Sampler: return self.child_sampler.is_sharded() - def get_dataset_size(self): + def get_num_samples(self): return self._get_indices().size @@ -124,8 +124,9 @@ class BuiltinSampler: User should not extend this class. """ - def __init__(self): + def __init__(self, num_samples=None): self.child_sampler = None + self.num_samples = num_samples def create(self): pass @@ -149,11 +150,37 @@ class BuiltinSampler: def is_sharded(self): raise NotImplementedError("Sampler must implement is_sharded.") - def get_dataset_size(self): - if self.child_sampler is not None: - return self.child_sampler.get_dataset_size() + def get_num_samples(self): + """ + All samplers can contain a numeric num_samples value (or it could be set to None). + Child sampler can exist or be None. + if child sampler exists, then the child sampler count can be a numeric value or None. + Given these conditions, we need to output what the sampler count is for this sampler. + The following table shows the possible results from calling this function. - return None + child sampler num_samples child_samples result + ------------- ----------- ------------- -------- + T x y min(x, y) + T x None x + T None y y + T None None None + None x n/a x + None None n/a None + + Returns: + int, The number of samples, or None + """ + if self.child_sampler is not None: + child_samples = self.child_sampler.get_num_samples() + if self.num_samples is not None: + if child_samples is not None: + return min(self.num_samples, child_samples) + + return self.num_samples + + return child_samples + + return self.num_samples class DistributedSampler(BuiltinSampler): @@ -164,6 +191,7 @@ class DistributedSampler(BuiltinSampler): num_shards (int): Number of shards to divide the dataset into. shard_id (int): Shard ID of the current shard within num_shards. shuffle (bool, optional): If true, the indices are shuffled (default=True). + num_samples (int, optional): The number of samples to draw (default=None, all elements). Examples: >>> import mindspore.dataset as ds @@ -180,7 +208,7 @@ class DistributedSampler(BuiltinSampler): ValueError: If shuffle is not a boolean value. """ - def __init__(self, num_shards, shard_id, shuffle=True): + def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None): if num_shards <= 0: raise ValueError("num_shards should be a positive integer value, but got num_shards={}".format(num_shards)) @@ -194,12 +222,13 @@ class DistributedSampler(BuiltinSampler): self.shard_id = shard_id self.shuffle = shuffle self.seed = 0 - super().__init__() + super().__init__(num_samples) def create(self): + num_samples = self.num_samples if self.num_samples is not None else 0 # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle self.seed += 1 - c_sampler = cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) + c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id, self.shuffle, self.seed) c_child_sampler = self.create_child() c_sampler.add_child(c_child_sampler) return c_sampler @@ -226,6 +255,7 @@ class PKSampler(BuiltinSampler): num_class (int, optional): Number of classes to sample (default=None, all classes). shuffle (bool, optional): If true, the class IDs are shuffled (default=False). class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset. + num_samples (int, optional): The number of samples to draw (default=None, all elements). Examples: >>> import mindspore.dataset as ds @@ -242,7 +272,7 @@ class PKSampler(BuiltinSampler): ValueError: If shuffle is not boolean. """ - def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'): + def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None): if num_val <= 0: raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) @@ -255,10 +285,11 @@ class PKSampler(BuiltinSampler): self.num_val = num_val self.shuffle = shuffle self.class_column = class_column # work for minddataset - super().__init__() + super().__init__(num_samples) def create(self): - c_sampler = cde.PKSampler(self.num_val, self.shuffle) + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle) c_child_sampler = self.create_child() c_sampler.add_child(c_child_sampler) return c_sampler @@ -309,23 +340,18 @@ class RandomSampler(BuiltinSampler): raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement)) if num_samples is not None: - if num_samples <= 0: + if num_samples < 0: raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(num_samples)) self.deterministic = False self.replacement = replacement - self.num_samples = num_samples self.reshuffle_each_epoch = True - super().__init__() + super().__init__(num_samples) def create(self): - c_sampler = None - if self.num_samples is None: - c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch) - else: - c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch, self.num_samples) - + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) c_child_sampler = self.create_child() c_sampler.add_child(c_child_sampler) return c_sampler @@ -339,14 +365,15 @@ class RandomSampler(BuiltinSampler): return self.child_sampler.is_sharded() - def get_dataset_size(self): - return self.num_samples - class SequentialSampler(BuiltinSampler): """ Samples the dataset elements sequentially, same as not having a sampler. + Args: + start_index (int, optional): Index to start sampling at. (dafault=None starts at first id) + num_samples (int, optional): Number of elements to sample (default=None, all elements). + Examples: >>> import mindspore.dataset as ds >>> @@ -357,66 +384,14 @@ class SequentialSampler(BuiltinSampler): >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler) """ - def create(self): - c_sampler = cde.SequentialSampler() - c_child_sampler = self.create_child() - c_sampler.add_child(c_child_sampler) - return c_sampler - - def is_shuffled(self): - if self.child_sampler is None: - return False - - return self.child_sampler.is_shuffled() - - def is_sharded(self): - if self.child_sampler is None: - return False - - return self.child_sampler.is_sharded() - - -class SubsetSampler(BuiltinSampler): - """ - Samples a subset of elements consecutively from a given index. - - Args: - start_index (int): Index to start sampling at. - subset_size (int): How many samples to include in this subset. - - Examples: - >>> import mindspore.dataset as ds - >>> - >>> dataset_dir = "path/to/imagefolder_directory" - >>> - >>> # creates a SubsetSampler, will sample the next 5 images from the 100th image. - >>> sampler = ds.SubsetSampler(100, 5) - >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler) - - Raises: - ValueError: If start_index is not a positive int. - ValueError: If subset_size is not a positive int. - """ - - def __init__(self, start_index, subset_size): - if not isinstance(start_index, int): - raise ValueError("start_index should be an int.") - - if start_index < 0: - raise ValueError("start_index should not be negative.") - - if not isinstance(subset_size, int): - raise ValueError("start_index should be an int") - - if subset_size < 0: - raise ValueError("subset_size should not be negative.") - + def __init__(self, start_index=None, num_samples=None): self.start_index = start_index - self.subset_size = subset_size - super().__init__() + super().__init__(num_samples) def create(self): - c_sampler = cde.SubsetSampler(self.start_index, self.subset_size) + start_index = self.start_index if self.start_index is not None else 0 + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.SequentialSampler(num_samples, start_index) c_child_sampler = self.create_child() c_sampler.add_child(c_child_sampler) return c_sampler @@ -433,9 +408,6 @@ class SubsetSampler(BuiltinSampler): return self.child_sampler.is_sharded() - def get_dataset_size(self): - return self.subset_size - class SubsetRandomSampler(BuiltinSampler): """ @@ -443,6 +415,7 @@ class SubsetRandomSampler(BuiltinSampler): Args: indices (list[int]): A sequence of indices. + num_samples (int, optional): Number of elements to sample (default=None, all elements). Examples: >>> import mindspore.dataset as ds @@ -456,15 +429,16 @@ class SubsetRandomSampler(BuiltinSampler): >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler) """ - def __init__(self, indices): + def __init__(self, indices, num_samples=None): if not isinstance(indices, list): indices = [indices] self.indices = indices - super().__init__() + super().__init__(num_samples) def create(self): - c_sampler = cde.SubsetRandomSampler(self.indices) + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.SubsetRandomSampler(num_samples, self.indices) c_child_sampler = self.create_child() c_sampler.add_child(c_child_sampler) return c_sampler @@ -481,9 +455,9 @@ class SubsetRandomSampler(BuiltinSampler): def _create_for_minddataset(self): return cde.MindrecordSubsetRandomSampler(self.indices) - - def get_dataset_size(self): - return len(self.indices) + def get_num_samples(self): + num_samples = super().get_num_samples() + return min(len(self.indices), num_samples) class WeightedRandomSampler(BuiltinSampler): @@ -492,7 +466,7 @@ class WeightedRandomSampler(BuiltinSampler): Args: weights (list[float]): A sequence of weights, not necessarily summing up to 1. - num_samples (int): Number of elements to sample. + num_samples (int): Number of elements to sample (default=None, all elements). replacement (bool, optional): If True, put the sample ID back for the next draw (default=True). Examples: @@ -511,24 +485,25 @@ class WeightedRandomSampler(BuiltinSampler): ValueError: If replacement is not boolean. """ - def __init__(self, weights, num_samples, replacement=True): + def __init__(self, weights, num_samples=None, replacement=True): if not isinstance(weights, list): weights = [weights] - if num_samples <= 0: - raise ValueError("num_samples should be a positive integer " - "value, but got num_samples={}".format(num_samples)) + if num_samples is not None: + if num_samples < 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(num_samples)) if not isinstance(replacement, bool): raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement)) self.weights = weights - self.num_samples = num_samples self.replacement = replacement - super().__init__() + super().__init__(num_samples) def create(self): - c_sampler = cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement) + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement) c_child_sampler = self.create_child() c_sampler.add_child(c_child_sampler) return c_sampler @@ -541,6 +516,3 @@ class WeightedRandomSampler(BuiltinSampler): return False return self.child_sampler.is_sharded() - - def get_dataset_size(self): - return self.num_samples diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index 688ef167537..6c30a51516e 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -161,6 +161,20 @@ def traverse(node): else: node_repr[k] = v + # If a sampler exists in this node, then the following 4 arguments must be set to None: + # num_samples, shard_id, num_shards, shuffle + # These arguments get moved into the sampler itself, so they are no longer needed to + # be set at the dataset level. + if 'sampler' in node_args.keys(): + if 'num_samples' in node_repr.keys(): + node_repr['num_samples'] = None + if 'shuffle' in node_repr.keys(): + node_repr['shuffle'] = None + if 'num_shards' in node_repr.keys(): + node_repr['num_shards'] = None + if 'shard_id' in node_repr.keys(): + node_repr['shard_id'] = None + # Leaf node doesn't have input attribute. if not node.input: return node_repr diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 049931c80e6..4893aace361 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -283,8 +283,8 @@ def check_num_parallel_workers(value): def check_num_samples(value): check_type(value, 'num_samples', int) - if value <= 0: - raise ValueError("num_samples must be greater than 0!") + if value < 0: + raise ValueError("num_samples cannot be less than 0!") def check_dataset_dir(dataset_dir): diff --git a/tests/ut/cpp/dataset/celeba_op_test.cc b/tests/ut/cpp/dataset/celeba_op_test.cc index 35be4d73787..5fa50a85ff1 100644 --- a/tests/ut/cpp/dataset/celeba_op_test.cc +++ b/tests/ut/cpp/dataset/celeba_op_test.cc @@ -39,14 +39,13 @@ std::shared_ptr Repeat(int repeat_cnt); std::shared_ptr Build(std::vector> ops); std::shared_ptr Celeba(int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, - const std::string &dir, int64_t num_samples = 0, - std::unique_ptr sampler = nullptr, bool decode = false, - const std::string &dataset_type="all") { + const std::string &dir, std::shared_ptr sampler = nullptr, + bool decode = false, const std::string &dataset_type="all") { std::shared_ptr so; CelebAOp::Builder builder; Status rc = builder.SetNumWorkers(num_workers).SetCelebADir(dir).SetRowsPerBuffer(rows_per_buffer) .SetOpConnectorSize(queue_size).SetSampler(std::move(sampler)).SetDecode(decode) - .SetNumSamples(num_samples).SetDatasetType(dataset_type).Build(&so); + .SetDatasetType(dataset_type).Build(&so); return so; } @@ -116,11 +115,12 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) { TEST_F(MindDataTestCelebaDataset, TestSubsetRandomSamplerCeleba) { std::vector indices({1}); - std::unique_ptr sampler = std::make_unique(indices); + int64_t num_samples = 0; + std::shared_ptr sampler = std::make_shared(num_samples, indices); uint32_t expect_labels[1][40] = {{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}}; std::string dir = datasets_root_path_ + "/testCelebAData/"; uint32_t count = 0; - auto tree = Build({Celeba(16, 2, 32, dir, 0, std::move(sampler))}); + auto tree = Build({Celeba(16, 2, 32, dir, std::move(sampler))}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { @@ -143,25 +143,3 @@ TEST_F(MindDataTestCelebaDataset, TestSubsetRandomSamplerCeleba) { EXPECT_TRUE(count == 1); } } - -TEST_F(MindDataTestCelebaDataset, TestCelebaNumSamples) { - std::string dir = datasets_root_path_ + "/testCelebAData/"; - uint32_t count = 0; - auto tree = Build({Celeba(16, 2, 32, dir, 1)}); - tree->Prepare(); - Status rc = tree->Launch(); - if (rc.IsError()) { - MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << "."; - EXPECT_TRUE(false); - } else { - DatasetIterator di(tree); - TensorMap tersor_map; - di.GetNextAsMap(&tersor_map); - EXPECT_TRUE(rc.IsOk()); - while (tersor_map.size() != 0) { - count++; - di.GetNextAsMap(&tersor_map); - } - EXPECT_TRUE(count == 1); - } -} diff --git a/tests/ut/cpp/dataset/cifar_op_test.cc b/tests/ut/cpp/dataset/cifar_op_test.cc index 8eeeba76afe..2992bc91a8a 100644 --- a/tests/ut/cpp/dataset/cifar_op_test.cc +++ b/tests/ut/cpp/dataset/cifar_op_test.cc @@ -45,13 +45,12 @@ std::shared_ptr Repeat(int repeatCnt); std::shared_ptr Build(std::vector> ops); std::shared_ptr Cifarop(uint64_t num_works, uint64_t rows, uint64_t conns, std::string path, - std::unique_ptr sampler = nullptr, - uint64_t num_samples = 0, bool cifar10 = true) { + std::shared_ptr sampler = nullptr, bool cifar10 = true) { std::shared_ptr so; CifarOp::Builder builder; Status rc = builder.SetNumWorkers(num_works).SetCifarDir(path).SetRowsPerBuffer(rows) .SetOpConnectorSize(conns).SetSampler(std::move(sampler)).SetCifarType(cifar10) - .SetNumSamples(num_samples).Build(&so); + .Build(&so); return so; } @@ -66,7 +65,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) { //appear in this dataset //Example: python tests/dataset/data/prep_data.py std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100)}); + auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr)}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { @@ -79,7 +78,8 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) { EXPECT_TRUE(rc.IsOk()); uint64_t i = 0; uint32_t label = 0; - while (tensor_map.size() != 0) { + // Note: only iterating first 100 rows then break out. + while (tensor_map.size() != 0 && i < 100) { tensor_map["label"]->GetItemAt(&label, {}); MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; i++; @@ -92,9 +92,9 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) { TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) { uint32_t original_seed = GlobalContext::config_manager()->seed(); GlobalContext::config_manager()->set_seed(0); - std::unique_ptr sampler = std::make_unique(true, true, 12); + std::shared_ptr sampler = std::make_unique(12, true, true); std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)}); + auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler))}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { @@ -118,34 +118,9 @@ TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) { GlobalContext::config_manager()->set_seed(original_seed); } -TEST_F(MindDataTestCifarOp, TestCifar10NumSample) { - std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; - auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100)}); - tree->Prepare(); - Status rc = tree->Launch(); - if (rc.IsError()) { - MS_LOG(ERROR) << "Return code error detected during tree launch: " << common::SafeCStr(rc.ToString()) << "."; - EXPECT_TRUE(false); - } else { - DatasetIterator di(tree); - TensorMap tensor_map; - di.GetNextAsMap(&tensor_map); - EXPECT_TRUE(rc.IsOk()); - uint64_t i = 0; - uint32_t label = 0; - while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); - MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; - i++; - di.GetNextAsMap(&tensor_map); - } - EXPECT_TRUE(i == 100); - } -} - TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar100) { std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; - auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100, false)}); + auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, false)}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { @@ -159,7 +134,8 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar100) { uint64_t i = 0; uint32_t coarse = 0; uint32_t fine = 0; - while (tensor_map.size() != 0) { + // only iterate to 100 then break out of loop + while (tensor_map.size() != 0 && i < 100) { tensor_map["coarse_label"]->GetItemAt(&coarse, {}); tensor_map["fine_label"]->GetItemAt(&fine, {}); MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << " coarse:" diff --git a/tests/ut/cpp/dataset/image_folder_op_test.cc b/tests/ut/cpp/dataset/image_folder_op_test.cc index 380b7cd02b5..cd72d8f18a8 100644 --- a/tests/ut/cpp/dataset/image_folder_op_test.cc +++ b/tests/ut/cpp/dataset/image_folder_op_test.cc @@ -50,9 +50,8 @@ std::shared_ptr Repeat(int repeat_cnt); std::shared_ptr Build(std::vector> ops); std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, - bool shuf = false, std::unique_ptr sampler = nullptr, - std::map map = {}, int64_t num_samples = 0, - bool decode = false) { + bool shuf = false, std::shared_ptr sampler = nullptr, + std::map map = {}, bool decode = false) { std::shared_ptr so; ImageFolderOp::Builder builder; Status rc = builder.SetNumWorkers(num_works) @@ -63,7 +62,6 @@ std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int6 .SetSampler(std::move(sampler)) .SetClassIndex(map) .SetDecode(decode) - .SetNumSamples(num_samples) .Build(&so); return so; } @@ -138,7 +136,8 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) { TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) { int32_t original_seed = GlobalContext::config_manager()->seed(); GlobalContext::config_manager()->set_seed(0); - std::unique_ptr sampler = std::make_unique(true, true, 12); + int64_t num_samples = 12; + std::shared_ptr sampler = std::make_unique(num_samples, true, true); int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label std::string folder_path = datasets_root_path_ + "/testPK/data"; auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); @@ -200,7 +199,8 @@ TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeatBatch) TEST_F(MindDataTestImageFolderSampler, TestSubsetRandomSamplerImageFolder) { // id range 0 - 10 is label 0, and id range 11 - 21 is label 1 std::vector indices({0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11}); - std::unique_ptr sampler = std::make_unique(indices); + int64_t num_samples = 0; + std::shared_ptr sampler = std::make_shared(num_samples, indices); std::string folder_path = datasets_root_path_ + "/testPK/data"; // Expect 6 samples for label 0 and 1 int res[2] = {6, 6}; @@ -237,8 +237,8 @@ TEST_F(MindDataTestImageFolderSampler, TestWeightedRandomSamplerImageFolder) { std::vector weights(total_samples, std::rand() % 100); // create sampler with replacement = replacement - std::unique_ptr sampler = - std::make_unique(weights, num_samples, true, samples_per_buffer); + std::shared_ptr sampler = + std::make_shared(num_samples, weights, true, samples_per_buffer); std::string folder_path = datasets_root_path_ + "/testPK/data"; auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); @@ -295,7 +295,8 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderClassIndex) { } TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) { - std::unique_ptr sampler = std::make_unique(11, 10, false); + int64_t num_samples = 0; + std::shared_ptr sampler = std::make_shared(num_samples, 11, 10, false); std::string folder_path = datasets_root_path_ + "/testPK/data"; auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)), Repeat(4)}); tree->Prepare(); @@ -322,7 +323,8 @@ TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) { } TEST_F(MindDataTestImageFolderSampler, TestPKSamplerImageFolder) { - std::unique_ptr sampler = std::make_unique(3, false, 4); + int64_t num_samples = 0; + std::shared_ptr sampler = std::make_shared(num_samples, 3, false, 4); int32_t res[] = {0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3}; // ground truth label std::string folder_path = datasets_root_path_ + "/testPK/data"; auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); @@ -349,39 +351,16 @@ TEST_F(MindDataTestImageFolderSampler, TestPKSamplerImageFolder) { } } -TEST_F(MindDataTestImageFolderSampler, TestImageFolderNumSamples) { - std::string folder_path = datasets_root_path_ + "/testPK/data"; - auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, nullptr, {}, 11), Repeat(2)}); - tree->Prepare(); - Status rc = tree->Launch(); - if (rc.IsError()) { - MS_LOG(ERROR) << "Return code error detected during tree launch: " << common::SafeCStr(rc.ToString()) << "."; - EXPECT_TRUE(false); - } else { - DatasetIterator di(tree); - TensorMap tensor_map; - di.GetNextAsMap(&tensor_map); - EXPECT_TRUE(rc.IsOk()); - uint64_t i = 0; - int32_t label = 0; - while (tensor_map.size() != 0) { - tensor_map["label"]->GetItemAt(&label, {}); - EXPECT_TRUE(0 == label); - MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; - i++; - di.GetNextAsMap(&tensor_map); - } - EXPECT_TRUE(i == 22); - } -} - TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) { std::string folder_path = datasets_root_path_ + "/testPK/data"; std::map map; map["class3"] = 333; map["class1"] = 111; map["wrong folder name"] = 1234; // this is skipped - auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, nullptr, map, 20, true)}); + int64_t num_samples = 20; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(seq_sampler), map, true)}); int64_t res[2] = {111, 333}; tree->Prepare(); Status rc = tree->Launch(); @@ -408,33 +387,12 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) { } } -TEST_F(MindDataTestImageFolderSampler, TestImageFolderDatasetSize) { - std::string folder_path = datasets_root_path_ + "/testPK/data"; - int64_t num_rows = 0; - int64_t num_classes = 0; - ImageFolderOp::CountRowsAndClasses(folder_path, 15, {}, &num_rows, &num_classes); - EXPECT_TRUE(num_rows == 15 && num_classes == 4); - ImageFolderOp::CountRowsAndClasses(folder_path, 44, {}, &num_rows, &num_classes); - EXPECT_TRUE(num_rows == 44 && num_classes == 4); - ImageFolderOp::CountRowsAndClasses(folder_path, 0, {}, &num_rows, &num_classes); - EXPECT_TRUE(num_rows == 44 && num_classes == 4); - ImageFolderOp::CountRowsAndClasses(folder_path, 55, {}, &num_rows, &num_classes); - EXPECT_TRUE(num_rows == 44 && num_classes == 4); - ImageFolderOp::CountRowsAndClasses(folder_path, 44, {}, &num_rows, &num_classes, 2, 3); - EXPECT_TRUE(num_rows == 15 && num_classes == 4); - ImageFolderOp::CountRowsAndClasses(folder_path, 33, {}, &num_rows, &num_classes, 0, 3); - EXPECT_TRUE(num_rows == 15 && num_classes == 4); - ImageFolderOp::CountRowsAndClasses(folder_path, 13, {}, &num_rows, &num_classes, 0, 11); - EXPECT_TRUE(num_rows == 4 && num_classes == 4); - ImageFolderOp::CountRowsAndClasses(folder_path, 3, {}, &num_rows, &num_classes, 0, 11); - EXPECT_TRUE(num_rows == 3 && num_classes == 4); -} - TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) { - std::unique_ptr sampler = std::make_unique(4, 0, false); + int64_t num_samples = 5; + std::shared_ptr sampler = std::make_shared(num_samples, 4, 0, false); std::string folder_path = datasets_root_path_ + "/testPK/data"; // numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode - auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler), {}, 5)}); + auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler), {})}); tree->Prepare(); Status rc = tree->Launch(); int32_t labels[5] = {0, 0, 0, 1, 1}; @@ -460,10 +418,11 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) { } TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding2) { - std::unique_ptr sampler = std::make_unique(4, 3, false); + int64_t num_samples = 12; + std::shared_ptr sampler = std::make_shared(num_samples, 4, 3, false); std::string folder_path = datasets_root_path_ + "/testPK/data"; // numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode - auto tree = Build({ImageFolder(16, 16, 32, folder_path, false, std::move(sampler), {}, 12)}); + auto tree = Build({ImageFolder(16, 16, 32, folder_path, false, std::move(sampler), {})}); tree->Prepare(); Status rc = tree->Launch(); uint32_t labels[11] = {0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3}; diff --git a/tests/ut/cpp/dataset/manifest_op_test.cc b/tests/ut/cpp/dataset/manifest_op_test.cc index f662f98fc8e..35773f6bbbf 100644 --- a/tests/ut/cpp/dataset/manifest_op_test.cc +++ b/tests/ut/cpp/dataset/manifest_op_test.cc @@ -23,6 +23,7 @@ #include "dataset/core/client.h" #include "dataset/core/global_context.h" #include "dataset/engine/datasetops/source/manifest_op.h" +#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include "dataset/util/de_error.h" #include "dataset/util/status.h" @@ -42,14 +43,13 @@ std::shared_ptr Repeat(int repeatCnt); std::shared_ptr Build(std::vector> ops); std::shared_ptr Manifest(int32_t num_works, int32_t rows, int32_t conns, const std::string &file, - std::string usage = "train", std::unique_ptr sampler = nullptr, - std::map map = {}, uint64_t num_samples = 0, - bool decode = false) { + std::string usage = "train", std::shared_ptr sampler = nullptr, + std::map map = {}, bool decode = false) { std::shared_ptr so; ManifestOp::Builder builder; Status rc = builder.SetNumWorkers(num_works).SetManifestFile(file).SetRowsPerBuffer( rows).SetOpConnectorSize(conns).SetSampler(std::move(sampler)).SetClassIndex(map).SetDecode(decode) - .SetNumSamples(num_samples).SetUsage(usage).Build(&so); + .SetUsage(usage).Build(&so); return so; } @@ -86,7 +86,8 @@ TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) { TEST_F(MindDataTestManifest, TestSubsetRandomSamplerManifest) { std::vector indices({1}); - std::unique_ptr sampler = std::make_unique(indices); + int64_t num_samples = 0; + std::shared_ptr sampler = std::make_shared(num_samples, indices); std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; // Expect 6 samples for label 0 and 1 auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(sampler))}); @@ -145,7 +146,10 @@ TEST_F(MindDataTestManifest, MindDataTestManifestClassIndex) { TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) { std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; - auto tree = Build({Manifest(16, 2, 32, file, "train", nullptr, {}, 1), Repeat(4)}); + int64_t num_samples = 1; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}), Repeat(4)}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { @@ -171,7 +175,10 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) { TEST_F(MindDataTestManifest, MindDataTestManifestEval) { std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; - auto tree = Build({Manifest(16, 2, 32, file, "eval", nullptr, {}, 1)}); + int64_t num_samples = 1; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + auto tree = Build({Manifest(16, 2, 32, file, "eval", std::move(seq_sampler), {})}); tree->Prepare(); Status rc = tree->Launch(); if (rc.IsError()) { diff --git a/tests/ut/cpp/dataset/map_op_test.cc b/tests/ut/cpp/dataset/map_op_test.cc index 7a990074371..881c711093a 100644 --- a/tests/ut/cpp/dataset/map_op_test.cc +++ b/tests/ut/cpp/dataset/map_op_test.cc @@ -120,9 +120,8 @@ class MindDataTestMapOp : public UT::DatasetOpTesting { }; std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, - bool shuf = false, std::unique_ptr sampler = nullptr, - std::map map = {}, int64_t num_samples = 0, - bool decode = false); + bool shuf = false, std::shared_ptr sampler = nullptr, + std::map map = {}, bool decode = false); std::shared_ptr Build(std::vector> ops); diff --git a/tests/ut/cpp/dataset/mnist_op_test.cc b/tests/ut/cpp/dataset/mnist_op_test.cc index 2733597b358..26b7335ad33 100644 --- a/tests/ut/cpp/dataset/mnist_op_test.cc +++ b/tests/ut/cpp/dataset/mnist_op_test.cc @@ -53,13 +53,11 @@ Status Create1DTensor(std::shared_ptr *sample_ids, int64_t num_elements, DataType::Type data_type = DataType::DE_UINT32); std::shared_ptr CreateMnist(int64_t num_wrks, int64_t rows, int64_t conns, std::string path, - bool shuf = false, std::unique_ptr sampler = nullptr, - int64_t num_samples = 0) { + bool shuf = false, std::shared_ptr sampler = nullptr) { std::shared_ptr so; MnistOp::Builder builder; Status rc = builder.SetNumWorkers(num_wrks).SetDir(path).SetRowsPerBuffer(rows) - .SetOpConnectorSize(conns).SetSampler(std::move(sampler)) - .SetNumSamples(num_samples).Build(&so); + .SetOpConnectorSize(conns).SetSampler(std::move(sampler)).Build(&so); return so; } @@ -74,7 +72,10 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) { // appear in this dataset // Example: python tests/dataset/data/prep_data.py std::string folder_path = datasets_root_path_ + "/testMnistData/"; - auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, nullptr, 10), Repeat(2)}); + int64_t num_samples = 10; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2)}); tree->Prepare(); uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; Status rc = tree->Launch(); @@ -101,7 +102,10 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) { TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) { std::string folder_path = datasets_root_path_ + "/testMnistData/"; - auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, nullptr, 10), Repeat(2), Batch(5)}); + int64_t num_samples = 10; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2), Batch(5)}); tree->Prepare(); uint32_t res[4][5] = { {0, 0, 0, 0, 0 }, {0, 0, 0, 0, 0 }, diff --git a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc index 6ab7d0498f9..39fe56e163a 100644 --- a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc +++ b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc @@ -43,20 +43,11 @@ class MindDataTestStandAloneSampler : public UT::DatasetOpTesting { protected: class MockStorageOp : public RandomAccessOp { public: - MockStorageOp(int64_t val) : m_val_(val) {} - - Status GetNumSamples(int64_t *ptr) const override { - (*ptr) = m_val_; - return Status::OK(); + MockStorageOp(int64_t val){ + // row count is in base class as protected member + // GetNumRowsInDataset does not need an override, the default from base class is fine. + num_rows_ = val; } - - Status GetNumRowsInDataset(int64_t *ptr) const override { - (*ptr) = m_val_; - return Status::OK(); - } - - private: - int64_t m_val_; }; }; @@ -73,8 +64,9 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { MockStorageOp mock(20); std::unique_ptr db; std::shared_ptr tensor; + int64_t num_samples = 0; for (int i = 0; i < 6; i++) { - std::unique_ptr sampler = std::make_unique(3, i % 3, (i < 3 ? false : true)); + std::shared_ptr sampler = std::make_shared(num_samples, 3, i % 3, (i < 3 ? false : true)); sampler->HandshakeRandomAccessOp(&mock); sampler->GetNextBuffer(&db); db->GetTensor(&tensor, 0, 0); @@ -92,7 +84,9 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { std::shared_ptr label1, label2; CreateINT64Tensor(&label1, 3, reinterpret_cast(res)); CreateINT64Tensor(&label2, 2, reinterpret_cast(res + 3)); - std::shared_ptr sampler = std::make_shared(3); + int64_t num_samples = 0; + int64_t start_index = 0; + std::shared_ptr sampler = std::make_shared(num_samples, start_index, 3); std::unique_ptr db; std::shared_ptr tensor; sampler->HandshakeRandomAccessOp(&mock); diff --git a/tests/ut/cpp/dataset/subset_random_sampler_test.cc b/tests/ut/cpp/dataset/subset_random_sampler_test.cc index bb8b3439d59..10050dbfb4f 100644 --- a/tests/ut/cpp/dataset/subset_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/subset_random_sampler_test.cc @@ -31,26 +31,17 @@ class MindDataTestSubsetRandomSampler : public UT::Common { public: class DummyRandomAccessOp : public RandomAccessOp { public: - DummyRandomAccessOp(int64_t num_rows) : num_rows_(num_rows) {}; - Status GetNumSamples(int64_t *num) const { - *num = num_rows_; - return Status::OK(); - } - - Status GetNumRowsInDataset(int64_t *num) const { - *num = num_rows_; - return Status::OK(); - } - - private: - int64_t num_rows_; + DummyRandomAccessOp(int64_t num_rows) { + num_rows_ = num_rows; // base class + }; }; }; TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { std::vector in({0, 1, 2, 3, 4}); std::unordered_set in_set(in.begin(), in.end()); - SubsetRandomSampler sampler(in); + int64_t num_samples = 0; + SubsetRandomSampler sampler(num_samples, in); DummyRandomAccessOp dummyRandomAccessOp(5); sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -77,8 +68,9 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { int64_t total_samples = 100000 - 5; int64_t samples_per_buffer = 10; + int64_t num_samples = 0; std::vector input(total_samples, 1); - SubsetRandomSampler sampler(input, samples_per_buffer); + SubsetRandomSampler sampler(num_samples, input, samples_per_buffer); DummyRandomAccessOp dummyRandomAccessOp(total_samples); sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -109,7 +101,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { TEST_F(MindDataTestSubsetRandomSampler, TestReset) { std::vector in({0, 1, 2, 3, 4}); std::unordered_set in_set(in.begin(), in.end()); - SubsetRandomSampler sampler(in); + int64_t num_samples = 0; + SubsetRandomSampler sampler(num_samples, in); DummyRandomAccessOp dummyRandomAccessOp(5); sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); diff --git a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc index 51a4bc3cb3a..a41dae532f3 100644 --- a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc @@ -35,19 +35,11 @@ class MindDataTestWeightedRandomSampler : public UT::Common { public: class DummyRandomAccessOp : public RandomAccessOp { public: - DummyRandomAccessOp(uint64_t num_rows) : num_rows_(num_rows) {}; - Status GetNumSamples(int64_t *num) const { - *num = num_rows_; - return Status::OK(); + DummyRandomAccessOp(uint64_t num_rows) { + // row count is in base class as protected member + // GetNumRowsInDataset does not need an override, the default from base class is fine. + num_rows_ = num_rows; } - - Status GetNumRowsInDataset(int64_t *num) const { - *num = num_rows_; - return Status::OK(); - } - - private: - uint64_t num_rows_; }; }; @@ -59,7 +51,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { std::vector freq(total_samples, 0); // create sampler with replacement = true - WeightedRandomSampler m_sampler(weights, num_samples, true); + WeightedRandomSampler m_sampler(num_samples, weights, true); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -89,7 +81,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { std::vector freq(total_samples, 0); // create sampler with replacement = replacement - WeightedRandomSampler m_sampler(weights, num_samples, false); + WeightedRandomSampler m_sampler(num_samples, weights, false); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -125,7 +117,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { std::vector weights(total_samples, std::rand() % 100); // create sampler with replacement = replacement - WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer); + WeightedRandomSampler m_sampler(num_samples, weights, true, samples_per_buffer); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -161,7 +153,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { std::vector freq(total_samples, 0); // create sampler with replacement = replacement - WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer); + WeightedRandomSampler m_sampler(num_samples, weights, false, samples_per_buffer); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -202,7 +194,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { std::vector freq(total_samples, 0); // create sampler with replacement = true - WeightedRandomSampler m_sampler(weights, num_samples, true); + WeightedRandomSampler m_sampler(num_samples, weights, true); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); @@ -247,7 +239,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { std::vector freq(total_samples, 0); // create sampler with replacement = true - WeightedRandomSampler m_sampler(weights, num_samples, false); + WeightedRandomSampler m_sampler(num_samples, weights, false); DummyRandomAccessOp dummyRandomAccessOp(total_samples); m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); diff --git a/tests/ut/python/dataset/test_datasets_imagefolder.py b/tests/ut/python/dataset/test_datasets_imagefolder.py index a88111ccbe1..8e5679076d4 100644 --- a/tests/ut/python/dataset/test_datasets_imagefolder.py +++ b/tests/ut/python/dataset/test_datasets_imagefolder.py @@ -58,7 +58,7 @@ def test_imagefolder_numsamples(): assert num_iter == 10 random_sampler = ds.RandomSampler(num_samples=3, replacement=True) - data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler) + data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_parallel_workers=2, sampler=random_sampler) num_iter = 0 for item in data1.create_dict_iterator(): @@ -67,7 +67,7 @@ def test_imagefolder_numsamples(): assert num_iter == 3 random_sampler = ds.RandomSampler(num_samples=3, replacement=False) - data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler) + data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_parallel_workers=2, sampler=random_sampler) num_iter = 0 for item in data1.create_dict_iterator(): diff --git a/tests/ut/python/dataset/test_datasets_sharding.py b/tests/ut/python/dataset/test_datasets_sharding.py index 02db3589e6e..94c39fb34c7 100644 --- a/tests/ut/python/dataset/test_datasets_sharding.py +++ b/tests/ut/python/dataset/test_datasets_sharding.py @@ -162,8 +162,8 @@ def test_voc_shardings(print_res=False): voc_dir = "../data/dataset/testVOC2012" def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1): - sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle) - data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_samples=num_samples) + sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) + data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler) data1 = data1.repeat(repeat_cnt) res = [] for item in data1.create_dict_iterator(): # each data is a dictionary diff --git a/tests/ut/python/dataset/test_exceptions.py b/tests/ut/python/dataset/test_exceptions.py index cb79d456d41..cbfa402bb06 100644 --- a/tests/ut/python/dataset/test_exceptions.py +++ b/tests/ut/python/dataset/test_exceptions.py @@ -35,18 +35,13 @@ def test_exception_01(): def test_exception_02(): """ - Test multiple exceptions with invalid input + Test exceptions with invalid input, and test valid input """ logger.info("test_exception_02") - num_samples = 0 - with pytest.raises(ValueError) as info: - data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) - assert "num_samples must be greater than 0" in str(info.value) - num_samples = -1 with pytest.raises(ValueError) as info: data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) - assert "num_samples must be greater than 0" in str(info.value) + assert "num_samples cannot be less than 0" in str(info.value) num_samples = 1 data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) diff --git a/tests/ut/python/dataset/test_generator.py b/tests/ut/python/dataset/test_generator.py index 30c36cdcb44..926b84a7f44 100644 --- a/tests/ut/python/dataset/test_generator.py +++ b/tests/ut/python/dataset/test_generator.py @@ -544,7 +544,7 @@ def test_distributed_sampler(): def test_num_samples(): source = [(np.array([x]),) for x in range(64)] num_samples = 32 - ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_samples=num_samples) + ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples)) ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples=num_samples) ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples) @@ -660,4 +660,6 @@ if __name__ == "__main__": test_sequential_sampler() test_distributed_sampler() test_random_sampler() + test_num_samples() + test_num_samples_underflow() test_schema() diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index e5586655791..381b6dafe7a 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -28,8 +28,8 @@ def test_sequential_sampler(print_res=False): map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} def test_config(num_samples, num_repeats=None): - sampler = ds.SequentialSampler() - data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler) + sampler = ds.SequentialSampler(num_samples=num_samples) + data1 = ds.ManifestDataset(manifest_file, sampler=sampler) if num_repeats is not None: data1 = data1.repeat(num_repeats) res = [] @@ -43,6 +43,7 @@ def test_sequential_sampler(print_res=False): assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2] assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2 + assert test_config(num_samples=0, num_repeats=2) == [0, 1, 2, 3, 4] * 2 assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2 @@ -119,8 +120,8 @@ def test_python_sampler(): return iter([i for i in range(self.dataset_size)]) class Sp2(ds.Sampler): - def __init__(self): - super(Sp2, self).__init__() + def __init__(self, num_samples=None): + super(Sp2, self).__init__(num_samples) # at this stage, self.dataset_size and self.num_samples are not yet known self.cnt = 0 @@ -130,8 +131,8 @@ def test_python_sampler(): def reset(self): self.cnt = (self.cnt + 1) % self.dataset_size - def test_config(num_samples, num_repeats, sampler): - data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler) + def test_config(num_repeats, sampler): + data1 = ds.ManifestDataset(manifest_file, sampler=sampler) if num_repeats is not None: data1 = data1.repeat(num_repeats) res = [] @@ -154,8 +155,8 @@ def test_python_sampler(): assert data[0] == (np.array(i),) i = i - 1 - assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] - assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] + assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] + assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] test_generator() sp1 = Sp1().create() @@ -169,9 +170,8 @@ def test_subset_sampler(): manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} - def test_config(num_samples, start_index, subset_size): - _ = num_samples - sampler = ds.SubsetSampler(start_index, subset_size) + def test_config(start_index, num_samples): + sampler = ds.SequentialSampler(start_index, num_samples) d = ds.ManifestDataset(manifest_file, sampler=sampler) res = [] @@ -180,19 +180,15 @@ def test_subset_sampler(): return res - with pytest.raises(RuntimeError) as info: - test_config(5, 0, 0) - assert "subset_size <= 0" in str(info.value) - - assert test_config(5, 0, 1) == [0] - assert test_config(5, 0, 2) == [0, 1] - assert test_config(5, 0, 3) == [0, 1, 2] - assert test_config(5, 0, 4) == [0, 1, 2, 3] - assert test_config(5, 0, 5) == [0, 1, 2, 3, 4] - assert test_config(5, 1, 1) == [1] - assert test_config(5, 2, 3) == [2, 3, 4] - assert test_config(5, 3, 2) == [3, 4] - assert test_config(5, 4, 1) == [4] + assert test_config(0, 1) == [0] + assert test_config(0, 2) == [0, 1] + assert test_config(0, 3) == [0, 1, 2] + assert test_config(0, 4) == [0, 1, 2, 3] + assert test_config(0, 5) == [0, 1, 2, 3, 4] + assert test_config(1, 1) == [1] + assert test_config(2, 3) == [2, 3, 4] + assert test_config(3, 2) == [3, 4] + assert test_config(4, 1) == [4] def test_sampler_chain(): @@ -200,11 +196,11 @@ def test_sampler_chain(): map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} def test_config(num_shards, shard_id): - sampler = ds.DistributedSampler(num_shards, shard_id, False) + sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=False, num_samples=5) child_sampler = ds.SequentialSampler() sampler.add_child(child_sampler) - data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler) + data1 = ds.ManifestDataset(manifest_file, sampler=sampler) res = [] for item in data1.create_dict_iterator(): @@ -234,6 +230,11 @@ def test_add_sampler_invalid_input(): data1.use_sampler("sampler") assert "not an instance of a sampler" in str(info.value) + sampler = ds.SequentialSampler() + with pytest.raises(ValueError) as info: + data2 = ds.ManifestDataset(manifest_file, sampler=sampler, num_samples=20) + assert "Conflicting arguments during sampler assignments" in str(info.value) + if __name__ == '__main__': test_sequential_sampler(True)