From e2ea1fa0dfbc78ee00b6da9954a43230772ed695 Mon Sep 17 00:00:00 2001 From: liyong Date: Thu, 16 Jul 2020 19:07:29 +0800 Subject: [PATCH] activate num_samples in distributed samplers --- .../minddata/dataset/api/python_bindings.cc | 2 +- .../include/shard_distributed_sample.h | 5 +- .../mindrecord/include/shard_sample.h | 2 +- .../meta/shard_distributed_sample.cc | 9 +-- .../minddata/mindrecord/meta/shard_sample.cc | 10 ++- mindspore/dataset/engine/samplers.py | 4 +- tests/ut/python/dataset/test_minddataset.py | 66 +++++++++++++++++++ .../dataset/test_minddataset_exception.py | 21 ++++++ 8 files changed, 108 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc index 08016ee0613..94c4ec40d70 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc @@ -784,7 +784,7 @@ void bindSamplerOps(py::module *m) { (void)py::class_>(*m, "MindrecordDistributedSampler") - .def(py::init()); + .def(py::init()); (void)py::class_>( *m, "MindrecordRandomSampler") diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h index f166ec1e6c6..9244c16f9f5 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h @@ -29,9 +29,10 @@ namespace mindspore { namespace mindrecord { class ShardDistributedSample : public ShardSample { public: - ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); + ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed, + int no_of_samples = 0); - ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed); + ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, int no_of_samples = 0); void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; } diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h index ce813bc4bf4..c3d695e8e8c 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h @@ -32,7 +32,7 @@ class ShardSample : public ShardOperator { ShardSample(int num, int den); - ShardSample(int num, int den, int par); + ShardSample(int num, int den, int par, int no_of_samples = 0); ShardSample(const std::vector &indices, uint32_t seed); diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc index 4c7abbb4b48..6bc1c1408d4 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc @@ -23,16 +23,17 @@ using mindspore::MsLogLevel::ERROR; namespace mindspore { namespace mindrecord { ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, - uint32_t seed) - : ShardSample(1, num_shards, shard_id), + uint32_t seed, int no_of_samples) + : ShardSample(1, num_shards, shard_id, no_of_samples), shuffle_(shuffle), no_of_padded_samples_(no_of_padded_samples), first_epoch_(true) { shuffle_op_ = std::make_shared(seed, kShuffleSample); } -ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed) - : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {} +ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, + int no_of_samples) + : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed, no_of_samples) {} int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { if (no_of_padded_samples_ <= 0) { diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc index 808ab55bfbe..b8be83735b7 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc @@ -38,11 +38,11 @@ ShardSample::ShardSample(int num, int den) indices_({}), sampler_type_(kCustomTopPercentSampler) {} -ShardSample::ShardSample(int num, int den, int par) +ShardSample::ShardSample(int num, int den, int par, int no_of_samples) : numerator_(num), denominator_(den), partition_id_(par), - no_of_samples_(0), + no_of_samples_(no_of_samples), indices_({}), sampler_type_(kCustomTopPercentSampler) {} @@ -110,8 +110,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python } } else { + int count = 0; for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { + if (no_of_samples_ != 0 && count == no_of_samples_) break; new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start + count++; } } std::swap(tasks, new_tasks); @@ -121,8 +124,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { return FAILED; } total_no = static_cast(tasks.permutation_.size()); + int count = 0; for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { + if (no_of_samples_ != 0 && count == no_of_samples_) break; new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); + count++; } std::swap(tasks, new_tasks); } diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index b74874f9cf3..22c0e44d0d4 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -270,7 +270,9 @@ class DistributedSampler(BuiltinSampler): return c_sampler def create_for_minddataset(self): - c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) + num_samples = self.num_samples if self.num_samples is not None else 0 + c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, + self.seed, num_samples) c_child_sampler = self.create_child_for_minddataset() c_sampler.add_child(c_child_sampler) return c_sampler diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index 7d613d414f6..8d22bd6c50f 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -238,6 +238,72 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): assert partitions(5) == 2 assert partitions(9) == 2 +def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + + def partitions(num_shards): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, num_samples=1) + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + return num_iter + + assert partitions(4) == 1 + assert partitions(5) == 1 + assert partitions(9) == 1 + +def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + + def partitions(num_shards): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, num_samples=2) + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + return num_iter + + assert partitions(4) == 2 + assert partitions(5) == 2 + assert partitions(9) == 2 + +def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file): + """tutorial for cv minddataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + + def partitions(num_shards): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, num_samples=3) + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- partition : {} ------------------------".format(partition_id)) + logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) + logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) + num_iter += 1 + return num_iter + + assert partitions(4) == 3 + assert partitions(5) == 2 + assert partitions(9) == 2 + def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file): """tutorial for cv minddataset.""" diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index 0b4d0dfc8fe..0bfb7a03427 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -228,3 +228,24 @@ def test_minddataset_shard_id_bigger_than_num_shard(): os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) + +def test_cv_minddataset_partition_num_samples_equals_0(): + """tutorial for cv minddataset.""" + create_cv_mindrecord(1) + columns_list = ["data", "label"] + num_readers = 4 + + def partitions(num_shards): + for partition_id in range(num_shards): + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, + num_shards=num_shards, + shard_id=partition_id, num_samples=0) + num_iter = 0 + for _ in data_set.create_dict_iterator(): + num_iter += 1 + with pytest.raises(Exception) as error_info: + partitions(5) + assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info) + + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME))