forked from mindspore-Ecosystem/mindspore
!1989 fix MindDataset distribute shuffle error
Merge pull request !1989 from guozhijian/fix_distribute_shuffle
This commit is contained in:
commit
251a6667a5
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindrecord/include/shard_distributed_sample.h"
|
||||
#include "mindrecord/include/shard_reader.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
|
@ -1385,9 +1386,18 @@ void ShardReader::Reset() {
|
|||
|
||||
void ShardReader::ShuffleTask() {
|
||||
for (const auto &op : operators_) {
|
||||
if (block_reader_ || !std::dynamic_pointer_cast<ShardShuffle>(op)) continue;
|
||||
if (SUCCESS != (*op)(tasks_)) {
|
||||
MS_LOG(WARNING) << "Reshuffle reader tasks failed.";
|
||||
if (block_reader_) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
|
||||
if (SUCCESS != (*op)(tasks_)) {
|
||||
MS_LOG(WARNING) << "Reshuffle reader tasks failed.";
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
if (SUCCESS != op->PreExecute(tasks_)) {
|
||||
MS_LOG(WARNING) << "Distribute reshuffle reader tasks failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -238,6 +238,139 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
|
|||
assert partitions(9) == 2
|
||||
|
||||
|
||||
def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
num_shards = 3
|
||||
epoch1 = []
|
||||
epoch2 = []
|
||||
epoch3 = []
|
||||
|
||||
for partition_id in range(num_shards):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
num_shards=num_shards, shard_id=partition_id)
|
||||
|
||||
data_set = data_set.repeat(3)
|
||||
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- partition : {} ------------------------".format(partition_id))
|
||||
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
if num_iter <= 4:
|
||||
epoch1.append(item["file_name"]) # save epoch 1 list
|
||||
elif num_iter <= 8:
|
||||
epoch2.append(item["file_name"]) # save epoch 2 list
|
||||
else:
|
||||
epoch3.append(item["file_name"]) # save epoch 3 list
|
||||
assert num_iter == 12
|
||||
assert len(epoch1) == 4
|
||||
assert len(epoch2) == 4
|
||||
assert len(epoch3) == 4
|
||||
assert epoch1 not in (epoch2, epoch3)
|
||||
assert epoch2 not in (epoch1, epoch3)
|
||||
assert epoch3 not in (epoch1, epoch2)
|
||||
epoch1 = []
|
||||
epoch2 = []
|
||||
epoch3 = []
|
||||
|
||||
|
||||
def test_cv_minddataset_check_shuffle_result(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
|
||||
ds.config.set_seed(54321)
|
||||
epoch1 = []
|
||||
epoch2 = []
|
||||
epoch3 = []
|
||||
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
|
||||
data_set = data_set.repeat(3)
|
||||
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
if num_iter <= 10:
|
||||
epoch1.append(item["file_name"]) # save epoch 1 list
|
||||
elif num_iter <= 20:
|
||||
epoch2.append(item["file_name"]) # save epoch 2 list
|
||||
else:
|
||||
epoch3.append(item["file_name"]) # save epoch 3 list
|
||||
assert num_iter == 30
|
||||
assert len(epoch1) == 10
|
||||
assert len(epoch2) == 10
|
||||
assert len(epoch3) == 10
|
||||
assert epoch1 not in (epoch2, epoch3)
|
||||
assert epoch2 not in (epoch1, epoch3)
|
||||
assert epoch3 not in (epoch1, epoch2)
|
||||
|
||||
epoch1_new_dataset = []
|
||||
epoch2_new_dataset = []
|
||||
epoch3_new_dataset = []
|
||||
|
||||
data_set2 = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
|
||||
data_set2 = data_set2.repeat(3)
|
||||
|
||||
num_iter = 0
|
||||
for item in data_set2.create_dict_iterator():
|
||||
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
if num_iter <= 10:
|
||||
epoch1_new_dataset.append(item["file_name"]) # save epoch 1 list
|
||||
elif num_iter <= 20:
|
||||
epoch2_new_dataset.append(item["file_name"]) # save epoch 2 list
|
||||
else:
|
||||
epoch3_new_dataset.append(item["file_name"]) # save epoch 3 list
|
||||
assert num_iter == 30
|
||||
assert len(epoch1_new_dataset) == 10
|
||||
assert len(epoch2_new_dataset) == 10
|
||||
assert len(epoch3_new_dataset) == 10
|
||||
assert epoch1_new_dataset not in (epoch2_new_dataset, epoch3_new_dataset)
|
||||
assert epoch2_new_dataset not in (epoch1_new_dataset, epoch3_new_dataset)
|
||||
assert epoch3_new_dataset not in (epoch1_new_dataset, epoch2_new_dataset)
|
||||
|
||||
assert epoch1 == epoch1_new_dataset
|
||||
assert epoch2 == epoch2_new_dataset
|
||||
assert epoch3 == epoch3_new_dataset
|
||||
|
||||
ds.config.set_seed(12345)
|
||||
epoch1_new_dataset2 = []
|
||||
epoch2_new_dataset2 = []
|
||||
epoch3_new_dataset2 = []
|
||||
|
||||
data_set3 = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
|
||||
data_set3 = data_set3.repeat(3)
|
||||
|
||||
num_iter = 0
|
||||
for item in data_set3.create_dict_iterator():
|
||||
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
|
||||
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
if num_iter <= 10:
|
||||
epoch1_new_dataset2.append(item["file_name"]) # save epoch 1 list
|
||||
elif num_iter <= 20:
|
||||
epoch2_new_dataset2.append(item["file_name"]) # save epoch 2 list
|
||||
else:
|
||||
epoch3_new_dataset2.append(item["file_name"]) # save epoch 3 list
|
||||
assert num_iter == 30
|
||||
assert len(epoch1_new_dataset2) == 10
|
||||
assert len(epoch2_new_dataset2) == 10
|
||||
assert len(epoch3_new_dataset2) == 10
|
||||
assert epoch1_new_dataset2 not in (epoch2_new_dataset2, epoch3_new_dataset2)
|
||||
assert epoch2_new_dataset2 not in (epoch1_new_dataset2, epoch3_new_dataset2)
|
||||
assert epoch3_new_dataset2 not in (epoch1_new_dataset2, epoch2_new_dataset2)
|
||||
|
||||
assert epoch1 != epoch1_new_dataset2
|
||||
assert epoch2 != epoch2_new_dataset2
|
||||
assert epoch3 != epoch3_new_dataset2
|
||||
|
||||
|
||||
def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
|
|
Loading…
Reference in New Issue