forked from mindspore-Ecosystem/mindspore
fix md ut part05
This commit is contained in:
parent
00bc83db01
commit
7c9f6899a3
|
@ -1997,6 +1997,7 @@ def create_multi_mindrecord_files():
|
|||
os.remove("{}.db".format(filename))
|
||||
|
||||
def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
|
||||
ds.config.set_seed(1)
|
||||
datas_all = []
|
||||
index = 0
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
|
@ -2233,6 +2234,7 @@ def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
|
|||
assert origin_index != current_index
|
||||
|
||||
def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_files):
|
||||
ds.config.set_seed(1)
|
||||
datas_all = []
|
||||
datas_all_samples = []
|
||||
index = 0
|
||||
|
@ -2424,6 +2426,7 @@ def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_fi
|
|||
assert origin_index != current_index
|
||||
|
||||
def test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files):
|
||||
ds.config.set_seed(1)
|
||||
datas_all = []
|
||||
datas_all_samples = []
|
||||
index = 0
|
||||
|
|
|
@ -22,47 +22,49 @@ from mindspore.mindrecord import FileReader
|
|||
from mindspore.mindrecord import SUCCESS
|
||||
|
||||
CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
|
||||
MINDRECORD_FILE = "./cifar100.mindrecord"
|
||||
|
||||
def remove_file(x):
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
if os.path.exists("{}_test".format(x)):
|
||||
os.remove("{}_test".format(x))
|
||||
if os.path.exists("{}_test.db".format(x)):
|
||||
os.remove("{}_test.db".format(x))
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_file():
|
||||
"""add/remove file"""
|
||||
def remove_file(x):
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
if os.path.exists("{}_test".format(x)):
|
||||
os.remove("{}_test".format(x))
|
||||
if os.path.exists("{}_test.db".format(x)):
|
||||
os.remove("{}_test.db".format(x))
|
||||
|
||||
remove_file(MINDRECORD_FILE)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
remove_file(file_name)
|
||||
yield "yield_fixture_data"
|
||||
remove_file(MINDRECORD_FILE)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_cifar100_to_mindrecord_without_index_fields(fixture_file):
|
||||
"""test transform cifar100 dataset to mindrecord without index fields."""
|
||||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, file_name)
|
||||
ret = cifar100_transformer.transform()
|
||||
assert ret == SUCCESS, "Failed to transform from cifar100 to mindrecord"
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
assert os.path.exists(file_name)
|
||||
assert os.path.exists(file_name + "_test")
|
||||
read(file_name)
|
||||
|
||||
def test_cifar100_to_mindrecord(fixture_file):
|
||||
"""test transform cifar100 dataset to mindrecord."""
|
||||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, file_name)
|
||||
cifar100_transformer.transform(['fine_label', 'coarse_label'])
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
assert os.path.exists(file_name)
|
||||
assert os.path.exists(file_name + "_test")
|
||||
read(file_name)
|
||||
|
||||
|
||||
def read():
|
||||
def read(file_name):
|
||||
"""test file reader"""
|
||||
count = 0
|
||||
reader = FileReader(MINDRECORD_FILE)
|
||||
reader = FileReader(file_name)
|
||||
for _, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 4
|
||||
count = count + 1
|
||||
|
@ -72,7 +74,7 @@ def read():
|
|||
reader.close()
|
||||
|
||||
count = 0
|
||||
reader = FileReader(MINDRECORD_FILE + "_test")
|
||||
reader = FileReader(file_name + "_test")
|
||||
for _, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 4
|
||||
count = count + 1
|
||||
|
@ -102,16 +104,18 @@ def test_cifar100_to_mindrecord_filename_start_with_space(fixture_file):
|
|||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
|
||||
cifar100_transformer.transform()
|
||||
|
||||
def test_cifar100_to_mindrecord_filename_contain_space(fixture_file):
|
||||
def test_cifar100_to_mindrecord_filename_contain_space():
|
||||
"""
|
||||
test transform cifar10 dataset to mindrecord
|
||||
when file name contains space.
|
||||
Feature: Cifar100ToMR
|
||||
Description: test transform cifar100 dataset to mindrecord when file name contains space.
|
||||
Expectation: generate mindrecord file successfully
|
||||
"""
|
||||
filename = "./yes ok"
|
||||
filename = "./cifar100 ok"
|
||||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
|
||||
cifar100_transformer.transform()
|
||||
assert os.path.exists(filename)
|
||||
assert os.path.exists(filename + "_test")
|
||||
remove_file(filename)
|
||||
|
||||
def test_cifar100_to_mindrecord_directory(fixture_file):
|
||||
"""
|
||||
|
|
|
@ -22,75 +22,61 @@ from mindspore.mindrecord import FileReader
|
|||
from mindspore.mindrecord import SUCCESS
|
||||
|
||||
CIFAR10_DIR = "../data/mindrecord/testCifar10Data"
|
||||
MINDRECORD_FILE = "./cifar10.mindrecord"
|
||||
file_name = "./cifar10.mindrecord"
|
||||
|
||||
def remove_file(x):
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
if os.path.exists("{}_test".format(x)):
|
||||
os.remove("{}_test".format(x))
|
||||
if os.path.exists("{}_test.db".format(x)):
|
||||
os.remove("{}_test.db".format(x))
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_file():
|
||||
"""add/remove file"""
|
||||
def remove_file(x):
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
if os.path.exists("{}_test".format(x)):
|
||||
os.remove("{}_test".format(x))
|
||||
if os.path.exists("{}_test.db".format(x)):
|
||||
os.remove("{}_test.db".format(x))
|
||||
|
||||
remove_file(MINDRECORD_FILE)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
remove_file(file_name)
|
||||
yield "yield_fixture_data"
|
||||
remove_file(MINDRECORD_FILE)
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_space_file():
|
||||
"""add/remove file"""
|
||||
def remove_file(x):
|
||||
if os.path.exists("{}".format(x)):
|
||||
os.remove("{}".format(x))
|
||||
if os.path.exists("{}.db".format(x)):
|
||||
os.remove("{}.db".format(x))
|
||||
if os.path.exists("{}_test".format(x)):
|
||||
os.remove("{}_test".format(x))
|
||||
if os.path.exists("{}_test.db".format(x)):
|
||||
os.remove("{}_test.db".format(x))
|
||||
|
||||
x = "./yes ok"
|
||||
remove_file(x)
|
||||
yield "yield_fixture_data"
|
||||
remove_file(x)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_cifar10_to_mindrecord_without_index_fields(fixture_file):
|
||||
"""test transform cifar10 dataset to mindrecord without index fields."""
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, file_name)
|
||||
cifar10_transformer.transform()
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
assert os.path.exists(file_name)
|
||||
assert os.path.exists(file_name + "_test")
|
||||
read(file_name)
|
||||
|
||||
|
||||
|
||||
def test_cifar10_to_mindrecord(fixture_file):
|
||||
"""test transform cifar10 dataset to mindrecord."""
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, file_name)
|
||||
cifar10_transformer.transform(['label'])
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
assert os.path.exists(file_name)
|
||||
assert os.path.exists(file_name + "_test")
|
||||
read(file_name)
|
||||
|
||||
def test_cifar10_to_mindrecord_with_return(fixture_file):
|
||||
"""test transform cifar10 dataset to mindrecord."""
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, file_name)
|
||||
ret = cifar10_transformer.transform(['label'])
|
||||
assert ret == SUCCESS, "commit failed"
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + "_test")
|
||||
read()
|
||||
assert os.path.exists(file_name)
|
||||
assert os.path.exists(file_name + "_test")
|
||||
read(file_name)
|
||||
|
||||
|
||||
def read():
|
||||
def read(file_name):
|
||||
"""test file reader"""
|
||||
count = 0
|
||||
reader = FileReader(MINDRECORD_FILE)
|
||||
reader = FileReader(file_name)
|
||||
for _, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
count = count + 1
|
||||
|
@ -100,7 +86,7 @@ def read():
|
|||
reader.close()
|
||||
|
||||
count = 0
|
||||
reader = FileReader(MINDRECORD_FILE + "_test")
|
||||
reader = FileReader(file_name + "_test")
|
||||
for _, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
count = count + 1
|
||||
|
@ -130,16 +116,18 @@ def test_cifar10_to_mindrecord_filename_start_with_space(fixture_file):
|
|||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
|
||||
cifar10_transformer.transform()
|
||||
|
||||
def test_cifar10_to_mindrecord_filename_contain_space(fixture_space_file):
|
||||
def test_cifar10_to_mindrecord_filename_contain_space():
|
||||
"""
|
||||
test transform cifar10 dataset to mindrecord
|
||||
when file name contains space.
|
||||
Feature: Cifar10ToMR
|
||||
Description: test transform cifar10 dataset to mindrecord when file name contains space.
|
||||
Expectation: generate mindrecord file successfully
|
||||
"""
|
||||
filename = "./yes ok"
|
||||
filename = "./cifar10 ok"
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
|
||||
cifar10_transformer.transform()
|
||||
assert os.path.exists(filename)
|
||||
assert os.path.exists(filename + "_test")
|
||||
remove_file(filename)
|
||||
|
||||
def test_cifar10_to_mindrecord_directory(fixture_file):
|
||||
"""
|
||||
|
|
|
@ -27,7 +27,6 @@ except ModuleNotFoundError:
|
|||
pd = None
|
||||
|
||||
CSV_FILE = "../data/mindrecord/testCsv/data.csv"
|
||||
MINDRECORD_FILE = "../data/mindrecord/testCsv/csv.mindrecord"
|
||||
PARTITION_NUMBER = 4
|
||||
|
||||
@pytest.fixture(name="remove_mindrecord_file")
|
||||
|
@ -36,20 +35,21 @@ def fixture_remove():
|
|||
def remove_one_file(x):
|
||||
if os.path.exists(x):
|
||||
os.remove(x)
|
||||
def remove_file():
|
||||
x = MINDRECORD_FILE
|
||||
def remove_file(file_name):
|
||||
x = file_name
|
||||
remove_one_file(x)
|
||||
x = MINDRECORD_FILE + ".db"
|
||||
x = file_name + ".db"
|
||||
remove_one_file(x)
|
||||
for i in range(PARTITION_NUMBER):
|
||||
x = MINDRECORD_FILE + str(i)
|
||||
x = file_name + str(i)
|
||||
remove_one_file(x)
|
||||
x = MINDRECORD_FILE + str(i) + ".db"
|
||||
x = file_name + str(i) + ".db"
|
||||
remove_one_file(x)
|
||||
|
||||
remove_file()
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
remove_file(file_name)
|
||||
yield "yield_fixture_data"
|
||||
remove_file()
|
||||
remove_file(file_name)
|
||||
|
||||
def read(filename, columns, row_num):
|
||||
"""test file reade"""
|
||||
|
@ -70,26 +70,29 @@ def read(filename, columns, row_num):
|
|||
|
||||
def test_csv_to_mindrecord(remove_mindrecord_file):
|
||||
"""test transform csv to mindrecord."""
|
||||
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, partition_number=PARTITION_NUMBER)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
csv_trans = CsvToMR(CSV_FILE, file_name, partition_number=PARTITION_NUMBER)
|
||||
csv_trans.transform()
|
||||
for i in range(PARTITION_NUMBER):
|
||||
assert os.path.exists(MINDRECORD_FILE + str(i))
|
||||
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
|
||||
read(MINDRECORD_FILE + "0", ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5)
|
||||
assert os.path.exists(file_name + str(i))
|
||||
assert os.path.exists(file_name + str(i) + ".db")
|
||||
read(file_name + "0", ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5)
|
||||
|
||||
def test_csv_to_mindrecord_with_columns(remove_mindrecord_file):
|
||||
"""test transform csv to mindrecord."""
|
||||
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, columns_list=['Age', 'Sales'], partition_number=PARTITION_NUMBER)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
csv_trans = CsvToMR(CSV_FILE, file_name, columns_list=['Age', 'Sales'], partition_number=PARTITION_NUMBER)
|
||||
csv_trans.transform()
|
||||
for i in range(PARTITION_NUMBER):
|
||||
assert os.path.exists(MINDRECORD_FILE + str(i))
|
||||
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
|
||||
read(MINDRECORD_FILE + "0", ["Age", "Sales"], 5)
|
||||
assert os.path.exists(file_name + str(i))
|
||||
assert os.path.exists(file_name + str(i) + ".db")
|
||||
read(file_name + "0", ["Age", "Sales"], 5)
|
||||
|
||||
def test_csv_to_mindrecord_with_no_exist_columns(remove_mindrecord_file):
|
||||
"""test transform csv to mindrecord."""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
with pytest.raises(Exception, match="The parameter columns_list is illegal, column ssales does not exist."):
|
||||
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, columns_list=['Age', 'ssales'],
|
||||
csv_trans = CsvToMR(CSV_FILE, file_name, columns_list=['Age', 'ssales'],
|
||||
partition_number=PARTITION_NUMBER)
|
||||
csv_trans.transform()
|
||||
|
||||
|
@ -97,8 +100,9 @@ def test_csv_partition_number_with_illegal_columns(remove_mindrecord_file):
|
|||
"""
|
||||
test transform csv to mindrecord
|
||||
"""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
with pytest.raises(Exception, match="The parameter columns_list must be list of str."):
|
||||
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, ["Sales", 2])
|
||||
csv_trans = CsvToMR(CSV_FILE, file_name, ["Sales", 2])
|
||||
csv_trans.transform()
|
||||
|
||||
|
||||
|
@ -107,19 +111,21 @@ def test_csv_to_mindrecord_default_partition_number(remove_mindrecord_file):
|
|||
test transform csv to mindrecord
|
||||
when partition number is default.
|
||||
"""
|
||||
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
csv_trans = CsvToMR(CSV_FILE, file_name)
|
||||
csv_trans.transform()
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + ".db")
|
||||
read(MINDRECORD_FILE, ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5)
|
||||
assert os.path.exists(file_name)
|
||||
assert os.path.exists(file_name + ".db")
|
||||
read(file_name, ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5)
|
||||
|
||||
def test_csv_partition_number_0(remove_mindrecord_file):
|
||||
"""
|
||||
test transform csv to mindrecord
|
||||
when partition number is 0.
|
||||
"""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
with pytest.raises(Exception, match="Invalid parameter value"):
|
||||
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, None, 0)
|
||||
csv_trans = CsvToMR(CSV_FILE, file_name, None, 0)
|
||||
csv_trans.transform()
|
||||
|
||||
def test_csv_to_mindrecord_partition_number_none(remove_mindrecord_file):
|
||||
|
@ -127,9 +133,10 @@ def test_csv_to_mindrecord_partition_number_none(remove_mindrecord_file):
|
|||
test transform csv to mindrecord
|
||||
when partition number is none.
|
||||
"""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
with pytest.raises(Exception,
|
||||
match="The parameter partition_number must be int"):
|
||||
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, None, None)
|
||||
csv_trans = CsvToMR(CSV_FILE, file_name, None, None)
|
||||
csv_trans.transform()
|
||||
|
||||
def test_csv_to_mindrecord_illegal_filename(remove_mindrecord_file):
|
||||
|
|
|
@ -22,7 +22,6 @@ from mindspore.mindrecord import ImageNetToMR
|
|||
|
||||
IMAGENET_MAP_FILE = "../data/mindrecord/testImageNetDataWhole/labels_map.txt"
|
||||
IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images"
|
||||
MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord"
|
||||
PARTITION_NUMBER = 4
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -31,20 +30,21 @@ def fixture_file():
|
|||
def remove_one_file(x):
|
||||
if os.path.exists(x):
|
||||
os.remove(x)
|
||||
def remove_file():
|
||||
x = MINDRECORD_FILE
|
||||
def remove_file(file_name):
|
||||
x = file_name
|
||||
remove_one_file(x)
|
||||
x = MINDRECORD_FILE + ".db"
|
||||
x = file_name + ".db"
|
||||
remove_one_file(x)
|
||||
for i in range(PARTITION_NUMBER):
|
||||
x = MINDRECORD_FILE + str(i)
|
||||
x = file_name + str(i)
|
||||
remove_one_file(x)
|
||||
x = MINDRECORD_FILE + str(i) + ".db"
|
||||
x = file_name + str(i) + ".db"
|
||||
remove_one_file(x)
|
||||
|
||||
remove_file()
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
remove_file(file_name)
|
||||
yield "yield_fixture_data"
|
||||
remove_file()
|
||||
remove_file(file_name)
|
||||
|
||||
def read(filename):
|
||||
"""test file reade"""
|
||||
|
@ -60,35 +60,38 @@ def read(filename):
|
|||
|
||||
def test_imagenet_to_mindrecord(fixture_file):
|
||||
"""test transform imagenet dataset to mindrecord."""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR,
|
||||
MINDRECORD_FILE, PARTITION_NUMBER)
|
||||
file_name, PARTITION_NUMBER)
|
||||
imagenet_transformer.transform()
|
||||
for i in range(PARTITION_NUMBER):
|
||||
assert os.path.exists(MINDRECORD_FILE + str(i))
|
||||
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
|
||||
read(MINDRECORD_FILE + "0")
|
||||
assert os.path.exists(file_name + str(i))
|
||||
assert os.path.exists(file_name + str(i) + ".db")
|
||||
read(file_name + "0")
|
||||
|
||||
def test_imagenet_to_mindrecord_default_partition_number(fixture_file):
|
||||
"""
|
||||
test transform imagenet dataset to mindrecord
|
||||
when partition number is default.
|
||||
"""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR,
|
||||
MINDRECORD_FILE)
|
||||
file_name)
|
||||
imagenet_transformer.transform()
|
||||
assert os.path.exists(MINDRECORD_FILE)
|
||||
assert os.path.exists(MINDRECORD_FILE + ".db")
|
||||
read(MINDRECORD_FILE)
|
||||
assert os.path.exists(file_name)
|
||||
assert os.path.exists(file_name + ".db")
|
||||
read(file_name)
|
||||
|
||||
def test_imagenet_to_mindrecord_partition_number_0(fixture_file):
|
||||
"""
|
||||
test transform imagenet dataset to mindrecord
|
||||
when partition number is 0.
|
||||
"""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
with pytest.raises(Exception, match="Invalid parameter value"):
|
||||
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE,
|
||||
IMAGENET_IMAGE_DIR,
|
||||
MINDRECORD_FILE, 0)
|
||||
file_name, 0)
|
||||
imagenet_transformer.transform()
|
||||
|
||||
def test_imagenet_to_mindrecord_partition_number_none(fixture_file):
|
||||
|
@ -96,11 +99,12 @@ def test_imagenet_to_mindrecord_partition_number_none(fixture_file):
|
|||
test transform imagenet dataset to mindrecord
|
||||
when partition number is none.
|
||||
"""
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
with pytest.raises(Exception,
|
||||
match="The parameter partition_number must be int"):
|
||||
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE,
|
||||
IMAGENET_IMAGE_DIR,
|
||||
MINDRECORD_FILE, None)
|
||||
file_name, None)
|
||||
imagenet_transformer.transform()
|
||||
|
||||
def test_imagenet_to_mindrecord_illegal_filename(fixture_file):
|
||||
|
@ -108,7 +112,7 @@ def test_imagenet_to_mindrecord_illegal_filename(fixture_file):
|
|||
test transform imagenet dataset to mindrecord
|
||||
when file name contains illegal character.
|
||||
"""
|
||||
filename = "not_*ok"
|
||||
filename = "imagenet_not_*ok"
|
||||
with pytest.raises(Exception, match="File name should not contains"):
|
||||
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE,
|
||||
IMAGENET_IMAGE_DIR, filename,
|
||||
|
@ -120,7 +124,7 @@ def test_imagenet_to_mindrecord_illegal_1_filename(fixture_file):
|
|||
test transform imagenet dataset to mindrecord
|
||||
when file name end with '/'.
|
||||
"""
|
||||
filename = "test/path/"
|
||||
filename = "imagenet/path/"
|
||||
with pytest.raises(Exception, match="File path can not end with '/'"):
|
||||
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE,
|
||||
IMAGENET_IMAGE_DIR, filename,
|
||||
|
|
|
@ -23,7 +23,6 @@ from mindspore.mindrecord import FileReader
|
|||
from mindspore.mindrecord import MnistToMR
|
||||
|
||||
MNIST_DIR = "../data/mindrecord/testMnistData"
|
||||
FILE_NAME = "mnist"
|
||||
PARTITION_NUM = 4
|
||||
IMAGE_SIZE = 28
|
||||
NUM_CHANNELS = 1
|
||||
|
@ -34,32 +33,35 @@ def fixture_file():
|
|||
def remove_one_file(x):
|
||||
if os.path.exists(x):
|
||||
os.remove(x)
|
||||
def remove_file():
|
||||
x = "mnist_train.mindrecord"
|
||||
remove_one_file(x)
|
||||
x = "mnist_train.mindrecord.db"
|
||||
remove_one_file(x)
|
||||
x = "mnist_test.mindrecord"
|
||||
remove_one_file(x)
|
||||
x = "mnist_test.mindrecord.db"
|
||||
remove_one_file(x)
|
||||
def remove_file(file_name):
|
||||
remove_one_file(file_name + '_train.mindrecord')
|
||||
remove_one_file(file_name + '_train.mindrecord.db')
|
||||
remove_one_file(file_name + '_test.mindrecord')
|
||||
remove_one_file(file_name + '_test.mindrecord.db')
|
||||
for i in range(PARTITION_NUM):
|
||||
x = "mnist_train.mindrecord" + str(i)
|
||||
x = file_name + "_train.mindrecord" + str(i)
|
||||
remove_one_file(x)
|
||||
x = "mnist_train.mindrecord" + str(i) + ".db"
|
||||
x = file_name + "_train.mindrecord" + str(i) + ".db"
|
||||
remove_one_file(x)
|
||||
x = "mnist_test.mindrecord" + str(i)
|
||||
x = file_name + "_test.mindrecord" + str(i)
|
||||
remove_one_file(x)
|
||||
x = "mnist_test.mindrecord" + str(i) + ".db"
|
||||
x = file_name + "_test.mindrecord" + str(i) + ".db"
|
||||
remove_one_file(x)
|
||||
|
||||
remove_file()
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
remove_file(file_name)
|
||||
yield "yield_fixture_data"
|
||||
remove_file()
|
||||
remove_file(file_name)
|
||||
|
||||
def read(train_name, test_name):
|
||||
def read(file_name, partition=False):
|
||||
"""test file reader"""
|
||||
count = 0
|
||||
if partition:
|
||||
train_name = file_name + "_train.mindrecord0"
|
||||
test_name = file_name + "_test.mindrecord0"
|
||||
else:
|
||||
train_name = file_name + "_train.mindrecord"
|
||||
test_name = file_name + "_test.mindrecord"
|
||||
reader = FileReader(train_name)
|
||||
for _, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 2
|
||||
|
@ -82,21 +84,24 @@ def read(train_name, test_name):
|
|||
|
||||
def test_mnist_to_mindrecord(fixture_file):
|
||||
"""test transform mnist dataset to mindrecord."""
|
||||
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mnist_transformer = MnistToMR(MNIST_DIR, file_name)
|
||||
mnist_transformer.transform()
|
||||
assert os.path.exists("mnist_train.mindrecord")
|
||||
assert os.path.exists("mnist_test.mindrecord")
|
||||
assert os.path.exists(file_name + "_train.mindrecord")
|
||||
assert os.path.exists(file_name + "_test.mindrecord")
|
||||
|
||||
read("mnist_train.mindrecord", "mnist_test.mindrecord")
|
||||
read(file_name)
|
||||
|
||||
def test_mnist_to_mindrecord_compare_data(fixture_file):
|
||||
"""test transform mnist dataset to mindrecord and compare data."""
|
||||
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mnist_transformer = MnistToMR(MNIST_DIR, file_name)
|
||||
mnist_transformer.transform()
|
||||
assert os.path.exists("mnist_train.mindrecord")
|
||||
assert os.path.exists("mnist_test.mindrecord")
|
||||
assert os.path.exists(file_name + "_train.mindrecord")
|
||||
assert os.path.exists(file_name + "_test.mindrecord")
|
||||
|
||||
train_name, test_name = "mnist_train.mindrecord", "mnist_test.mindrecord"
|
||||
train_name = file_name + "_train.mindrecord"
|
||||
test_name = file_name + "_test.mindrecord"
|
||||
|
||||
def _extract_images(filename, num_images):
|
||||
"""Extract the images into a 4D tensor [image index, y, x, channels]."""
|
||||
|
@ -147,7 +152,8 @@ def test_mnist_to_mindrecord_compare_data(fixture_file):
|
|||
|
||||
def test_mnist_to_mindrecord_multi_partition(fixture_file):
|
||||
"""test transform mnist dataset to multiple mindrecord files."""
|
||||
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM)
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mnist_transformer = MnistToMR(MNIST_DIR, file_name, PARTITION_NUM)
|
||||
mnist_transformer.transform()
|
||||
|
||||
read("mnist_train.mindrecord0", "mnist_test.mindrecord0")
|
||||
read(file_name, partition=True)
|
||||
|
|
|
@ -32,8 +32,6 @@ except ModuleNotFoundError:
|
|||
tf = None
|
||||
|
||||
TFRECORD_DATA_DIR = "../data/mindrecord/testTFRecordData"
|
||||
TFRECORD_FILE_NAME = "test.tfrecord"
|
||||
MINDRECORD_FILE_NAME = "test.mindrecord"
|
||||
PARTITION_NUM = 1
|
||||
|
||||
def cast_name(key):
|
||||
|
@ -82,7 +80,7 @@ def verify_data(transformer, reader):
|
|||
assert value == mr_item[cast_name(key)]
|
||||
assert count == 10
|
||||
|
||||
def generate_tfrecord():
|
||||
def generate_tfrecord(tfrecord_file_name):
|
||||
def create_int_feature(values):
|
||||
if isinstance(values, list):
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int]
|
||||
|
@ -105,7 +103,7 @@ def generate_tfrecord():
|
|||
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
|
||||
return feature
|
||||
|
||||
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
example_count = 0
|
||||
for i in range(10):
|
||||
|
@ -131,7 +129,7 @@ def generate_tfrecord():
|
|||
writer.close()
|
||||
logger.info("Write {} rows in tfrecord.".format(example_count))
|
||||
|
||||
def generate_tfrecord_with_special_field_name():
|
||||
def generate_tfrecord_with_special_field_name(tfrecord_file_name):
|
||||
def create_int_feature(values):
|
||||
if isinstance(values, list):
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int]
|
||||
|
@ -147,7 +145,7 @@ def generate_tfrecord_with_special_field_name():
|
|||
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
|
||||
return feature
|
||||
|
||||
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
example_count = 0
|
||||
for i in range(10):
|
||||
|
@ -172,8 +170,12 @@ def test_tfrecord_to_mindrecord():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord(tfrecord_file_name)
|
||||
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
|
@ -183,25 +185,25 @@ def test_tfrecord_to_mindrecord():
|
|||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME)
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
|
||||
assert os.path.exists(mindrecord_file_name)
|
||||
assert os.path.exists(mindrecord_file_name + ".db")
|
||||
|
||||
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
|
||||
fr_mindrecord = FileReader(mindrecord_file_name)
|
||||
verify_data(tfrecord_transformer, fr_mindrecord)
|
||||
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
os.remove(mindrecord_file_name)
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
def test_tfrecord_to_mindrecord_scalar_with_1():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
|
@ -211,8 +213,11 @@ def test_tfrecord_to_mindrecord_scalar_with_1():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord(tfrecord_file_name)
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
|
@ -222,25 +227,25 @@ def test_tfrecord_to_mindrecord_scalar_with_1():
|
|||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME)
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
|
||||
assert os.path.exists(mindrecord_file_name)
|
||||
assert os.path.exists(mindrecord_file_name + ".db")
|
||||
|
||||
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
|
||||
fr_mindrecord = FileReader(mindrecord_file_name)
|
||||
verify_data(tfrecord_transformer, fr_mindrecord)
|
||||
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
os.remove(mindrecord_file_name)
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
|
@ -250,8 +255,11 @@ def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord(tfrecord_file_name)
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
|
@ -261,22 +269,22 @@ def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception():
|
|||
"float_list": tf.io.FixedLenFeature([2], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
def test_tfrecord_to_mindrecord_list_with_diff_type_exception():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
|
@ -286,8 +294,11 @@ def test_tfrecord_to_mindrecord_list_with_diff_type_exception():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord(tfrecord_file_name)
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
|
@ -297,22 +308,22 @@ def test_tfrecord_to_mindrecord_list_with_diff_type_exception():
|
|||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
def test_tfrecord_to_mindrecord_list_without_bytes_type():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
|
@ -322,8 +333,11 @@ def test_tfrecord_to_mindrecord_list_without_bytes_type():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord(tfrecord_file_name)
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
|
@ -333,25 +347,25 @@ def test_tfrecord_to_mindrecord_list_without_bytes_type():
|
|||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict)
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict)
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME)
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
|
||||
assert os.path.exists(mindrecord_file_name)
|
||||
assert os.path.exists(mindrecord_file_name + ".db")
|
||||
|
||||
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
|
||||
fr_mindrecord = FileReader(mindrecord_file_name)
|
||||
verify_data(tfrecord_transformer, fr_mindrecord)
|
||||
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
os.remove(mindrecord_file_name)
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
def test_tfrecord_to_mindrecord_scalar_with_2_exception():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
|
@ -361,8 +375,11 @@ def test_tfrecord_to_mindrecord_scalar_with_2_exception():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord(tfrecord_file_name)
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
|
@ -372,22 +389,22 @@ def test_tfrecord_to_mindrecord_scalar_with_2_exception():
|
|||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict, ["image_bytes"])
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
def test_tfrecord_to_mindrecord_scalar_string_with_1_exception():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
|
@ -397,8 +414,11 @@ def test_tfrecord_to_mindrecord_scalar_string_with_1_exception():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord(tfrecord_file_name)
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([1], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
|
@ -408,22 +428,22 @@ def test_tfrecord_to_mindrecord_scalar_string_with_1_exception():
|
|||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
|
@ -433,8 +453,11 @@ def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord(tfrecord_file_name)
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([10], tf.string),
|
||||
|
@ -444,22 +467,22 @@ def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
|
|||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
|
@ -469,8 +492,11 @@ def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord(tfrecord_file_name)
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
|
@ -480,22 +506,22 @@ def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type():
|
|||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["int64_list"])
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict, ["int64_list"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
|
@ -505,8 +531,11 @@ def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord(tfrecord_file_name)
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
|
@ -516,22 +545,22 @@ def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list():
|
|||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, "")
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict, "")
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
def test_tfrecord_to_mindrecord_with_special_field_name():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
|
@ -541,29 +570,32 @@ def test_tfrecord_to_mindrecord_with_special_field_name():
|
|||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord_with_special_field_name()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
mindrecord_file_name = file_name_ + '.mindrecord'
|
||||
tfrecord_file_name = file_name_ + '.tfrecord'
|
||||
generate_tfrecord_with_special_field_name(tfrecord_file_name)
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
||||
feature_dict = {"image/class/label": tf.io.FixedLenFeature([], tf.int64),
|
||||
"image/encoded": tf.io.FixedLenFeature([], tf.string),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
if os.path.exists(mindrecord_file_name):
|
||||
os.remove(mindrecord_file_name)
|
||||
if os.path.exists(mindrecord_file_name + ".db"):
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image/encoded"])
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
|
||||
mindrecord_file_name, feature_dict, ["image/encoded"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME)
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
|
||||
assert os.path.exists(mindrecord_file_name)
|
||||
assert os.path.exists(mindrecord_file_name + ".db")
|
||||
|
||||
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
|
||||
fr_mindrecord = FileReader(mindrecord_file_name)
|
||||
verify_data(tfrecord_transformer, fr_mindrecord)
|
||||
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
os.remove(mindrecord_file_name)
|
||||
os.remove(mindrecord_file_name + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
|
||||
|
|
Loading…
Reference in New Issue