forked from mindspore-Ecosystem/mindspore
fix distributedSampler reshuffle and fix random_device failed
This commit is contained in:
parent
23d0497df6
commit
2412ee09ce
|
@ -19,13 +19,16 @@
|
|||
#if defined(_WIN32) || defined(_WIN64)
|
||||
#include <stdlib.h>
|
||||
#endif
|
||||
#include <chrono>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -35,6 +38,17 @@ inline std::mt19937 GetRandomDevice() {
|
|||
rand_s(&number);
|
||||
std::mt19937 random_device{static_cast<uint32_t>(number)};
|
||||
#else
|
||||
int i = 0;
|
||||
while (i < 5) {
|
||||
try {
|
||||
std::mt19937 random_device{std::random_device("/dev/urandom")()};
|
||||
return random_device;
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(WARNING) << "Get std::random_device failed, retry: " << i << ", error: " << e.what();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
i++;
|
||||
}
|
||||
}
|
||||
std::mt19937 random_device{std::random_device("/dev/urandom")()};
|
||||
#endif
|
||||
return random_device;
|
||||
|
|
|
@ -44,8 +44,8 @@ class ShardDistributedSample : public ShardSample {
|
|||
private:
|
||||
bool shuffle_;
|
||||
int no_of_padded_samples_;
|
||||
|
||||
bool init_judgment_; // we should judge the (num_sample + num_padded) % num_shards == 0 in first time
|
||||
bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch
|
||||
ShardTask task_; // maintain the input tasks in first epoch
|
||||
};
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDRECORD_INCLUDE_SHARD_TASK_H_
|
||||
#define MINDRECORD_INCLUDE_SHARD_TASK_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
@ -27,6 +28,14 @@ namespace mindspore {
|
|||
namespace mindrecord {
|
||||
class ShardTask {
|
||||
public:
|
||||
ShardTask();
|
||||
|
||||
ShardTask(const ShardTask &task); // copy construction
|
||||
|
||||
ShardTask &operator=(const ShardTask &task); // assignment operator
|
||||
|
||||
~ShardTask() = default;
|
||||
|
||||
void MakePerm();
|
||||
|
||||
void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
||||
|
@ -46,10 +55,11 @@ class ShardTask {
|
|||
|
||||
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements);
|
||||
|
||||
uint32_t categories = 1;
|
||||
uint32_t categories;
|
||||
|
||||
std::vector<int> permutation_;
|
||||
|
||||
std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_;
|
||||
std::vector<int> permutation_;
|
||||
};
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1434,14 +1434,15 @@ void ShardReader::ShuffleTask() {
|
|||
for (const auto &op : operators_) {
|
||||
if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
|
||||
if (SUCCESS != (*op)(tasks_)) {
|
||||
MS_LOG(WARNING) << "Reshuffle reader tasks failed.";
|
||||
MS_LOG(WARNING) << "Redo randomSampler failed.";
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
if (SUCCESS != op->PreExecute(tasks_)) {
|
||||
MS_LOG(WARNING) << "Distribute reshuffle reader tasks failed.";
|
||||
if (SUCCESS != (*op)(tasks_)) {
|
||||
MS_LOG(WARNING) << "Redo distributeSampler failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (tasks_.permutation_.empty()) tasks_.MakePerm();
|
||||
}
|
||||
|
||||
} // namespace mindrecord
|
||||
|
|
|
@ -27,7 +27,7 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int
|
|||
: ShardSample(1, num_shards, shard_id),
|
||||
shuffle_(shuffle),
|
||||
no_of_padded_samples_(no_of_padded_samples),
|
||||
init_judgment_(false) {
|
||||
first_epoch_(true) {
|
||||
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample);
|
||||
}
|
||||
|
||||
|
@ -54,8 +54,7 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_
|
|||
|
||||
MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) {
|
||||
auto total_no = tasks.Size();
|
||||
if (no_of_padded_samples_ > 0 && init_judgment_ == false) { // we only judge this in first time
|
||||
init_judgment_ = true;
|
||||
if (no_of_padded_samples_ > 0 && first_epoch_) {
|
||||
if (total_no % denominator_ != 0) {
|
||||
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. "
|
||||
<< "task size: " << total_no << ", number padded: " << no_of_padded_samples_
|
||||
|
@ -63,6 +62,12 @@ MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) {
|
|||
return FAILED;
|
||||
}
|
||||
}
|
||||
if (first_epoch_) {
|
||||
first_epoch_ = false;
|
||||
task_ = tasks;
|
||||
} else {
|
||||
tasks = task_;
|
||||
}
|
||||
if (shuffle_ == true) {
|
||||
if (SUCCESS != (*shuffle_op_)(tasks)) {
|
||||
return FAILED;
|
||||
|
|
|
@ -43,6 +43,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
|||
}
|
||||
|
||||
MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
|
||||
if (reshuffle_each_epoch_) shuffle_seed_++;
|
||||
if (tasks.categories < 1) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -81,7 +82,6 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
|
|||
}
|
||||
}
|
||||
}
|
||||
if (reshuffle_each_epoch_) shuffle_seed_++;
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
|
|
|
@ -24,6 +24,19 @@ using mindspore::MsLogLevel::DEBUG;
|
|||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardTask::ShardTask() : categories(1) {}
|
||||
|
||||
ShardTask::ShardTask(const ShardTask &other)
|
||||
: categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {}
|
||||
|
||||
ShardTask &ShardTask::operator=(const ShardTask &other) {
|
||||
ShardTask tmp(other);
|
||||
std::swap(categories, tmp.categories);
|
||||
permutation_.swap(tmp.permutation_);
|
||||
task_list_.swap(tmp.task_list_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void ShardTask::MakePerm() {
|
||||
permutation_ = std::vector<int>(task_list_.size());
|
||||
for (uint32_t i = 0; i < task_list_.size(); i++) {
|
||||
|
|
|
@ -278,6 +278,41 @@ def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_c
|
|||
epoch3 = []
|
||||
|
||||
|
||||
def test_cv_minddataset_partition_tutorial_check_whole_reshuffle_result_per_epoch(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
num_shards = 3
|
||||
epoch_result = [[["", "", "", ""], ["", "", "", ""], ["", "", "", ""]], # save partition 0 result
|
||||
[["", "", "", ""], ["", "", "", ""], ["", "", "", ""]], # save partition 1 result
|
||||
[["", "", "", ""], ["", "", "", ""], ["", "", "", ""]]] # svae partition 2 result
|
||||
|
||||
for partition_id in range(num_shards):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
num_shards=num_shards, shard_id=partition_id)
|
||||
|
||||
data_set = data_set.repeat(3)
|
||||
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- partition : {} ------------------------".format(partition_id))
|
||||
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
|
||||
# total 3 partition, 4 result per epoch, total 12 result
|
||||
epoch_result[partition_id][int(num_iter / 4)][num_iter % 4] = item["file_name"] # save epoch result
|
||||
num_iter += 1
|
||||
assert num_iter == 12
|
||||
assert epoch_result[partition_id][0] not in (epoch_result[partition_id][1], epoch_result[partition_id][2])
|
||||
assert epoch_result[partition_id][1] not in (epoch_result[partition_id][0], epoch_result[partition_id][2])
|
||||
assert epoch_result[partition_id][2] not in (epoch_result[partition_id][1], epoch_result[partition_id][0])
|
||||
epoch_result[partition_id][0].sort()
|
||||
epoch_result[partition_id][1].sort()
|
||||
epoch_result[partition_id][2].sort()
|
||||
assert epoch_result[partition_id][0] != epoch_result[partition_id][1]
|
||||
assert epoch_result[partition_id][1] != epoch_result[partition_id][2]
|
||||
assert epoch_result[partition_id][2] != epoch_result[partition_id][0]
|
||||
|
||||
|
||||
def test_cv_minddataset_check_shuffle_result(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
|
|
|
@ -468,6 +468,64 @@ def test_nlp_minddataset_reader_basic_padded_samples_multi_epoch(add_and_remove_
|
|||
partitions(5, 5, 3)
|
||||
partitions(9, 8, 2)
|
||||
|
||||
|
||||
def test_nlp_minddataset_reader_basic_padded_samples_check_whole_reshuffle_result_per_epoch(add_and_remove_nlp_file):
|
||||
columns_list = ["input_ids", "id", "rating"]
|
||||
|
||||
padded_sample = {}
|
||||
padded_sample['id'] = "-1"
|
||||
padded_sample['input_ids'] = np.array([-1,-1,-1,-1], dtype=np.int64)
|
||||
padded_sample['rating'] = 1.0
|
||||
num_readers = 4
|
||||
repeat_size = 3
|
||||
|
||||
def partitions(num_shards, num_padded, dataset_size):
|
||||
num_padded_iter = 0
|
||||
num_iter = 0
|
||||
|
||||
epoch_result = [[["" for i in range(dataset_size)] for i in range(repeat_size)] for i in range(num_shards)]
|
||||
|
||||
for partition_id in range(num_shards):
|
||||
data_set = ds.MindDataset(NLP_FILE_NAME + "0", columns_list, num_readers,
|
||||
num_shards=num_shards,
|
||||
shard_id=partition_id,
|
||||
padded_sample=padded_sample,
|
||||
num_padded=num_padded)
|
||||
assert data_set.get_dataset_size() == dataset_size
|
||||
data_set = data_set.repeat(repeat_size)
|
||||
inner_num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
|
||||
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
|
||||
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------"
|
||||
.format(item["input_ids"], item["input_ids"].shape))
|
||||
if item['id'] == bytes('-1', encoding='utf-8'):
|
||||
num_padded_iter += 1
|
||||
assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
|
||||
assert (item['input_ids'] == padded_sample['input_ids']).all()
|
||||
assert (item['rating'] == padded_sample['rating']).all()
|
||||
# save epoch result
|
||||
epoch_result[partition_id][int(inner_num_iter / dataset_size)][inner_num_iter % dataset_size] = item["id"]
|
||||
num_iter += 1
|
||||
inner_num_iter += 1
|
||||
assert epoch_result[partition_id][0] not in (epoch_result[partition_id][1], epoch_result[partition_id][2])
|
||||
assert epoch_result[partition_id][1] not in (epoch_result[partition_id][0], epoch_result[partition_id][2])
|
||||
assert epoch_result[partition_id][2] not in (epoch_result[partition_id][1], epoch_result[partition_id][0])
|
||||
if dataset_size > 2:
|
||||
epoch_result[partition_id][0].sort()
|
||||
epoch_result[partition_id][1].sort()
|
||||
epoch_result[partition_id][2].sort()
|
||||
assert epoch_result[partition_id][0] != epoch_result[partition_id][1]
|
||||
assert epoch_result[partition_id][1] != epoch_result[partition_id][2]
|
||||
assert epoch_result[partition_id][2] != epoch_result[partition_id][0]
|
||||
assert num_padded_iter == num_padded * repeat_size
|
||||
assert num_iter == dataset_size * num_shards * repeat_size
|
||||
|
||||
partitions(4, 6, 4)
|
||||
partitions(5, 5, 3)
|
||||
partitions(9, 8, 2)
|
||||
|
||||
|
||||
def get_data(dir_name):
|
||||
"""
|
||||
usage: get data from imagenet dataset
|
||||
|
|
|
@ -586,6 +586,13 @@ def test_cv_minddataset_split_sharding(add_and_remove_cv_file):
|
|||
assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
|
||||
assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)
|
||||
|
||||
epoch1_dataset.sort()
|
||||
epoch2_dataset.sort()
|
||||
epoch3_dataset.sort()
|
||||
assert epoch1_dataset != epoch2_dataset
|
||||
assert epoch2_dataset != epoch3_dataset
|
||||
assert epoch3_dataset != epoch1_dataset
|
||||
|
||||
|
||||
def get_data(dir_name, sampler=False):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue