fix md ut part05

This commit is contained in:
liyong 2021-10-21 19:34:23 +08:00
parent 00bc83db01
commit 7c9f6899a3
7 changed files with 346 additions and 302 deletions

View File

@ -1997,6 +1997,7 @@ def create_multi_mindrecord_files():
os.remove("{}.db".format(filename))
def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
ds.config.set_seed(1)
datas_all = []
index = 0
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
@ -2233,6 +2234,7 @@ def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
assert origin_index != current_index
def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_files):
ds.config.set_seed(1)
datas_all = []
datas_all_samples = []
index = 0
@ -2424,6 +2426,7 @@ def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_fi
assert origin_index != current_index
def test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files):
ds.config.set_seed(1)
datas_all = []
datas_all_samples = []
index = 0

View File

@ -22,47 +22,49 @@ from mindspore.mindrecord import FileReader
from mindspore.mindrecord import SUCCESS
CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
MINDRECORD_FILE = "./cifar100.mindrecord"
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
remove_file(MINDRECORD_FILE)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
remove_file(file_name)
yield "yield_fixture_data"
remove_file(MINDRECORD_FILE)
remove_file(file_name)
def test_cifar100_to_mindrecord_without_index_fields(fixture_file):
"""test transform cifar100 dataset to mindrecord without index fields."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, file_name)
ret = cifar100_transformer.transform()
assert ret == SUCCESS, "Failed to transform from cifar100 to mindrecord"
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
assert os.path.exists(file_name)
assert os.path.exists(file_name + "_test")
read(file_name)
def test_cifar100_to_mindrecord(fixture_file):
"""test transform cifar100 dataset to mindrecord."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, file_name)
cifar100_transformer.transform(['fine_label', 'coarse_label'])
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
assert os.path.exists(file_name)
assert os.path.exists(file_name + "_test")
read(file_name)
def read():
def read(file_name):
"""test file reader"""
count = 0
reader = FileReader(MINDRECORD_FILE)
reader = FileReader(file_name)
for _, x in enumerate(reader.get_next()):
assert len(x) == 4
count = count + 1
@ -72,7 +74,7 @@ def read():
reader.close()
count = 0
reader = FileReader(MINDRECORD_FILE + "_test")
reader = FileReader(file_name + "_test")
for _, x in enumerate(reader.get_next()):
assert len(x) == 4
count = count + 1
@ -102,16 +104,18 @@ def test_cifar100_to_mindrecord_filename_start_with_space(fixture_file):
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
cifar100_transformer.transform()
def test_cifar100_to_mindrecord_filename_contain_space(fixture_file):
def test_cifar100_to_mindrecord_filename_contain_space():
"""
test transform cifar10 dataset to mindrecord
when file name contains space.
Feature: Cifar100ToMR
Description: test transform cifar100 dataset to mindrecord when file name contains space.
Expectation: generate mindrecord file successfully
"""
filename = "./yes ok"
filename = "./cifar100 ok"
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
cifar100_transformer.transform()
assert os.path.exists(filename)
assert os.path.exists(filename + "_test")
remove_file(filename)
def test_cifar100_to_mindrecord_directory(fixture_file):
"""

View File

@ -22,75 +22,61 @@ from mindspore.mindrecord import FileReader
from mindspore.mindrecord import SUCCESS
CIFAR10_DIR = "../data/mindrecord/testCifar10Data"
MINDRECORD_FILE = "./cifar10.mindrecord"
file_name = "./cifar10.mindrecord"
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
remove_file(MINDRECORD_FILE)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
remove_file(file_name)
yield "yield_fixture_data"
remove_file(MINDRECORD_FILE)
@pytest.fixture
def fixture_space_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
x = "./yes ok"
remove_file(x)
yield "yield_fixture_data"
remove_file(x)
remove_file(file_name)
def test_cifar10_to_mindrecord_without_index_fields(fixture_file):
"""test transform cifar10 dataset to mindrecord without index fields."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, file_name)
cifar10_transformer.transform()
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
assert os.path.exists(file_name)
assert os.path.exists(file_name + "_test")
read(file_name)
def test_cifar10_to_mindrecord(fixture_file):
"""test transform cifar10 dataset to mindrecord."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, file_name)
cifar10_transformer.transform(['label'])
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
assert os.path.exists(file_name)
assert os.path.exists(file_name + "_test")
read(file_name)
def test_cifar10_to_mindrecord_with_return(fixture_file):
"""test transform cifar10 dataset to mindrecord."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, file_name)
ret = cifar10_transformer.transform(['label'])
assert ret == SUCCESS, "commit failed"
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
assert os.path.exists(file_name)
assert os.path.exists(file_name + "_test")
read(file_name)
def read():
def read(file_name):
"""test file reader"""
count = 0
reader = FileReader(MINDRECORD_FILE)
reader = FileReader(file_name)
for _, x in enumerate(reader.get_next()):
assert len(x) == 3
count = count + 1
@ -100,7 +86,7 @@ def read():
reader.close()
count = 0
reader = FileReader(MINDRECORD_FILE + "_test")
reader = FileReader(file_name + "_test")
for _, x in enumerate(reader.get_next()):
assert len(x) == 3
count = count + 1
@ -130,16 +116,18 @@ def test_cifar10_to_mindrecord_filename_start_with_space(fixture_file):
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
cifar10_transformer.transform()
def test_cifar10_to_mindrecord_filename_contain_space(fixture_space_file):
def test_cifar10_to_mindrecord_filename_contain_space():
"""
test transform cifar10 dataset to mindrecord
when file name contains space.
Feature: Cifar10ToMR
Description: test transform cifar10 dataset to mindrecord when file name contains space.
Expectation: generate mindrecord file successfully
"""
filename = "./yes ok"
filename = "./cifar10 ok"
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
cifar10_transformer.transform()
assert os.path.exists(filename)
assert os.path.exists(filename + "_test")
remove_file(filename)
def test_cifar10_to_mindrecord_directory(fixture_file):
"""

View File

@ -27,7 +27,6 @@ except ModuleNotFoundError:
pd = None
CSV_FILE = "../data/mindrecord/testCsv/data.csv"
MINDRECORD_FILE = "../data/mindrecord/testCsv/csv.mindrecord"
PARTITION_NUMBER = 4
@pytest.fixture(name="remove_mindrecord_file")
@ -36,20 +35,21 @@ def fixture_remove():
def remove_one_file(x):
if os.path.exists(x):
os.remove(x)
def remove_file():
x = MINDRECORD_FILE
def remove_file(file_name):
x = file_name
remove_one_file(x)
x = MINDRECORD_FILE + ".db"
x = file_name + ".db"
remove_one_file(x)
for i in range(PARTITION_NUMBER):
x = MINDRECORD_FILE + str(i)
x = file_name + str(i)
remove_one_file(x)
x = MINDRECORD_FILE + str(i) + ".db"
x = file_name + str(i) + ".db"
remove_one_file(x)
remove_file()
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
remove_file(file_name)
yield "yield_fixture_data"
remove_file()
remove_file(file_name)
def read(filename, columns, row_num):
"""test file reade"""
@ -70,26 +70,29 @@ def read(filename, columns, row_num):
def test_csv_to_mindrecord(remove_mindrecord_file):
"""test transform csv to mindrecord."""
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, partition_number=PARTITION_NUMBER)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
csv_trans = CsvToMR(CSV_FILE, file_name, partition_number=PARTITION_NUMBER)
csv_trans.transform()
for i in range(PARTITION_NUMBER):
assert os.path.exists(MINDRECORD_FILE + str(i))
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
read(MINDRECORD_FILE + "0", ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5)
assert os.path.exists(file_name + str(i))
assert os.path.exists(file_name + str(i) + ".db")
read(file_name + "0", ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5)
def test_csv_to_mindrecord_with_columns(remove_mindrecord_file):
"""test transform csv to mindrecord."""
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, columns_list=['Age', 'Sales'], partition_number=PARTITION_NUMBER)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
csv_trans = CsvToMR(CSV_FILE, file_name, columns_list=['Age', 'Sales'], partition_number=PARTITION_NUMBER)
csv_trans.transform()
for i in range(PARTITION_NUMBER):
assert os.path.exists(MINDRECORD_FILE + str(i))
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
read(MINDRECORD_FILE + "0", ["Age", "Sales"], 5)
assert os.path.exists(file_name + str(i))
assert os.path.exists(file_name + str(i) + ".db")
read(file_name + "0", ["Age", "Sales"], 5)
def test_csv_to_mindrecord_with_no_exist_columns(remove_mindrecord_file):
"""test transform csv to mindrecord."""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
with pytest.raises(Exception, match="The parameter columns_list is illegal, column ssales does not exist."):
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, columns_list=['Age', 'ssales'],
csv_trans = CsvToMR(CSV_FILE, file_name, columns_list=['Age', 'ssales'],
partition_number=PARTITION_NUMBER)
csv_trans.transform()
@ -97,8 +100,9 @@ def test_csv_partition_number_with_illegal_columns(remove_mindrecord_file):
"""
test transform csv to mindrecord
"""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
with pytest.raises(Exception, match="The parameter columns_list must be list of str."):
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, ["Sales", 2])
csv_trans = CsvToMR(CSV_FILE, file_name, ["Sales", 2])
csv_trans.transform()
@ -107,19 +111,21 @@ def test_csv_to_mindrecord_default_partition_number(remove_mindrecord_file):
test transform csv to mindrecord
when partition number is default.
"""
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
csv_trans = CsvToMR(CSV_FILE, file_name)
csv_trans.transform()
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + ".db")
read(MINDRECORD_FILE, ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5)
assert os.path.exists(file_name)
assert os.path.exists(file_name + ".db")
read(file_name, ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5)
def test_csv_partition_number_0(remove_mindrecord_file):
"""
test transform csv to mindrecord
when partition number is 0.
"""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
with pytest.raises(Exception, match="Invalid parameter value"):
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, None, 0)
csv_trans = CsvToMR(CSV_FILE, file_name, None, 0)
csv_trans.transform()
def test_csv_to_mindrecord_partition_number_none(remove_mindrecord_file):
@ -127,9 +133,10 @@ def test_csv_to_mindrecord_partition_number_none(remove_mindrecord_file):
test transform csv to mindrecord
when partition number is none.
"""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
with pytest.raises(Exception,
match="The parameter partition_number must be int"):
csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, None, None)
csv_trans = CsvToMR(CSV_FILE, file_name, None, None)
csv_trans.transform()
def test_csv_to_mindrecord_illegal_filename(remove_mindrecord_file):

View File

@ -22,7 +22,6 @@ from mindspore.mindrecord import ImageNetToMR
IMAGENET_MAP_FILE = "../data/mindrecord/testImageNetDataWhole/labels_map.txt"
IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images"
MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord"
PARTITION_NUMBER = 4
@pytest.fixture
@ -31,20 +30,21 @@ def fixture_file():
def remove_one_file(x):
if os.path.exists(x):
os.remove(x)
def remove_file():
x = MINDRECORD_FILE
def remove_file(file_name):
x = file_name
remove_one_file(x)
x = MINDRECORD_FILE + ".db"
x = file_name + ".db"
remove_one_file(x)
for i in range(PARTITION_NUMBER):
x = MINDRECORD_FILE + str(i)
x = file_name + str(i)
remove_one_file(x)
x = MINDRECORD_FILE + str(i) + ".db"
x = file_name + str(i) + ".db"
remove_one_file(x)
remove_file()
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
remove_file(file_name)
yield "yield_fixture_data"
remove_file()
remove_file(file_name)
def read(filename):
"""test file reade"""
@ -60,35 +60,38 @@ def read(filename):
def test_imagenet_to_mindrecord(fixture_file):
"""test transform imagenet dataset to mindrecord."""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR,
MINDRECORD_FILE, PARTITION_NUMBER)
file_name, PARTITION_NUMBER)
imagenet_transformer.transform()
for i in range(PARTITION_NUMBER):
assert os.path.exists(MINDRECORD_FILE + str(i))
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
read(MINDRECORD_FILE + "0")
assert os.path.exists(file_name + str(i))
assert os.path.exists(file_name + str(i) + ".db")
read(file_name + "0")
def test_imagenet_to_mindrecord_default_partition_number(fixture_file):
"""
test transform imagenet dataset to mindrecord
when partition number is default.
"""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR,
MINDRECORD_FILE)
file_name)
imagenet_transformer.transform()
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + ".db")
read(MINDRECORD_FILE)
assert os.path.exists(file_name)
assert os.path.exists(file_name + ".db")
read(file_name)
def test_imagenet_to_mindrecord_partition_number_0(fixture_file):
"""
test transform imagenet dataset to mindrecord
when partition number is 0.
"""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
with pytest.raises(Exception, match="Invalid parameter value"):
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE,
IMAGENET_IMAGE_DIR,
MINDRECORD_FILE, 0)
file_name, 0)
imagenet_transformer.transform()
def test_imagenet_to_mindrecord_partition_number_none(fixture_file):
@ -96,11 +99,12 @@ def test_imagenet_to_mindrecord_partition_number_none(fixture_file):
test transform imagenet dataset to mindrecord
when partition number is none.
"""
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
with pytest.raises(Exception,
match="The parameter partition_number must be int"):
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE,
IMAGENET_IMAGE_DIR,
MINDRECORD_FILE, None)
file_name, None)
imagenet_transformer.transform()
def test_imagenet_to_mindrecord_illegal_filename(fixture_file):
@ -108,7 +112,7 @@ def test_imagenet_to_mindrecord_illegal_filename(fixture_file):
test transform imagenet dataset to mindrecord
when file name contains illegal character.
"""
filename = "not_*ok"
filename = "imagenet_not_*ok"
with pytest.raises(Exception, match="File name should not contains"):
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE,
IMAGENET_IMAGE_DIR, filename,
@ -120,7 +124,7 @@ def test_imagenet_to_mindrecord_illegal_1_filename(fixture_file):
test transform imagenet dataset to mindrecord
when file name end with '/'.
"""
filename = "test/path/"
filename = "imagenet/path/"
with pytest.raises(Exception, match="File path can not end with '/'"):
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE,
IMAGENET_IMAGE_DIR, filename,

