!158 fix: resolve MindDataset hung when field not in index when using block_reader

Merge pull request !158 from guozhijian/fix_block_reader_hung
This commit is contained in:
mindspore-ci-bot 2020-04-07 21:49:33 +08:00 committed by Gitee
commit 475e858474
6 changed files with 43 additions and 8 deletions

View File

@ -785,6 +785,8 @@ vector<std::string> ShardReader::GetAllColumns() {
MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
vector<std::string> columns = GetAllColumns();
CheckIfColumnInIndex(columns);
for (const auto &rg : row_group_summary) {
auto shard_id = std::get<0>(rg);
auto group_id = std::get<1>(rg);

View File

@ -143,6 +143,7 @@ class FileWriter:
ParamTypeError: If index field is invalid.
MRMDefineIndexError: If index field is not primitive type.
MRMAddIndexError: If failed to add index field.
MRMGetMetaError: If the schema is not set or get meta failed.
"""
if not index_fields or not isinstance(index_fields, list):
raise ParamTypeError('index_fields', 'list')

View File

@ -24,7 +24,7 @@ from mindspore import log as logger
from .cifar100 import Cifar100
from ..common.exceptions import PathNotExistsError
from ..filewriter import FileWriter
from ..shardutils import check_filename
from ..shardutils import check_filename, SUCCESS
try:
cv2 = import_module("cv2")
except ModuleNotFoundError:
@ -98,8 +98,11 @@ class Cifar100ToMR:
data_list = _construct_raw_data(images, fine_labels, coarse_labels)
test_data_list = _construct_raw_data(test_images, test_fine_labels, test_coarse_labels)
_generate_mindrecord(self.destination, data_list, fields, "img_train")
_generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test")
if _generate_mindrecord(self.destination, data_list, fields, "img_train") != SUCCESS:
return FAILED
if _generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test") != SUCCESS:
return FAILED
return SUCCESS
def _construct_raw_data(images, fine_labels, coarse_labels):
"""

View File

@ -47,7 +47,9 @@ def add_and_remove_cv_file():
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
data = get_data(CV_DIR_NAME)
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
cv_schema_json = {"id": {"type": "int32"},
"file_name": {"type": "string"},
"label": {"type": "int32"},
"data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
@ -226,6 +228,24 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file):
num_iter += 1
assert num_iter == 20
def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["id", "data", "label"]
num_readers = 4
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False,
block_reader=True)
assert data_set.get_dataset_size() == 10
repeat_num = 2
data_set = data_set.repeat(repeat_num)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter))
logger.info("-------------- item[id]: {} ----------------------------".format(item["id"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
num_iter += 1
assert num_iter == 20
def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
@ -359,13 +379,14 @@ def get_data(dir_name):
lines = file_reader.readlines()
data_list = []
for line in lines:
for i, line in enumerate(lines):
try:
filename, label = line.split(",")
label = label.strip("\n")
with open(os.path.join(img_dir, filename), "rb") as file_reader:
img = file_reader.read()
data_json = {"file_name": filename,
data_json = {"id": i,
"file_name": filename,
"data": img,
"label": int(label)}
data_list.append(data_json)

View File

@ -18,6 +18,7 @@ import pytest
from mindspore.mindrecord import Cifar100ToMR
from mindspore.mindrecord import FileReader
from mindspore.mindrecord import MRMOpenError
from mindspore.mindrecord import SUCCESS
from mindspore import log as logger
CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
@ -26,7 +27,8 @@ MINDRECORD_FILE = "./cifar100.mindrecord"
def test_cifar100_to_mindrecord_without_index_fields():
"""test transform cifar100 dataset to mindrecord without index fields."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
cifar100_transformer.transform()
ret = cifar100_transformer.transform()
assert ret == SUCCESS, "Failed to tranform from cifar100 to mindrecord"
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()

View File

@ -16,7 +16,7 @@
import os
import pytest
from mindspore.mindrecord import FileWriter, FileReader, MindPage
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError
from mindspore import log as logger
from utils import get_data
@ -280,3 +280,9 @@ def test_cv_file_writer_shard_num_greater_than_1000():
with pytest.raises(ParamValueError) as err:
FileWriter(CV_FILE_NAME, 1001)
assert 'Shard number should between' in str(err.value)
def test_add_index_without_add_schema():
with pytest.raises(MRMGetMetaError) as err:
fw = FileWriter(CV_FILE_NAME)
fw.add_index(["label"])
assert 'Failed to get meta info' in str(err.value)