forked from mindspore-Ecosystem/mindspore
fix: when use MindDataset block_reade=True hung
This commit is contained in:
parent
9e17b996c7
commit
c688265671
|
@ -785,6 +785,8 @@ vector<std::string> ShardReader::GetAllColumns() {
|
|||
|
||||
MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
||||
vector<std::string> columns = GetAllColumns();
|
||||
CheckIfColumnInIndex(columns);
|
||||
for (const auto &rg : row_group_summary) {
|
||||
auto shard_id = std::get<0>(rg);
|
||||
auto group_id = std::get<1>(rg);
|
||||
|
|
|
@ -143,6 +143,7 @@ class FileWriter:
|
|||
ParamTypeError: If index field is invalid.
|
||||
MRMDefineIndexError: If index field is not primitive type.
|
||||
MRMAddIndexError: If failed to add index field.
|
||||
MRMGetMetaError: If the schema is not set or get meta failed.
|
||||
"""
|
||||
if not index_fields or not isinstance(index_fields, list):
|
||||
raise ParamTypeError('index_fields', 'list')
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore import log as logger
|
|||
from .cifar100 import Cifar100
|
||||
from ..common.exceptions import PathNotExistsError
|
||||
from ..filewriter import FileWriter
|
||||
from ..shardutils import check_filename
|
||||
from ..shardutils import check_filename, SUCCESS
|
||||
try:
|
||||
cv2 = import_module("cv2")
|
||||
except ModuleNotFoundError:
|
||||
|
@ -98,8 +98,11 @@ class Cifar100ToMR:
|
|||
data_list = _construct_raw_data(images, fine_labels, coarse_labels)
|
||||
test_data_list = _construct_raw_data(test_images, test_fine_labels, test_coarse_labels)
|
||||
|
||||
_generate_mindrecord(self.destination, data_list, fields, "img_train")
|
||||
_generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test")
|
||||
if _generate_mindrecord(self.destination, data_list, fields, "img_train") != SUCCESS:
|
||||
return FAILED
|
||||
if _generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test") != SUCCESS:
|
||||
return FAILED
|
||||
return SUCCESS
|
||||
|
||||
def _construct_raw_data(images, fine_labels, coarse_labels):
|
||||
"""
|
||||
|
|
|
@ -47,7 +47,9 @@ def add_and_remove_cv_file():
|
|||
os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
|
||||
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
||||
data = get_data(CV_DIR_NAME)
|
||||
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
|
||||
cv_schema_json = {"id": {"type": "int32"},
|
||||
"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"},
|
||||
"data": {"type": "bytes"}}
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
|
@ -226,6 +228,24 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file):
|
|||
num_iter += 1
|
||||
assert num_iter == 20
|
||||
|
||||
def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file):
|
||||
"""tutorial for cv minddataset."""
|
||||
columns_list = ["id", "data", "label"]
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False,
|
||||
block_reader=True)
|
||||
assert data_set.get_dataset_size() == 10
|
||||
repeat_num = 2
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter))
|
||||
logger.info("-------------- item[id]: {} ----------------------------".format(item["id"]))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 20
|
||||
|
||||
|
||||
def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
|
@ -359,13 +379,14 @@ def get_data(dir_name):
|
|||
lines = file_reader.readlines()
|
||||
|
||||
data_list = []
|
||||
for line in lines:
|
||||
for i, line in enumerate(lines):
|
||||
try:
|
||||
filename, label = line.split(",")
|
||||
label = label.strip("\n")
|
||||
with open(os.path.join(img_dir, filename), "rb") as file_reader:
|
||||
img = file_reader.read()
|
||||
data_json = {"file_name": filename,
|
||||
data_json = {"id": i,
|
||||
"file_name": filename,
|
||||
"data": img,
|
||||
"label": int(label)}
|
||||
data_list.append(data_json)
|
||||
|
|
|
@ -18,6 +18,7 @@ import pytest
|
|||
from mindspore.mindrecord import Cifar100ToMR
|
||||
from mindspore.mindrecord import FileReader
|
||||
from mindspore.mindrecord import MRMOpenError
|
||||
from mindspore.mindrecord import SUCCESS
|
||||
from mindspore import log as logger
|
||||
|
||||
CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
|
||||
|
@ -26,7 +27,8 @@ MINDRECORD_FILE = "./cifar100.mindrecord"
|
|||
def test_cifar100_to_mindrecord_without_index_fields():
|
||||
"""test transform cifar100 dataset to mindrecord without index fields."""
|
||||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
|
||||
cifar100_transformer.transform()
|
||||
ret = cifar100_transformer.transform()
|
||||
assert ret == SUCCESS, "Failed to tranform from cifar100 to mindrecord"
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import os
|
||||
import pytest
|
||||
from mindspore.mindrecord import FileWriter, FileReader, MindPage
|
||||
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError
|
||||
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError
|
||||
from mindspore import log as logger
|
||||
from utils import get_data
|
||||
|
||||
|
@ -280,3 +280,9 @@ def test_cv_file_writer_shard_num_greater_than_1000():
|
|||
with pytest.raises(ParamValueError) as err:
|
||||
FileWriter(CV_FILE_NAME, 1001)
|
||||
assert 'Shard number should between' in str(err.value)
|
||||
|
||||
def test_add_index_without_add_schema():
|
||||
with pytest.raises(MRMGetMetaError) as err:
|
||||
fw = FileWriter(CV_FILE_NAME)
|
||||
fw.add_index(["label"])
|
||||
assert 'Failed to get meta info' in str(err.value)
|
||||
|
|
Loading…
Reference in New Issue