fix num_sample in sequentialSampler and randomSampler

This commit is contained in:
liyong 2020-10-23 18:29:58 +08:00
parent eaa3fe98ed
commit ee042b90f7
5 changed files with 85 additions and 13 deletions

View File

@ -35,33 +35,33 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c
if (per_ > kEpsilon && per_ <= 1.0f) {
return dataset_size * kEpsilon;
}
return no_of_samples_;
return std::min(static_cast<int64_t>(no_of_samples_), dataset_size);
}
MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) {
int total_no = static_cast<int>(tasks.Size());
int taking;
int64_t total_no = static_cast<int64_t>(tasks.Size());
int64_t taking;
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
taking = total_no;
} else if (per_ > kEpsilon && per_ <= 1.0f) {
taking = total_no * kEpsilon;
} else {
taking = no_of_samples_;
taking = std::min(static_cast<int64_t>(no_of_samples_), total_no);
}
if (tasks.permutation_.empty()) {
ShardTask new_tasks;
total_no = static_cast<int>(tasks.Size());
for (int i = offset_; i < taking + offset_; ++i) {
total_no = static_cast<int64_t>(tasks.Size());
for (size_t i = offset_; i < taking + offset_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
}
std::swap(tasks, new_tasks);
} else { // shuffled
ShardTask new_tasks;
if (taking > static_cast<int>(tasks.permutation_.size())) {
if (taking > static_cast<int64_t>(tasks.permutation_.size())) {
return FAILED;
}
total_no = static_cast<int>(tasks.permutation_.size());
total_no = static_cast<int64_t>(tasks.permutation_.size());
for (size_t i = offset_; i < taking + offset_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
}

View File

@ -39,7 +39,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (replacement_) {
return no_of_samples_ == 0 ? dataset_size : no_of_samples_;
}
return dataset_size;
return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
}
MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
@ -67,6 +67,14 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
std::swap(tasks, new_tasks);
} else {
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
auto total_no = static_cast<int64_t>(tasks.Size());
if (no_of_samples_ > 0 && no_of_samples_ < total_no) {
ShardTask new_tasks;
for (size_t i = 0; i < no_of_samples_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(i));
}
std::swap(tasks, new_tasks);
}
}
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
uint32_t individual_size = tasks.Size() / tasks.categories;

View File

@ -311,7 +311,7 @@ class Unique(cde.UniqueOp):
Call batch op before calling this function.
Examples:
>>> import mindspore.dataset.transforms.c_transforms as c_transforms
>>> import mindspore.dataset.transforms.c_transforms as c_transforms
>>>
>>> # Data before
>>> # | x |

View File

@ -208,7 +208,7 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess6) {
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::vector<std::string> file_list = {file_path1};
// Check sequential sampler, output number is 10, with duplicate samples(a little weird, wait to fix)
// Check sequential sampler, output number is 5
std::shared_ptr<Dataset> ds1 = MindData(file_list, {}, SequentialSampler(0, 10));
EXPECT_NE(ds1, nullptr);
@ -229,7 +229,7 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess6) {
EXPECT_NE(ds5, nullptr);
std::vector<std::shared_ptr<Dataset>> ds = {ds1, ds2, ds3, ds4, ds5};
std::vector<int32_t> expected_samples = {10, 5, 2, 3, 3};
std::vector<int32_t> expected_samples = {5, 5, 2, 3, 3};
for (int32_t i = 0; i < ds.size(); i++) {
// Create an iterator over the result of the above dataset

View File

@ -412,6 +412,46 @@ def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file):
num_iter += 1
assert num_iter == 5
def test_cv_minddataset_random_sampler_replacement_false_1(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.RandomSampler(replacement=False, num_samples=2)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 2
num_iter = 0
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 2
def test_cv_minddataset_random_sampler_replacement_false_2(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.RandomSampler(replacement=False, num_samples=20)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 10
num_iter = 0
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 10
def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
@ -437,7 +477,7 @@ def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
assert num_iter == 4
def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
def test_cv_minddataset_sequential_sampler_offeset(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
@ -461,6 +501,30 @@ def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
num_iter += 1
assert num_iter == 10
def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.SequentialSampler(2, 20)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
dataset_size = data_set.get_dataset_size()
assert dataset_size == 10
num_iter = 0
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info(
"-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info(
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
assert item['file_name'] == np.array(
data[(num_iter + 2) % dataset_size]['file_name'], dtype='S')
num_iter += 1
assert num_iter == 10
def test_cv_minddataset_split_basic(add_and_remove_cv_file):
data = get_data(CV_DIR_NAME, True)