!3133 [MD] fix num_sample in distributed sampler

Merge pull request !3133 from liyong126/md_num_of_samples
This commit is contained in:
mindspore-ci-bot 2020-07-18 15:36:33 +08:00 committed by Gitee
commit 4945d34a41
8 changed files with 108 additions and 11 deletions

View File

@ -784,7 +784,7 @@ void bindSamplerOps(py::module *m) {
(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t>());
.def(py::init<int64_t, int64_t, bool, uint32_t, int64_t>());
(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
*m, "MindrecordRandomSampler")

View File

@ -29,9 +29,10 @@ namespace mindspore {
namespace mindrecord {
class ShardDistributedSample : public ShardSample {
public:
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed);
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed,
int no_of_samples = 0);
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed);
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, int no_of_samples = 0);
void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; }

View File

@ -32,7 +32,7 @@ class ShardSample : public ShardOperator {
ShardSample(int num, int den);
ShardSample(int num, int den, int par);
ShardSample(int num, int den, int par, int no_of_samples = 0);
ShardSample(const std::vector<int64_t> &indices, uint32_t seed);

View File

@ -23,16 +23,17 @@ using mindspore::MsLogLevel::ERROR;
namespace mindspore {
namespace mindrecord {
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle,
uint32_t seed)
: ShardSample(1, num_shards, shard_id),
uint32_t seed, int no_of_samples)
: ShardSample(1, num_shards, shard_id, no_of_samples),
shuffle_(shuffle),
no_of_padded_samples_(no_of_padded_samples),
first_epoch_(true) {
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample);
}
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed)
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {}
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed,
int no_of_samples)
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed, no_of_samples) {}
int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (no_of_padded_samples_ <= 0) {

View File

@ -38,11 +38,11 @@ ShardSample::ShardSample(int num, int den)
indices_({}),
sampler_type_(kCustomTopPercentSampler) {}
ShardSample::ShardSample(int num, int den, int par)
ShardSample::ShardSample(int num, int den, int par, int no_of_samples)
: numerator_(num),
denominator_(den),
partition_id_(par),
no_of_samples_(0),
no_of_samples_(no_of_samples),
indices_({}),
sampler_type_(kCustomTopPercentSampler) {}
@ -110,8 +110,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python
}
} else {
int count = 0;
for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
if (no_of_samples_ != 0 && count == no_of_samples_) break;
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start
count++;
}
}
std::swap(tasks, new_tasks);
@ -121,8 +124,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
return FAILED;
}
total_no = static_cast<int>(tasks.permutation_.size());
int count = 0;
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
if (no_of_samples_ != 0 && count == no_of_samples_) break;
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
count++;
}
std::swap(tasks, new_tasks);
}

View File

@ -270,7 +270,9 @@ class DistributedSampler(BuiltinSampler):
return c_sampler
def create_for_minddataset(self):
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
self.seed, num_samples)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler

View File

@ -238,6 +238,72 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
assert partitions(5) == 2
assert partitions(9) == 2
def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=1)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
num_iter += 1
return num_iter
assert partitions(4) == 1
assert partitions(5) == 1
assert partitions(9) == 1
def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=2)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
num_iter += 1
return num_iter
assert partitions(4) == 2
assert partitions(5) == 2
assert partitions(9) == 2
def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=3)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
num_iter += 1
return num_iter
assert partitions(4) == 3
assert partitions(5) == 2
assert partitions(9) == 2
def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file):
"""tutorial for cv minddataset."""

View File

@ -228,3 +228,24 @@ def test_minddataset_shard_id_bigger_than_num_shard():
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
def test_cv_minddataset_partition_num_samples_equals_0():
"""tutorial for cv minddataset."""
create_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=0)
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
with pytest.raises(Exception) as error_info:
partitions(5)
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info)
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))