View File

@ -23,7 +23,6 @@ from mindspore.mindrecord import FileReader
from mindspore.mindrecord import MnistToMR
MNIST_DIR = "../data/mindrecord/testMnistData"
FILE_NAME = "mnist"
PARTITION_NUM = 4
IMAGE_SIZE = 28
NUM_CHANNELS = 1
@ -34,32 +33,35 @@ def fixture_file():
def remove_one_file(x):
if os.path.exists(x):
os.remove(x)
def remove_file():
x = "mnist_train.mindrecord"
remove_one_file(x)
x = "mnist_train.mindrecord.db"
remove_one_file(x)
x = "mnist_test.mindrecord"
remove_one_file(x)
x = "mnist_test.mindrecord.db"
remove_one_file(x)
def remove_file(file_name):
remove_one_file(file_name + '_train.mindrecord')
remove_one_file(file_name + '_train.mindrecord.db')
remove_one_file(file_name + '_test.mindrecord')
remove_one_file(file_name + '_test.mindrecord.db')
for i in range(PARTITION_NUM):
x = "mnist_train.mindrecord" + str(i)
x = file_name + "_train.mindrecord" + str(i)
remove_one_file(x)
x = "mnist_train.mindrecord" + str(i) + ".db"
x = file_name + "_train.mindrecord" + str(i) + ".db"
remove_one_file(x)
x = "mnist_test.mindrecord" + str(i)
x = file_name + "_test.mindrecord" + str(i)
remove_one_file(x)
x = "mnist_test.mindrecord" + str(i) + ".db"
x = file_name + "_test.mindrecord" + str(i) + ".db"
remove_one_file(x)
remove_file()
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
remove_file(file_name)
yield "yield_fixture_data"
remove_file()
remove_file(file_name)
def read(train_name, test_name):
def read(file_name, partition=False):
"""test file reader"""
count = 0
if partition:
train_name = file_name + "_train.mindrecord0"
test_name = file_name + "_test.mindrecord0"
else:
train_name = file_name + "_train.mindrecord"
test_name = file_name + "_test.mindrecord"
reader = FileReader(train_name)
for _, x in enumerate(reader.get_next()):
assert len(x) == 2
@ -82,21 +84,24 @@ def read(train_name, test_name):
def test_mnist_to_mindrecord(fixture_file):
"""test transform mnist dataset to mindrecord."""
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mnist_transformer = MnistToMR(MNIST_DIR, file_name)
mnist_transformer.transform()
assert os.path.exists("mnist_train.mindrecord")
assert os.path.exists("mnist_test.mindrecord")
assert os.path.exists(file_name + "_train.mindrecord")
assert os.path.exists(file_name + "_test.mindrecord")
read("mnist_train.mindrecord", "mnist_test.mindrecord")
read(file_name)
def test_mnist_to_mindrecord_compare_data(fixture_file):
"""test transform mnist dataset to mindrecord and compare data."""
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mnist_transformer = MnistToMR(MNIST_DIR, file_name)
mnist_transformer.transform()
assert os.path.exists("mnist_train.mindrecord")
assert os.path.exists("mnist_test.mindrecord")
assert os.path.exists(file_name + "_train.mindrecord")
assert os.path.exists(file_name + "_test.mindrecord")
train_name, test_name = "mnist_train.mindrecord", "mnist_test.mindrecord"
train_name = file_name + "_train.mindrecord"
test_name = file_name + "_test.mindrecord"
def _extract_images(filename, num_images):
"""Extract the images into a 4D tensor [image index, y, x, channels]."""
@ -147,7 +152,8 @@ def test_mnist_to_mindrecord_compare_data(fixture_file):
def test_mnist_to_mindrecord_multi_partition(fixture_file):
"""test transform mnist dataset to multiple mindrecord files."""
mnist_transformer = MnistToMR(MNIST_DIR, FILE_NAME, PARTITION_NUM)
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mnist_transformer = MnistToMR(MNIST_DIR, file_name, PARTITION_NUM)
mnist_transformer.transform()
read("mnist_train.mindrecord0", "mnist_test.mindrecord0")
read(file_name, partition=True)

