!17208 Fix mindrecord UTs: files existed causes write exception

From: @luoyang42
Reviewed-by: @liucunwei,@jonyguo
Signed-off-by: @liucunwei
This commit is contained in:
mindspore-ci-bot 2021-05-29 09:10:55 +08:00 committed by Gitee
commit 89e9f1c99f
1 changed files with 114 additions and 56 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,8 +29,24 @@ CV4_FILE_NAME = "/tmp/imagenet_append.mindrecord"
NLP_FILE_NAME = "./aclImdb.mindrecord"
def remove_one_file(file):
if os.path.exists(file):
os.remove(file)
def remove_multi_files(file_name, file_num):
paths = ["{}{}".format(file_name, str(x).rjust(1, '0'))
for x in range(file_num)]
for x in paths:
remove_one_file("{}".format(x))
remove_one_file("{}.db".format(x))
def test_write_read_process():
mindrecord_file_name = "test.mindrecord"
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64),
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
"data": bytes("image bytes abc", encoding='UTF-8')},
@ -75,12 +91,15 @@ def test_write_read_process():
assert count == 6
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
remove_one_file("{}".format(mindrecord_file_name))
remove_one_file("{}.db".format(mindrecord_file_name))
def test_write_read_process_with_define_index_field():
mindrecord_file_name = "test.mindrecord"
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64),
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
"data": bytes("image bytes abc", encoding='UTF-8')},
@ -126,12 +145,13 @@ def test_write_read_process_with_define_index_field():
assert count == 6
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
remove_one_file("{}".format(mindrecord_file_name))
remove_one_file("{}.db".format(mindrecord_file_name))
def test_cv_file_writer_tutorial():
def test_cv_file_writer_tutorial(remove_file=True):
"""tutorial for cv dataset writer."""
remove_multi_files(CV_FILE_NAME, FILES_NUM)
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
@ -140,10 +160,13 @@ def test_cv_file_writer_tutorial():
writer.add_index(["file_name", "label"])
writer.write_raw_data(data)
writer.commit()
if remove_file:
remove_multi_files(CV_FILE_NAME, FILES_NUM)
def test_cv_file_append_writer():
"""tutorial for cv dataset append writer."""
remove_multi_files(CV3_FILE_NAME, 4)
writer = FileWriter(CV3_FILE_NAME, 4)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
@ -164,15 +187,12 @@ def test_cv_file_append_writer():
assert count == 10
reader.close()
paths = ["{}{}".format(CV3_FILE_NAME, str(x).rjust(1, '0'))
for x in range(4)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
remove_multi_files(CV3_FILE_NAME, 4)
def test_cv_file_append_writer_absolute_path():
"""tutorial for cv dataset append writer."""
remove_multi_files(CV4_FILE_NAME, 4)
writer = FileWriter(CV4_FILE_NAME, 4)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
@ -193,15 +213,12 @@ def test_cv_file_append_writer_absolute_path():
assert count == 10
reader.close()
paths = ["{}{}".format(CV4_FILE_NAME, str(x).rjust(1, '0'))
for x in range(4)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
remove_multi_files(CV4_FILE_NAME, 4)
def test_cv_file_writer_loop_and_read():
"""tutorial for cv dataset loop writer."""
remove_multi_files(CV2_FILE_NAME, FILES_NUM)
writer = FileWriter(CV2_FILE_NAME, FILES_NUM)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
@ -221,15 +238,14 @@ def test_cv_file_writer_loop_and_read():
assert count == 10
reader.close()
paths = ["{}{}".format(CV2_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
remove_multi_files(CV2_FILE_NAME, FILES_NUM)
def test_cv_file_reader_tutorial():
"""tutorial for cv file reader."""
remove_multi_files(CV_FILE_NAME, FILES_NUM)
test_cv_file_writer_tutorial(remove_file=False)
reader = FileReader(CV_FILE_NAME + "0")
count = 0
for index, x in enumerate(reader.get_next()):
@ -239,9 +255,14 @@ def test_cv_file_reader_tutorial():
assert count == 10
reader.close()
remove_multi_files(CV_FILE_NAME, FILES_NUM)
def test_cv_file_reader_file_list():
"""tutorial for cv file partial reader."""
remove_multi_files(CV_FILE_NAME, FILES_NUM)
test_cv_file_writer_tutorial(remove_file=False)
reader = FileReader([CV_FILE_NAME + str(x) for x in range(FILES_NUM)])
count = 0
for index, x in enumerate(reader.get_next()):
@ -250,9 +271,14 @@ def test_cv_file_reader_file_list():
logger.info("#item{}: {}".format(index, x))
assert count == 10
remove_multi_files(CV_FILE_NAME, FILES_NUM)
def test_cv_file_reader_partial_tutorial():
"""tutorial for cv file partial reader."""
remove_multi_files(CV_FILE_NAME, FILES_NUM)
test_cv_file_writer_tutorial(remove_file=False)
reader = FileReader(CV_FILE_NAME + "0")
count = 0
for index, x in enumerate(reader.get_next()):
@ -263,9 +289,14 @@ def test_cv_file_reader_partial_tutorial():
reader.close()
assert count == 5
remove_multi_files(CV_FILE_NAME, FILES_NUM)
def test_cv_page_reader_tutorial():
"""tutorial for cv page reader."""
remove_multi_files(CV_FILE_NAME, FILES_NUM)
test_cv_file_writer_tutorial(remove_file=False)
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.get_category_fields()
assert fields == ['file_name', 'label'], \
@ -287,9 +318,14 @@ def test_cv_page_reader_tutorial():
assert len(row1[0]) == 3
assert row1[0]['label'] == 822
remove_multi_files(CV_FILE_NAME, FILES_NUM)
def test_cv_page_reader_tutorial_by_file_name():
"""tutorial for cv page reader."""
remove_multi_files(CV_FILE_NAME, FILES_NUM)
test_cv_file_writer_tutorial(remove_file=False)
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.get_category_fields()
assert fields == ['file_name', 'label'], \
@ -311,9 +347,14 @@ def test_cv_page_reader_tutorial_by_file_name():
assert len(row1[0]) == 3
assert row1[0]['label'] == 13
remove_multi_files(CV_FILE_NAME, FILES_NUM)
def test_cv_page_reader_tutorial_new_api():
"""tutorial for cv page reader."""
remove_multi_files(CV_FILE_NAME, FILES_NUM)
test_cv_file_writer_tutorial(remove_file=False)
reader = MindPage(CV_FILE_NAME + "0")
fields = reader.candidate_fields
assert fields == ['file_name', 'label'], \
@ -334,15 +375,12 @@ def test_cv_page_reader_tutorial_new_api():
assert len(row1[0]) == 3
assert row1[0]['label'] == 13
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
remove_multi_files(CV_FILE_NAME, FILES_NUM)
def test_nlp_file_writer_tutorial():
def test_nlp_file_writer_tutorial(remove_file=True):
"""tutorial for nlp file writer."""
remove_multi_files(NLP_FILE_NAME, FILES_NUM)
writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
data = list(get_nlp_data("../data/mindrecord/testAclImdbData/pos",
"../data/mindrecord/testAclImdbData/vocab.txt",
@ -360,10 +398,14 @@ def test_nlp_file_writer_tutorial():
writer.add_index(["id", "rating"])
writer.write_raw_data(data)
writer.commit()
if remove_file:
remove_multi_files(NLP_FILE_NAME, FILES_NUM)
def test_nlp_file_reader_tutorial():
"""tutorial for nlp file reader."""
remove_multi_files(NLP_FILE_NAME, FILES_NUM)
test_nlp_file_writer_tutorial(remove_file=False)
reader = FileReader(NLP_FILE_NAME + "0")
count = 0
for index, x in enumerate(reader.get_next()):
@ -372,10 +414,14 @@ def test_nlp_file_reader_tutorial():
logger.info("#item{}: {}".format(index, x))
assert count == 10
reader.close()
remove_multi_files(NLP_FILE_NAME, FILES_NUM)
def test_nlp_page_reader_tutorial():
"""tutorial for nlp page reader."""
remove_multi_files(NLP_FILE_NAME, FILES_NUM)
test_nlp_file_writer_tutorial(remove_file=False)
reader = MindPage(NLP_FILE_NAME + "0")
fields = reader.get_category_fields()
assert fields == ['id', 'rating'], \
@ -397,15 +443,12 @@ def test_nlp_page_reader_tutorial():
assert len(row1[0]) == 6
logger.info("row1[0]: {}".format(row1[0]))
paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
remove_multi_files(NLP_FILE_NAME, FILES_NUM)
def test_cv_file_writer_shard_num_10():
"""test file writer when shard num equals 10."""
remove_multi_files(CV_FILE_NAME, 10)
writer = FileWriter(CV_FILE_NAME, 10)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
@ -415,16 +458,13 @@ def test_cv_file_writer_shard_num_10():
writer.write_raw_data(data)
writer.commit()
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(10)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
remove_multi_files(CV_FILE_NAME, 10)
def test_cv_file_writer_absolute_path():
"""test cv file writer when file name is absolute path."""
file_name = "/tmp/" + str(uuid.uuid4())
remove_multi_files(file_name, FILES_NUM)
writer = FileWriter(file_name, FILES_NUM)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
@ -434,15 +474,14 @@ def test_cv_file_writer_absolute_path():
writer.write_raw_data(data)
writer.commit()
paths = ["{}{}".format(file_name, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
remove_multi_files(file_name, FILES_NUM)
def test_cv_file_writer_without_data():
"""test cv file writer without data."""
remove_one_file(CV_FILE_NAME)
remove_one_file(CV_FILE_NAME + ".db")
writer = FileWriter(CV_FILE_NAME, 1)
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
@ -456,12 +495,15 @@ def test_cv_file_writer_without_data():
logger.info("#item{}: {}".format(index, x))
assert count == 0
reader.close()
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
remove_one_file(CV_FILE_NAME)
remove_one_file(CV_FILE_NAME + ".db")
def test_cv_file_writer_no_blob():
"""test cv file writer without blob data."""
remove_one_file(CV_FILE_NAME)
remove_one_file(CV_FILE_NAME + ".db")
writer = FileWriter(CV_FILE_NAME, 1)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
@ -478,12 +520,15 @@ def test_cv_file_writer_no_blob():
logger.info("#item{}: {}".format(index, x))
assert count == 10
reader.close()
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
remove_one_file(CV_FILE_NAME)
remove_one_file(CV_FILE_NAME + ".db")
def test_cv_file_writer_no_raw():
"""test cv file writer without raw data."""
remove_one_file(NLP_FILE_NAME)
remove_one_file(NLP_FILE_NAME + ".db")
writer = FileWriter(NLP_FILE_NAME)
data = list(get_nlp_data("../data/mindrecord/testAclImdbData/pos",
"../data/mindrecord/testAclImdbData/vocab.txt",
@ -506,12 +551,15 @@ def test_cv_file_writer_no_raw():
logger.info("#item{}: {}".format(index, x))
assert count == 10
reader.close()
os.remove(NLP_FILE_NAME)
os.remove("{}.db".format(NLP_FILE_NAME))
remove_one_file(NLP_FILE_NAME)
remove_one_file(NLP_FILE_NAME + ".db")
def test_write_read_process_with_multi_bytes():
mindrecord_file_name = "test.mindrecord"
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")
data = [{"file_name": "001.jpg", "label": 43,
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
"image2": bytes("image1 bytes def", encoding='UTF-8'),
@ -631,12 +679,15 @@ def test_write_read_process_with_multi_bytes():
assert count == 6
reader5.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")
def test_write_read_process_with_multi_array():
mindrecord_file_name = "test.mindrecord"
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")
data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64),
@ -775,12 +826,15 @@ def test_write_read_process_with_multi_array():
assert count == 6
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")
def test_write_read_process_with_multi_bytes_and_array():
mindrecord_file_name = "test.mindrecord"
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")
data = [{"file_name": "001.jpg", "label": 4,
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
"image2": bytes("image1 bytes def", encoding='UTF-8'),
@ -962,11 +1016,15 @@ def test_write_read_process_with_multi_bytes_and_array():
assert count == 6
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")
def test_write_read_process_without_ndarray_type():
mindrecord_file_name = "test.mindrecord"
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")
# field: mask derivation type is int64, but schema type is int32
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9]),
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
@ -998,5 +1056,5 @@ def test_write_read_process_without_ndarray_type():
assert count == 1
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")