!3133 [MD] fix num_sample in distributed sampler
Merge pull request !3133 from liyong126/md_num_of_samples
This commit is contained in:
commit
4945d34a41
|
@ -784,7 +784,7 @@ void bindSamplerOps(py::module *m) {
|
|||
|
||||
(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
|
||||
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, bool, uint32_t>());
|
||||
.def(py::init<int64_t, int64_t, bool, uint32_t, int64_t>());
|
||||
|
||||
(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
|
||||
*m, "MindrecordRandomSampler")
|
||||
|
|
|
@ -29,9 +29,10 @@ namespace mindspore {
|
|||
namespace mindrecord {
|
||||
class ShardDistributedSample : public ShardSample {
|
||||
public:
|
||||
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed);
|
||||
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed,
|
||||
int no_of_samples = 0);
|
||||
|
||||
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed);
|
||||
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, int no_of_samples = 0);
|
||||
|
||||
void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; }
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class ShardSample : public ShardOperator {
|
|||
|
||||
ShardSample(int num, int den);
|
||||
|
||||
ShardSample(int num, int den, int par);
|
||||
ShardSample(int num, int den, int par, int no_of_samples = 0);
|
||||
|
||||
ShardSample(const std::vector<int64_t> &indices, uint32_t seed);
|
||||
|
||||
|
|
|
@ -23,16 +23,17 @@ using mindspore::MsLogLevel::ERROR;
|
|||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle,
|
||||
uint32_t seed)
|
||||
: ShardSample(1, num_shards, shard_id),
|
||||
uint32_t seed, int no_of_samples)
|
||||
: ShardSample(1, num_shards, shard_id, no_of_samples),
|
||||
shuffle_(shuffle),
|
||||
no_of_padded_samples_(no_of_padded_samples),
|
||||
first_epoch_(true) {
|
||||
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample);
|
||||
}
|
||||
|
||||
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed)
|
||||
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {}
|
||||
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed,
|
||||
int no_of_samples)
|
||||
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed, no_of_samples) {}
|
||||
|
||||
int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||
if (no_of_padded_samples_ <= 0) {
|
||||
|
|
|
@ -38,11 +38,11 @@ ShardSample::ShardSample(int num, int den)
|
|||
indices_({}),
|
||||
sampler_type_(kCustomTopPercentSampler) {}
|
||||
|
||||
ShardSample::ShardSample(int num, int den, int par)
|
||||
ShardSample::ShardSample(int num, int den, int par, int no_of_samples)
|
||||
: numerator_(num),
|
||||
denominator_(den),
|
||||
partition_id_(par),
|
||||
no_of_samples_(0),
|
||||
no_of_samples_(no_of_samples),
|
||||
indices_({}),
|
||||
sampler_type_(kCustomTopPercentSampler) {}
|
||||
|
||||
|
@ -110,8 +110,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
|
|||
new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python
|
||||
}
|
||||
} else {
|
||||
int count = 0;
|
||||
for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||
if (no_of_samples_ != 0 && count == no_of_samples_) break;
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start
|
||||
count++;
|
||||
}
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
|
@ -121,8 +124,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
|
|||
return FAILED;
|
||||
}
|
||||
total_no = static_cast<int>(tasks.permutation_.size());
|
||||
int count = 0;
|
||||
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||
if (no_of_samples_ != 0 && count == no_of_samples_) break;
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
|
||||
count++;
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
}
|
||||
|
|
|
@ -270,7 +270,9 @@ class DistributedSampler(BuiltinSampler):
|
|||
return c_sampler
|
||||
|
||||
def create_for_minddataset(self):
|
||||
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
|
||||
self.seed, num_samples)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
|
|
@ -238,6 +238,72 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
|
|||
assert partitions(5) == 2
|
||||
assert partitions(9) == 2
|
||||
|
||||
def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
|
||||
def partitions(num_shards):
|
||||
for partition_id in range(num_shards):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
num_shards=num_shards,
|
||||
shard_id=partition_id, num_samples=1)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- partition : {} ------------------------".format(partition_id))
|
||||
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
return num_iter
|
||||
|
||||
assert partitions(4) == 1
|
||||
assert partitions(5) == 1
|
||||
assert partitions(9) == 1
|
||||
|
||||
def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
|
||||
def partitions(num_shards):
|
||||
for partition_id in range(num_shards):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
num_shards=num_shards,
|
||||
shard_id=partition_id, num_samples=2)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- partition : {} ------------------------".format(partition_id))
|
||||
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
return num_iter
|
||||
|
||||
assert partitions(4) == 2
|
||||
assert partitions(5) == 2
|
||||
assert partitions(9) == 2
|
||||
|
||||
def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
|
||||
def partitions(num_shards):
|
||||
for partition_id in range(num_shards):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
num_shards=num_shards,
|
||||
shard_id=partition_id, num_samples=3)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- partition : {} ------------------------".format(partition_id))
|
||||
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
return num_iter
|
||||
|
||||
assert partitions(4) == 3
|
||||
assert partitions(5) == 2
|
||||
assert partitions(9) == 2
|
||||
|
||||
|
||||
def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
|
|
|
@ -228,3 +228,24 @@ def test_minddataset_shard_id_bigger_than_num_shard():
|
|||
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
def test_cv_minddataset_partition_num_samples_equals_0():
|
||||
"""tutorial for cv minddataset."""
|
||||
create_cv_mindrecord(1)
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
|
||||
def partitions(num_shards):
|
||||
for partition_id in range(num_shards):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
|
||||
num_shards=num_shards,
|
||||
shard_id=partition_id, num_samples=0)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
with pytest.raises(Exception) as error_info:
|
||||
partitions(5)
|
||||
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info)
|
||||
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
|
Loading…
Reference in New Issue