View File

@ -32,8 +32,6 @@ except ModuleNotFoundError:
tf = None
TFRECORD_DATA_DIR = "../data/mindrecord/testTFRecordData"
TFRECORD_FILE_NAME = "test.tfrecord"
MINDRECORD_FILE_NAME = "test.mindrecord"
PARTITION_NUM = 1
def cast_name(key):
@ -82,7 +80,7 @@ def verify_data(transformer, reader):
assert value == mr_item[cast_name(key)]
assert count == 10
def generate_tfrecord():
def generate_tfrecord(tfrecord_file_name):
def create_int_feature(values):
if isinstance(values, list):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int]
@ -105,7 +103,7 @@ def generate_tfrecord():
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
return feature
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
example_count = 0
for i in range(10):
@ -131,7 +129,7 @@ def generate_tfrecord():
writer.close()
logger.info("Write {} rows in tfrecord.".format(example_count))
def generate_tfrecord_with_special_field_name():
def generate_tfrecord_with_special_field_name(tfrecord_file_name):
def create_int_feature(values):
if isinstance(values, list):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int]
@ -147,7 +145,7 @@ def generate_tfrecord_with_special_field_name():
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
return feature
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
example_count = 0
for i in range(10):
@ -172,8 +170,12 @@ def test_tfrecord_to_mindrecord():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
@ -183,25 +185,25 @@ def test_tfrecord_to_mindrecord():
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
assert os.path.exists(MINDRECORD_FILE_NAME)
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
assert os.path.exists(mindrecord_file_name)
assert os.path.exists(mindrecord_file_name + ".db")
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
fr_mindrecord = FileReader(mindrecord_file_name)
verify_data(tfrecord_transformer, fr_mindrecord)
os.remove(MINDRECORD_FILE_NAME)
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(mindrecord_file_name)
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
def test_tfrecord_to_mindrecord_scalar_with_1():
"""test transform tfrecord to mindrecord."""
@ -211,8 +213,11 @@ def test_tfrecord_to_mindrecord_scalar_with_1():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
@ -222,25 +227,25 @@ def test_tfrecord_to_mindrecord_scalar_with_1():
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
assert os.path.exists(MINDRECORD_FILE_NAME)
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
assert os.path.exists(mindrecord_file_name)
assert os.path.exists(mindrecord_file_name + ".db")
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
fr_mindrecord = FileReader(mindrecord_file_name)
verify_data(tfrecord_transformer, fr_mindrecord)
os.remove(MINDRECORD_FILE_NAME)
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(mindrecord_file_name)
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception():
"""test transform tfrecord to mindrecord."""
@ -250,8 +255,11 @@ def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
@ -261,22 +269,22 @@ def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception():
"float_list": tf.io.FixedLenFeature([2], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
def test_tfrecord_to_mindrecord_list_with_diff_type_exception():
"""test transform tfrecord to mindrecord."""
@ -286,8 +294,11 @@ def test_tfrecord_to_mindrecord_list_with_diff_type_exception():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
@ -297,22 +308,22 @@ def test_tfrecord_to_mindrecord_list_with_diff_type_exception():
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
def test_tfrecord_to_mindrecord_list_without_bytes_type():
"""test transform tfrecord to mindrecord."""
@ -322,8 +333,11 @@ def test_tfrecord_to_mindrecord_list_without_bytes_type():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
@ -333,25 +347,25 @@ def test_tfrecord_to_mindrecord_list_without_bytes_type():
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict)
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict)
tfrecord_transformer.transform()
assert os.path.exists(MINDRECORD_FILE_NAME)
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
assert os.path.exists(mindrecord_file_name)
assert os.path.exists(mindrecord_file_name + ".db")
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
fr_mindrecord = FileReader(mindrecord_file_name)
verify_data(tfrecord_transformer, fr_mindrecord)
os.remove(MINDRECORD_FILE_NAME)
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(mindrecord_file_name)
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
def test_tfrecord_to_mindrecord_scalar_with_2_exception():
"""test transform tfrecord to mindrecord."""
@ -361,8 +375,11 @@ def test_tfrecord_to_mindrecord_scalar_with_2_exception():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
@ -372,22 +389,22 @@ def test_tfrecord_to_mindrecord_scalar_with_2_exception():
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict, ["image_bytes"])
with pytest.raises(ValueError):
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
def test_tfrecord_to_mindrecord_scalar_string_with_1_exception():
"""test transform tfrecord to mindrecord."""
@ -397,8 +414,11 @@ def test_tfrecord_to_mindrecord_scalar_string_with_1_exception():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"file_name": tf.io.FixedLenFeature([1], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
@ -408,22 +428,22 @@ def test_tfrecord_to_mindrecord_scalar_string_with_1_exception():
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
"""test transform tfrecord to mindrecord."""
@ -433,8 +453,11 @@ def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([10], tf.string),
@ -444,22 +467,22 @@ def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict, ["image_bytes"])
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type():
"""test transform tfrecord to mindrecord."""
@ -469,8 +492,11 @@ def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
@ -480,22 +506,22 @@ def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type():
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["int64_list"])
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict, ["int64_list"])
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list():
"""test transform tfrecord to mindrecord."""
@ -505,8 +531,11 @@ def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
@ -516,22 +545,22 @@ def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list():
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, "")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict, "")
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
def test_tfrecord_to_mindrecord_with_special_field_name():
"""test transform tfrecord to mindrecord."""
@ -541,29 +570,32 @@ def test_tfrecord_to_mindrecord_with_special_field_name():
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord_with_special_field_name()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
file_name_ = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
mindrecord_file_name = file_name_ + '.mindrecord'
tfrecord_file_name = file_name_ + '.tfrecord'
generate_tfrecord_with_special_field_name(tfrecord_file_name)
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))
feature_dict = {"image/class/label": tf.io.FixedLenFeature([], tf.int64),
"image/encoded": tf.io.FixedLenFeature([], tf.string),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
if os.path.exists(mindrecord_file_name):
os.remove(mindrecord_file_name)
if os.path.exists(mindrecord_file_name + ".db"):
os.remove(mindrecord_file_name + ".db")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image/encoded"])
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name),
mindrecord_file_name, feature_dict, ["image/encoded"])
tfrecord_transformer.transform()
assert os.path.exists(MINDRECORD_FILE_NAME)
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
assert os.path.exists(mindrecord_file_name)
assert os.path.exists(mindrecord_file_name + ".db")
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
fr_mindrecord = FileReader(mindrecord_file_name)
verify_data(tfrecord_transformer, fr_mindrecord)
os.remove(MINDRECORD_FILE_NAME)
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(mindrecord_file_name)
os.remove(mindrecord_file_name + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
os.remove(os.path.join(TFRECORD_DATA_DIR, tfrecord_file_name))