forked from mindspore-Ecosystem/mindspore
add tool: tfrecord to mindrecord
This commit is contained in:
parent
cb706951c1
commit
3353a20d8b
|
@ -31,7 +31,8 @@ from .tools.cifar10_to_mr import Cifar10ToMR
|
|||
from .tools.cifar100_to_mr import Cifar100ToMR
|
||||
from .tools.imagenet_to_mr import ImageNetToMR
|
||||
from .tools.mnist_to_mr import MnistToMR
|
||||
from .tools.tfrecord_to_mr import TFRecordToMR
|
||||
|
||||
__all__ = ['FileWriter', 'FileReader', 'MindPage',
|
||||
'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR',
|
||||
'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR', 'TFRecordToMR',
|
||||
'SUCCESS', 'FAILED']
|
||||
|
|
|
@ -0,0 +1,268 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
TFRecord convert tool for MindRecord
|
||||
"""
|
||||
|
||||
from importlib import import_module
|
||||
from string import punctuation
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
from ..filewriter import FileWriter
|
||||
from ..shardutils import check_filename
|
||||
|
||||
try:
|
||||
tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
|
||||
except ModuleNotFoundError:
|
||||
tf = None
|
||||
|
||||
__all__ = ['TFRecordToMR']
|
||||
|
||||
SupportedTensorFlowVersion = '2.1.0'
|
||||
|
||||
def _cast_type(value):
|
||||
"""
|
||||
Cast complex data type to basic datatype for MindRecord to recognize.
|
||||
|
||||
Args:
|
||||
value: the TFRecord data type
|
||||
|
||||
Returns:
|
||||
str, which is MindRecord field type.
|
||||
"""
|
||||
tf_type_to_mr_type = {tf.string: "string",
|
||||
tf.int8: "int32",
|
||||
tf.int16: "int32",
|
||||
tf.int32: "int32",
|
||||
tf.int64: "int64",
|
||||
tf.uint8: "int32",
|
||||
tf.uint16: "int32",
|
||||
tf.uint32: "int64",
|
||||
tf.uint64: "int64",
|
||||
tf.float16: "float32",
|
||||
tf.float32: "float32",
|
||||
tf.float64: "float64",
|
||||
tf.double: "float64",
|
||||
tf.bool: "int32"}
|
||||
unsupport_tf_type_to_mr_type = {tf.complex64: "None",
|
||||
tf.complex128: "None"}
|
||||
|
||||
if value in tf_type_to_mr_type:
|
||||
return tf_type_to_mr_type[value]
|
||||
|
||||
raise ValueError("Type " + value + " is not supported in MindRecord.")
|
||||
|
||||
def _cast_string_type_to_np_type(value):
|
||||
"""Cast string type like: int32/int64/float32/float64 to np.int32/np.int64/np.float32/np.float64"""
|
||||
string_type_to_np_type = {"int32": np.int32,
|
||||
"int64": np.int64,
|
||||
"float32": np.float32,
|
||||
"float64": np.float64}
|
||||
|
||||
if value in string_type_to_np_type:
|
||||
return string_type_to_np_type[value]
|
||||
|
||||
raise ValueError("Type " + value + " is not supported cast to numpy type in MindRecord.")
|
||||
|
||||
def _cast_name(key):
|
||||
"""
|
||||
Cast schema names which containing special characters to valid names.
|
||||
|
||||
Here special characters means any characters in
|
||||
'!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~
|
||||
Valid names can only contain a-z, A-Z, and 0-9 and _
|
||||
|
||||
Args:
|
||||
key (str): original key that might contains special characters.
|
||||
|
||||
Returns:
|
||||
str, casted key that replace the special characters with "_". i.e. if
|
||||
key is "a b" then returns "a_b".
|
||||
"""
|
||||
special_symbols = set('{}{}'.format(punctuation, ' '))
|
||||
special_symbols.remove('_')
|
||||
new_key = ['_' if x in special_symbols else x for x in key]
|
||||
casted_key = ''.join(new_key)
|
||||
return casted_key
|
||||
|
||||
class TFRecordToMR:
|
||||
"""
|
||||
Class is for tranformation from TFRecord to MindRecord.
|
||||
|
||||
Args:
|
||||
source (str): the TFRecord file to be transformed.
|
||||
destination (str): the MindRecord file path to tranform into.
|
||||
feature_dict (dict): a dictionary than states the feature type, i.e.
|
||||
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string),
|
||||
"yyyy": tf.io.FixedLenFeature([], tf.int64)}
|
||||
****** follow case which uses VarLenFeature not support ******
|
||||
feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string),
|
||||
"yyyy": tf.io.VarLenFeature(tf.int64)},
|
||||
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
|
||||
bytes_fields (list): the bytes fields which are in feature_dict.
|
||||
|
||||
Rasies:
|
||||
ValueError, when:
|
||||
1) parameter TFRecord is not string.
|
||||
2) parameter MindRecord is not string.
|
||||
3) feature_dict is not FixedLenFeature.
|
||||
4) parameter bytes_field is not list(str) or not in feature_dict
|
||||
Exception, when tensorflow module not found or version is not correct.
|
||||
"""
|
||||
def __init__(self, source, destination, feature_dict, bytes_fields=None):
|
||||
if not tf:
|
||||
raise Exception("Module tensorflow is not found, please use pip install it.")
|
||||
|
||||
if tf.__version__ < SupportedTensorFlowVersion:
|
||||
raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion))
|
||||
|
||||
if not isinstance(source, str):
|
||||
raise ValueError("Parameter source must be string.")
|
||||
check_filename(source)
|
||||
|
||||
if not isinstance(destination, str):
|
||||
raise ValueError("Parameter destination must be string.")
|
||||
check_filename(destination)
|
||||
|
||||
self.source = source
|
||||
self.destination = destination
|
||||
|
||||
if feature_dict is None or not isinstance(feature_dict, dict):
|
||||
raise ValueError("Parameter feature_dict is None or not dict.")
|
||||
|
||||
for key, val in feature_dict.items():
|
||||
if not isinstance(val, tf.io.FixedLenFeature):
|
||||
raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict))
|
||||
|
||||
self.feature_dict = feature_dict
|
||||
|
||||
bytes_fields_list = []
|
||||
if bytes_fields:
|
||||
if not isinstance(bytes_fields, list):
|
||||
raise ValueError("Parameter bytes_fields: {} must be list(str).".format(bytes_fields))
|
||||
for item in bytes_fields:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError("Parameter bytes_fields's item: {} is not str.".format(item))
|
||||
|
||||
if item not in self.feature_dict:
|
||||
raise ValueError("Parameter bytes_fields's item: {} is not in feature_dict: {}."
|
||||
.format(item, self.feature_dict))
|
||||
|
||||
if not isinstance(self.feature_dict[item].shape, list):
|
||||
raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item))
|
||||
|
||||
casted_bytes_field = _cast_name(item)
|
||||
bytes_fields_list.append(casted_bytes_field)
|
||||
|
||||
self.bytes_fields_list = bytes_fields_list
|
||||
self.scalar_set = set()
|
||||
self.list_set = set()
|
||||
|
||||
mindrecord_schema = {}
|
||||
for key, val in self.feature_dict.items():
|
||||
if not val.shape:
|
||||
self.scalar_set.add(_cast_name(key))
|
||||
if key in self.bytes_fields_list:
|
||||
mindrecord_schema[_cast_name(key)] = {"type": "bytes"}
|
||||
else:
|
||||
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype)}
|
||||
else:
|
||||
if len(val.shape) != 1:
|
||||
raise ValueError("Parameter len(feature_dict[{}].shape) should be 1.")
|
||||
if val.shape[0] < 1:
|
||||
raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key))
|
||||
if val.dtype == tf.string:
|
||||
raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] \
|
||||
is not None. It is not supported.".format(key))
|
||||
self.list_set.add(_cast_name(key))
|
||||
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype), "shape": [val.shape[0]]}
|
||||
self.mindrecord_schema = mindrecord_schema
|
||||
|
||||
def _parse_record(self, example):
|
||||
"""Returns features for a single example"""
|
||||
features = tf.io.parse_single_example(example, features=self.feature_dict)
|
||||
return features
|
||||
|
||||
def _get_data_when_scalar_field(self, ms_dict, cast_key, key, val):
|
||||
"""put data in ms_dict when field type is string"""
|
||||
if isinstance(val.numpy(), (np.ndarray, list)):
|
||||
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val))
|
||||
if self.feature_dict[key].dtype == tf.string:
|
||||
if cast_key in self.bytes_fields_list:
|
||||
ms_dict[cast_key] = val.numpy()
|
||||
else:
|
||||
ms_dict[cast_key] = str(val.numpy(), encoding="utf-8")
|
||||
elif _cast_type(self.feature_dict[key].dtype).startswith("int"):
|
||||
ms_dict[cast_key] = int(val.numpy())
|
||||
else:
|
||||
ms_dict[cast_key] = float(val.numpy())
|
||||
|
||||
def tfrecord_iterator(self):
|
||||
"""Yield a dict with key to be fields in schema, and value to be data."""
|
||||
dataset = tf.data.TFRecordDataset(self.source)
|
||||
dataset = dataset.map(self._parse_record)
|
||||
iterator = dataset.__iter__()
|
||||
index_id = 0
|
||||
try:
|
||||
for features in iterator:
|
||||
ms_dict = {}
|
||||
index_id = index_id + 1
|
||||
for key, val in features.items():
|
||||
cast_key = _cast_name(key)
|
||||
if key in self.scalar_set:
|
||||
self._get_data_when_scalar_field(ms_dict, cast_key, key, val)
|
||||
else:
|
||||
if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list):
|
||||
raise ValueError("he response key: {}, value: {} from TFRecord should be a ndarray or list."
|
||||
.format(key, val))
|
||||
# list set
|
||||
ms_dict[cast_key] = \
|
||||
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
|
||||
yield ms_dict
|
||||
except tf.errors.InvalidArgumentError:
|
||||
raise ValueError("TFRecord feature_dict parameter error.")
|
||||
|
||||
def transform(self):
|
||||
"""
|
||||
Executes transform from TFRecord to MindRecord.
|
||||
|
||||
Returns:
|
||||
SUCCESS/FAILED, whether successfuly written into MindRecord.
|
||||
"""
|
||||
writer = FileWriter(self.destination)
|
||||
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"
|
||||
.format(self.mindrecord_schema, self.feature_dict))
|
||||
|
||||
writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord")
|
||||
|
||||
tf_iter = self.tfrecord_iterator()
|
||||
batch_size = 256
|
||||
|
||||
transform_count = 0
|
||||
while True:
|
||||
data_list = []
|
||||
try:
|
||||
for _ in range(batch_size):
|
||||
data_list.append(tf_iter.__next__())
|
||||
transform_count += 1
|
||||
|
||||
writer.write_raw_data(data_list)
|
||||
logger.info("Transformed {} records...".format(transform_count))
|
||||
except StopIteration:
|
||||
if data_list:
|
||||
writer.write_raw_data(data_list)
|
||||
logger.info("Transformed {} records...".format(transform_count))
|
||||
break
|
||||
return writer.commit()
|
|
@ -0,0 +1 @@
|
|||
## tfrecord file dir
|
|
@ -0,0 +1,400 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""test tfrecord to mindrecord tool"""
|
||||
import collections
|
||||
from importlib import import_module
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import log as logger
|
||||
from mindspore.mindrecord import FileReader
|
||||
from mindspore.mindrecord import TFRecordToMR
|
||||
|
||||
SupportedTensorFlowVersion = '2.1.0'
|
||||
|
||||
try:
|
||||
tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
|
||||
except ModuleNotFoundError:
|
||||
logger.warning("tensorflow module not found.")
|
||||
tf = None
|
||||
|
||||
TFRECORD_DATA_DIR = "../data/mindrecord/testTFRecordData"
|
||||
TFRECORD_FILE_NAME = "test.tfrecord"
|
||||
MINDRECORD_FILE_NAME = "test.mindrecord"
|
||||
PARTITION_NUM = 1
|
||||
|
||||
def verify_data(transformer, reader):
|
||||
"""Verify the data by read from mindrecord"""
|
||||
tf_iter = transformer.tfrecord_iterator()
|
||||
mr_iter = reader.get_next()
|
||||
|
||||
count = 0
|
||||
for tf_item, mr_item in zip(tf_iter, mr_iter):
|
||||
count = count + 1
|
||||
assert len(tf_item) == 6
|
||||
assert len(mr_item) == 6
|
||||
for key, value in tf_item.items():
|
||||
logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value, mr_item[key]))
|
||||
if isinstance(value, np.ndarray):
|
||||
assert (value == mr_item[key]).all()
|
||||
else:
|
||||
assert value == mr_item[key]
|
||||
assert count == 10
|
||||
|
||||
def generate_tfrecord():
|
||||
def create_int_feature(values):
|
||||
if isinstance(values, list):
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int]
|
||||
else:
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) # values: int
|
||||
return feature
|
||||
|
||||
def create_float_feature(values):
|
||||
if isinstance(values, list):
|
||||
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) # values: [float, float]
|
||||
else:
|
||||
feature = tf.train.Feature(float_list=tf.train.FloatList(value=[values])) # values: float
|
||||
return feature
|
||||
|
||||
def create_bytes_feature(values):
|
||||
if isinstance(values, bytes):
|
||||
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) # values: bytes
|
||||
else:
|
||||
# values: string
|
||||
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
|
||||
return feature
|
||||
|
||||
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
example_count = 0
|
||||
for i in range(10):
|
||||
file_name = "000" + str(i) + ".jpg"
|
||||
image_bytes = bytes(str("aaaabbbbcccc" + str(i)), encoding="utf-8")
|
||||
int64_scalar = i
|
||||
float_scalar = float(i)
|
||||
int64_list = [i, i+1, i+2, i+3, i+4, i+1234567890]
|
||||
float_list = [float(i), float(i+1), float(i+2.8), float(i+3.2),
|
||||
float(i+4.4), float(i+123456.9), float(i+98765432.1)]
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["file_name"] = create_bytes_feature(file_name)
|
||||
features["image_bytes"] = create_bytes_feature(image_bytes)
|
||||
features["int64_scalar"] = create_int_feature(int64_scalar)
|
||||
features["float_scalar"] = create_float_feature(float_scalar)
|
||||
features["int64_list"] = create_int_feature(int64_list)
|
||||
features["float_list"] = create_float_feature(float_list)
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
writer.write(tf_example.SerializeToString())
|
||||
example_count += 1
|
||||
writer.close()
|
||||
logger.info("Write {} rows in tfrecord.".format(example_count))
|
||||
|
||||
def test_tfrecord_to_mindrecord():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
"int64_scalar": tf.io.FixedLenFeature([], tf.int64),
|
||||
"float_scalar": tf.io.FixedLenFeature([], tf.float32),
|
||||
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
||||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME)
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
|
||||
verify_data(tfrecord_transformer, fr_mindrecord)
|
||||
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
def test_tfrecord_to_mindrecord_scalar_with_1():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
|
||||
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
|
||||
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
||||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME)
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
|
||||
verify_data(tfrecord_transformer, fr_mindrecord)
|
||||
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
def test_tfrecord_to_mindrecord_scalar_with_1_list_small_len_exception():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
|
||||
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
|
||||
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
||||
"float_list": tf.io.FixedLenFeature([2], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
def test_tfrecord_to_mindrecord_list_with_diff_type_exception():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
|
||||
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
|
||||
"int64_list": tf.io.FixedLenFeature([6], tf.float32),
|
||||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
def test_tfrecord_to_mindrecord_list_without_bytes_type():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
|
||||
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
|
||||
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
||||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict)
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME)
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
|
||||
verify_data(tfrecord_transformer, fr_mindrecord)
|
||||
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
def test_tfrecord_to_mindrecord_scalar_with_2_exception():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
"int64_scalar": tf.io.FixedLenFeature([2], tf.int64),
|
||||
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
|
||||
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
||||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
def test_tfrecord_to_mindrecord_scalar_string_with_1_exception():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([1], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
|
||||
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
|
||||
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
||||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([10], tf.string),
|
||||
"int64_scalar": tf.io.FixedLenFeature([1], tf.int64),
|
||||
"float_scalar": tf.io.FixedLenFeature([1], tf.float32),
|
||||
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
||||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image_bytes"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
Loading…
Reference in New Issue