diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/mindrecord/include/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/mindrecord/include/bindings.cc index 5c785fcd375..9d3a36f5349 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/mindrecord/include/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/mindrecord/include/bindings.cc @@ -48,12 +48,12 @@ PYBIND_REGISTER( ShardPkSample, 1, ([](const py::module *m) { (void)py::class_>( *m, "MindrecordPkSampler") - .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { + .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle, int64_t num_samples) { if (shuffle == true) { return std::make_shared(kColumn, kVal, std::numeric_limits::max(), - GetSeed()); + GetSeed(), num_samples); } else { - return std::make_shared(kColumn, kVal); + return std::make_shared(kColumn, kVal, num_samples); } })); })); diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h index 04f47db358f..1adb7bbdcd9 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h @@ -29,19 +29,23 @@ namespace mindspore { namespace mindrecord { class ShardPkSample : public ShardCategory { public: - ShardPkSample(const std::string &category_field, int64_t num_elements); + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples); - ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, int64_t num_samples); - ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed, + int64_t num_samples); ~ShardPkSample() override{}; MSRStatus SufExecute(ShardTask &tasks) override; + int64_t GetNumSamples() const { return num_samples_; } + private: bool shuffle_; std::shared_ptr shuffle_op_; + int64_t num_samples_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h index e08375b10fe..6f185d5a4e4 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h @@ -49,6 +49,7 @@ #include "minddata/mindrecord/include/shard_error.h" #include "minddata/mindrecord/include/shard_index_generator.h" #include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_pk_sample.h" #include "minddata/mindrecord/include/shard_reader.h" #include "minddata/mindrecord/include/shard_sample.h" #include "minddata/mindrecord/include/shard_shuffle.h" diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h index 6074a036da2..a9507f0fb54 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h @@ -53,7 +53,8 @@ class ShardTask { std::tuple, std::vector, json> &GetRandomTask(); - static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements); + static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements, + int64_t num_samples); uint32_t categories; diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index c42b732463b..3aa40434f06 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -827,6 +827,12 @@ MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths std::string category_field = category_op->GetCategoryField(); auto num_classes = GetNumClasses(category_field); num_samples = category_op->GetNumSamples(num_samples, num_classes); + if (std::dynamic_pointer_cast(op)) { + auto tmp = std::dynamic_pointer_cast(op)->GetNumSamples(); + if (tmp != 0) { + num_samples = std::min(num_samples, tmp); + } + } } else if (std::dynamic_pointer_cast(op)) { if (std::dynamic_pointer_cast(op)) { auto sampler_op = std::dynamic_pointer_cast(op); @@ -958,6 +964,14 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector(op); auto categories = category_op->GetCategories(); int64_t num_elements = category_op->GetNumElements(); + int64_t num_samples = 0; + if (std::dynamic_pointer_cast(op)) { + num_samples = std::dynamic_pointer_cast(op)->GetNumSamples(); + if (num_samples < 0) { + MS_LOG(ERROR) << "Parameter num_samples is not positive or zero"; + return FAILED; + } + } if (num_elements <= 0) { MS_LOG(ERROR) << "Parameter num_element is not positive"; return FAILED; @@ -1006,7 +1020,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vectorGetReplacement(), num_elements); + tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples); if (SUCCESS != (*category_op)(tasks_)) { return FAILED; } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc index 081a48352de..ed4cf019dc7 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc @@ -22,15 +22,18 @@ using mindspore::MsLogLevel::ERROR; namespace mindspore { namespace mindrecord { -ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) - : ShardCategory(category_field, num_elements, std::numeric_limits::max(), true), shuffle_(false) {} - -ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) - : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples) + : ShardCategory(category_field, num_elements, std::numeric_limits::max(), true), + shuffle_(false), + num_samples_(num_samples) {} ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, - uint32_t seed) - : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { + int64_t num_samples) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false), num_samples_(num_samples) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, + uint32_t seed, int64_t num_samples) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true), num_samples_(num_samples) { shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc index 972e3b2d14b..bfacc90ce60 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc @@ -86,7 +86,8 @@ std::tuple, std::vector, json> &ShardTa return task_list_[dis(gen)]; } -ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements) { +ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements, + int64_t num_samples) { ShardTask res; if (category_tasks.empty()) return res; auto total_categories = category_tasks.size(); @@ -96,9 +97,12 @@ ShardTask ShardTask::Combine(std::vector &category_tasks, bool replac for (uint32_t i = 1; i < total_categories; i++) { minTasks = std::min(minTasks, category_tasks[i].Size()); } + int64_t count = 0; for (uint32_t task_no = 0; task_no < minTasks; task_no++) { for (uint32_t i = 0; i < total_categories; i++) { + if (num_samples != 0 && count == num_samples) break; res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast(task_no)))); + count++; } } } else { @@ -109,9 +113,12 @@ ShardTask ShardTask::Combine(std::vector &category_tasks, bool replac if (num_elements != std::numeric_limits::max()) { maxTasks = static_cast(num_elements); } + int64_t count = 0; for (uint32_t i = 0; i < total_categories; i++) { for (uint32_t j = 0; j < maxTasks; j++) { + if (num_samples != 0 && count == num_samples) break; res.InsertTask(category_tasks[i].GetRandomTask()); + count++; } } } diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 22c0e44d0d4..1cc7efb914f 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -359,7 +359,8 @@ class PKSampler(BuiltinSampler): if not self.class_column or not isinstance(self.class_column, str): raise ValueError("class_column should be a not empty string value, \ but got class_column={}".format(class_column)) - c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples) c_child_sampler = self.create_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler diff --git a/mindspore/mindrecord/tools/tfrecord_to_mr.py b/mindspore/mindrecord/tools/tfrecord_to_mr.py index c351f4307f4..7b716979429 100644 --- a/mindspore/mindrecord/tools/tfrecord_to_mr.py +++ b/mindspore/mindrecord/tools/tfrecord_to_mr.py @@ -104,7 +104,7 @@ class TFRecordToMR: Args: source (str): the TFRecord file to be transformed. destination (str): the MindRecord file path to tranform into. - feature_dict (dict): a dictionary than states the feature type, i.e. + feature_dict (dict): a dictionary that states the feature type, i.e. feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ "yyyy": tf.io.FixedLenFeature([], tf.int64)} diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 2137fb4a13f..7b9186ac37d 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -162,7 +162,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { auto column_list = std::vector{"file_name", "label"}; std::vector> ops; - ops.push_back(std::make_shared("label", 2)); + ops.push_back(std::make_shared("label", 2, 0)); ShardReader dataset; dataset.Open({file_name},true, 4, column_list, ops); @@ -187,7 +187,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { auto column_list = std::vector{"file_name", "label"}; std::vector> ops; - ops.push_back(std::make_shared("label", 2, 3, 0)); + ops.push_back(std::make_shared("label", 2, 3, 0, 0)); ShardReader dataset; dataset.Open({file_name},true, 4, column_list, ops); @@ -204,7 +204,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { } dataset.Finish(); ASSERT_TRUE(i == 6); -} // namespace mindrecord +} TEST_F(TestShardOperator, TestShardCategory) { MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index 9c110c0e1f8..b60302c3e37 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -101,7 +101,6 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 - def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] @@ -120,9 +119,51 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): logger.info( "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 + assert num_iter == 9 -def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): +def test_cv_minddataset_pk_sample_shuffle_1(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(3, None, True, 'label', 5) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + + assert data_set.get_dataset_size() == 5 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format(to_str(item["file_name"]))) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter == 5 + +def test_cv_minddataset_pk_sample_shuffle_2(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(3, None, True, 'label', 10) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + + assert data_set.get_dataset_size() == 9 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format(to_str(item["file_name"]))) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter == 9 + + +def test_cv_minddataset_pk_sample_out_of_range_0(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] num_readers = 4 @@ -139,6 +180,45 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): logger.info( "-------------- item[label]: {} ----------------------------".format(item["label"])) num_iter += 1 + assert num_iter == 15 + +def test_cv_minddataset_pk_sample_out_of_range_1(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(5, None, True, 'label', 20) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 15 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format(to_str(item["file_name"]))) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter == 15 + +def test_cv_minddataset_pk_sample_out_of_range_2(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(5, None, True, 'label', 10) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 10 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info( + "-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format(to_str(item["file_name"]))) + logger.info( + "-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + assert num_iter == 10 def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):