fix: when use MindDataset block_reade=True hung

This commit is contained in:
jonyguo 2020-04-03 16:53:45 +08:00
parent 9e17b996c7
commit c688265671
6 changed files with 43 additions and 8 deletions

View File

@ -785,6 +785,8 @@ vector<std::string> ShardReader::GetAllColumns() {
MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
vector<std::string> columns = GetAllColumns();
CheckIfColumnInIndex(columns);
for (const auto &rg : row_group_summary) {
auto shard_id = std::get<0>(rg);
auto group_id = std::get<1>(rg);

View File

@ -143,6 +143,7 @@ class FileWriter:
ParamTypeError: If index field is invalid.
MRMDefineIndexError: If index field is not primitive type.
MRMAddIndexError: If failed to add index field.
MRMGetMetaError: If the schema is not set or get meta failed.
"""
if not index_fields or not isinstance(index_fields, list):
raise ParamTypeError('index_fields', 'list')

View File

@ -24,7 +24,7 @@ from mindspore import log as logger
from .cifar100 import Cifar100
from ..common.exceptions import PathNotExistsError
from ..filewriter import FileWriter
from ..shardutils import check_filename
from ..shardutils import check_filename, SUCCESS
try:
cv2 = import_module("cv2")
except ModuleNotFoundError:
@ -98,8 +98,11 @@ class Cifar100ToMR:
data_list = _construct_raw_data(images, fine_labels, coarse_labels)
test_data_list = _construct_raw_data(test_images, test_fine_labels, test_coarse_labels)
_generate_mindrecord(self.destination, data_list, fields, "img_train")
_generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test")
if _generate_mindrecord(self.destination, data_list, fields, "img_train") != SUCCESS:
return FAILED
if _generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test") != SUCCESS:
return FAILED
return SUCCESS
def _construct_raw_data(images, fine_labels, coarse_labels):
"""

View File

@ -47,7 +47,9 @@ def add_and_remove_cv_file():
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
data = get_data(CV_DIR_NAME)
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
cv_schema_json = {"id": {"type": "int32"},
"file_name": {"type": "string"},
"label": {"type": "int32"},
"data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
@ -226,6 +228,24 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file):
num_iter += 1
assert num_iter == 20
def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["id", "data", "label"]
num_readers = 4
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False,
block_reader=True)
assert data_set.get_dataset_size() == 10
repeat_num = 2
data_set = data_set.repeat(repeat_num)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter))
logger.info("-------------- item[id]: {} ----------------------------".format(item["id"]))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
num_iter += 1
assert num_iter == 20
def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
@ -359,13 +379,14 @@ def get_data(dir_name):
lines = file_reader.readlines()
data_list = []
for line in lines:
for i, line in enumerate(lines):
try:
filename, label = line.split(",")
label = label.strip("\n")
with open(os.path.join(img_dir, filename), "rb") as file_reader:
img = file_reader.read()
data_json = {"file_name": filename,
data_json = {"id": i,
"file_name": filename,
"data": img,
"label": int(label)}
data_list.append(data_json)

View File

@ -18,6 +18,7 @@ import pytest
from mindspore.mindrecord import Cifar100ToMR
from mindspore.mindrecord import FileReader
from mindspore.mindrecord import MRMOpenError
from mindspore.mindrecord import SUCCESS
from mindspore import log as logger
CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
@ -26,7 +27,8 @@ MINDRECORD_FILE = "./cifar100.mindrecord"
def test_cifar100_to_mindrecord_without_index_fields():
"""test transform cifar100 dataset to mindrecord without index fields."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
cifar100_transformer.transform()
ret = cifar100_transformer.transform()
assert ret == SUCCESS, "Failed to tranform from cifar100 to mindrecord"
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()

View File

@ -16,7 +16,7 @@
import os
import pytest
from mindspore.mindrecord import FileWriter, FileReader, MindPage
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError
from mindspore import log as logger
from utils import get_data
@ -280,3 +280,9 @@ def test_cv_file_writer_shard_num_greater_than_1000():
with pytest.raises(ParamValueError) as err:
FileWriter(CV_FILE_NAME, 1001)
assert 'Shard number should between' in str(err.value)
def test_add_index_without_add_schema():
with pytest.raises(MRMGetMetaError) as err:
fw = FileWriter(CV_FILE_NAME)
fw.add_index(["label"])
assert 'Failed to get meta info' in str(err.value)