forked from mindspore-Ecosystem/mindspore
fix num_sample in sequentialSampler and randomSampler
This commit is contained in:
parent
eaa3fe98ed
commit
ee042b90f7
|
@ -35,33 +35,33 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c
|
|||
if (per_ > kEpsilon && per_ <= 1.0f) {
|
||||
return dataset_size * kEpsilon;
|
||||
}
|
||||
return no_of_samples_;
|
||||
return std::min(static_cast<int64_t>(no_of_samples_), dataset_size);
|
||||
}
|
||||
|
||||
MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) {
|
||||
int total_no = static_cast<int>(tasks.Size());
|
||||
int taking;
|
||||
int64_t total_no = static_cast<int64_t>(tasks.Size());
|
||||
int64_t taking;
|
||||
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
|
||||
taking = total_no;
|
||||
} else if (per_ > kEpsilon && per_ <= 1.0f) {
|
||||
taking = total_no * kEpsilon;
|
||||
} else {
|
||||
taking = no_of_samples_;
|
||||
taking = std::min(static_cast<int64_t>(no_of_samples_), total_no);
|
||||
}
|
||||
|
||||
if (tasks.permutation_.empty()) {
|
||||
ShardTask new_tasks;
|
||||
total_no = static_cast<int>(tasks.Size());
|
||||
for (int i = offset_; i < taking + offset_; ++i) {
|
||||
total_no = static_cast<int64_t>(tasks.Size());
|
||||
for (size_t i = offset_; i < taking + offset_; ++i) {
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
} else { // shuffled
|
||||
ShardTask new_tasks;
|
||||
if (taking > static_cast<int>(tasks.permutation_.size())) {
|
||||
if (taking > static_cast<int64_t>(tasks.permutation_.size())) {
|
||||
return FAILED;
|
||||
}
|
||||
total_no = static_cast<int>(tasks.permutation_.size());
|
||||
total_no = static_cast<int64_t>(tasks.permutation_.size());
|
||||
for (size_t i = offset_; i < taking + offset_; ++i) {
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
|||
if (replacement_) {
|
||||
return no_of_samples_ == 0 ? dataset_size : no_of_samples_;
|
||||
}
|
||||
return dataset_size;
|
||||
return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
|
||||
}
|
||||
|
||||
MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
|
||||
|
@ -67,6 +67,14 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
|
|||
std::swap(tasks, new_tasks);
|
||||
} else {
|
||||
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
|
||||
auto total_no = static_cast<int64_t>(tasks.Size());
|
||||
if (no_of_samples_ > 0 && no_of_samples_ < total_no) {
|
||||
ShardTask new_tasks;
|
||||
for (size_t i = 0; i < no_of_samples_; ++i) {
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i));
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
}
|
||||
}
|
||||
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
|
||||
uint32_t individual_size = tasks.Size() / tasks.categories;
|
||||
|
|
|
@ -311,7 +311,7 @@ class Unique(cde.UniqueOp):
|
|||
Call batch op before calling this function.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||
>>> import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||
>>>
|
||||
>>> # Data before
|
||||
>>> # | x |
|
||||
|
|
|
@ -208,7 +208,7 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess6) {
|
|||
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
|
||||
std::vector<std::string> file_list = {file_path1};
|
||||
|
||||
// Check sequential sampler, output number is 10, with duplicate samples(a little weird, wait to fix)
|
||||
// Check sequential sampler, output number is 5
|
||||
std::shared_ptr<Dataset> ds1 = MindData(file_list, {}, SequentialSampler(0, 10));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
|
@ -229,7 +229,7 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess6) {
|
|||
EXPECT_NE(ds5, nullptr);
|
||||
|
||||
std::vector<std::shared_ptr<Dataset>> ds = {ds1, ds2, ds3, ds4, ds5};
|
||||
std::vector<int32_t> expected_samples = {10, 5, 2, 3, 3};
|
||||
std::vector<int32_t> expected_samples = {5, 5, 2, 3, 3};
|
||||
|
||||
for (int32_t i = 0; i < ds.size(); i++) {
|
||||
// Create an iterator over the result of the above dataset
|
||||
|
|
|
@ -412,6 +412,46 @@ def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file):
|
|||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
||||
def test_cv_minddataset_random_sampler_replacement_false_1(add_and_remove_cv_file):
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.RandomSampler(replacement=False, num_samples=2)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
assert data_set.get_dataset_size() == 2
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
|
||||
def test_cv_minddataset_random_sampler_replacement_false_2(add_and_remove_cv_file):
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.RandomSampler(replacement=False, num_samples=20)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
assert data_set.get_dataset_size() == 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
|
@ -437,7 +477,7 @@ def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
|
|||
assert num_iter == 4
|
||||
|
||||
|
||||
def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
|
||||
def test_cv_minddataset_sequential_sampler_offeset(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
|
@ -461,6 +501,30 @@ def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
|
|||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.SequentialSampler(2, 20)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
dataset_size = data_set.get_dataset_size()
|
||||
assert dataset_size == 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info(
|
||||
"-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
logger.info(
|
||||
"-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
assert item['file_name'] == np.array(
|
||||
data[(num_iter + 2) % dataset_size]['file_name'], dtype='S')
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
def test_cv_minddataset_split_basic(add_and_remove_cv_file):
|
||||
data = get_data(CV_DIR_NAME, True)
|
||||
|
|
Loading…
Reference in New Issue