fix: tfrecord to mindrecord parameter check

This commit is contained in:
jonyguo 2020-06-28 15:58:12 +08:00
parent d6d93f16b1
commit 3450c35d9b
2 changed files with 178 additions and 13 deletions

View File

@ -113,7 +113,7 @@ class TFRecordToMR:
feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string), \
"yyyy": tf.io.VarLenFeature(tf.int64)}, \
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
bytes_fields (list): the bytes fields which are in feature_dict.
bytes_fields (list, optional): the bytes fields which are in feature_dict and can be images bytes.
Raises:
ValueError: If parameter is invalid.
@ -147,7 +147,7 @@ class TFRecordToMR:
self.feature_dict = feature_dict
bytes_fields_list = []
if bytes_fields:
if bytes_fields is not None:
if not isinstance(bytes_fields, list):
raise ValueError("Parameter bytes_fields: {} must be list(str).".format(bytes_fields))
for item in bytes_fields:
@ -161,6 +161,9 @@ class TFRecordToMR:
if not isinstance(self.feature_dict[item].shape, list):
raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item))
if self.feature_dict[item].dtype != tf.string:
raise ValueError("Parameter bytes_field: {} should be tf.string in feature_dict.".format(item))
casted_bytes_field = _cast_name(item)
bytes_fields_list.append(casted_bytes_field)
@ -172,7 +175,7 @@ class TFRecordToMR:
for key, val in self.feature_dict.items():
if not val.shape:
self.scalar_set.add(_cast_name(key))
if key in self.bytes_fields_list:
if _cast_name(key) in self.bytes_fields_list:
mindrecord_schema[_cast_name(key)] = {"type": "bytes"}
else:
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype)}
@ -182,8 +185,8 @@ class TFRecordToMR:
if val.shape[0] < 1:
raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key))
if val.dtype == tf.string:
raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] \
is not None. It is not supported.".format(key))
raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] " \
"is not None. It is not supported.".format(key))
self.list_set.add(_cast_name(key))
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype), "shape": [val.shape[0]]}
self.mindrecord_schema = mindrecord_schema
@ -219,12 +222,12 @@ class TFRecordToMR:
index_id = index_id + 1
for key, val in features.items():
cast_key = _cast_name(key)
if key in self.scalar_set:
if cast_key in self.scalar_set:
self._get_data_when_scalar_field(ms_dict, cast_key, key, val)
else:
if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list):
raise ValueError("he response key: {}, value: {} from TFRecord should be a ndarray or list."
.format(key, val))
raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \
"list.".format(key, val))
# list set
ms_dict[cast_key] = \
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))

View File

@ -15,6 +15,7 @@
import collections
from importlib import import_module
import os
from string import punctuation
import numpy as np
import pytest
@ -35,6 +36,27 @@ TFRECORD_FILE_NAME = "test.tfrecord"
MINDRECORD_FILE_NAME = "test.mindrecord"
PARTITION_NUM = 1
def cast_name(key):
"""
Cast schema names which containing special characters to valid names.
Here special characters means any characters in
'!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~
Valid names can only contain a-z, A-Z, and 0-9 and _
Args:
key (str): original key that might contains special characters.
Returns:
str, casted key that replace the special characters with "_". i.e. if
key is "a b" then returns "a_b".
"""
special_symbols = set('{}{}'.format(punctuation, ' '))
special_symbols.remove('_')
new_key = ['_' if x in special_symbols else x for x in key]
casted_key = ''.join(new_key)
return casted_key
def verify_data(transformer, reader):
"""Verify the data by read from mindrecord"""
tf_iter = transformer.tfrecord_iterator()
@ -43,14 +65,14 @@ def verify_data(transformer, reader):
count = 0
for tf_item, mr_item in zip(tf_iter, mr_iter):
count = count + 1
assert len(tf_item) == 6
assert len(mr_item) == 6
assert len(tf_item) == len(mr_item)
for key, value in tf_item.items():
logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value, mr_item[key]))
logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value,
mr_item[cast_name(key)]))
if isinstance(value, np.ndarray):
assert (value == mr_item[key]).all()
assert (value == mr_item[cast_name(key)]).all()
else:
assert value == mr_item[key]
assert value == mr_item[cast_name(key)]
assert count == 10
def generate_tfrecord():
@ -102,6 +124,39 @@ def generate_tfrecord():
writer.close()
logger.info("Write {} rows in tfrecord.".format(example_count))
def generate_tfrecord_with_special_field_name():
def create_int_feature(values):
if isinstance(values, list):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int]
else:
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) # values: int
return feature
def create_bytes_feature(values):
if isinstance(values, bytes):
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) # values: bytes
else:
# values: string
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
return feature
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
example_count = 0
for i in range(10):
label = i
image_bytes = bytes(str("aaaabbbbcccc" + str(i)), encoding="utf-8")
features = collections.OrderedDict()
features["image/class/label"] = create_int_feature(label)
features["image/encoded"] = create_bytes_feature(image_bytes)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
example_count += 1
writer.close()
logger.info("Write {} rows in tfrecord.".format(example_count))
def test_tfrecord_to_mindrecord():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
@ -398,3 +453,110 @@ def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
"int64_scalar": tf.io.FixedLenFeature([], tf.int64),
"float_scalar": tf.io.FixedLenFeature([], tf.float32),
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["int64_list"])
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
"image_bytes": tf.io.FixedLenFeature([], tf.string),
"int64_scalar": tf.io.FixedLenFeature([], tf.int64),
"float_scalar": tf.io.FixedLenFeature([], tf.float32),
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
"float_list": tf.io.FixedLenFeature([7], tf.float32),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
with pytest.raises(ValueError):
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, "")
tfrecord_transformer.transform()
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
def test_tfrecord_to_mindrecord_with_special_field_name():
"""test transform tfrecord to mindrecord."""
if not tf or tf.__version__ < SupportedTensorFlowVersion:
# skip the test
logger.warning("Module tensorflow is not found or version wrong, \
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
return
generate_tfrecord_with_special_field_name()
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
feature_dict = {"image/class/label": tf.io.FixedLenFeature([], tf.int64),
"image/encoded": tf.io.FixedLenFeature([], tf.string),
}
if os.path.exists(MINDRECORD_FILE_NAME):
os.remove(MINDRECORD_FILE_NAME)
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
os.remove(MINDRECORD_FILE_NAME + ".db")
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
MINDRECORD_FILE_NAME, feature_dict, ["image/encoded"])
tfrecord_transformer.transform()
assert os.path.exists(MINDRECORD_FILE_NAME)
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
verify_data(tfrecord_transformer, fr_mindrecord)
os.remove(MINDRECORD_FILE_NAME)
os.remove(MINDRECORD_FILE_NAME + ".db")
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))