forked from mindspore-Ecosystem/mindspore
fix: tfrecord to mindrecord parameter check
This commit is contained in:
parent
6ef1a731db
commit
b3346a98b9
|
@ -113,7 +113,7 @@ class TFRecordToMR:
|
|||
feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string), \
|
||||
"yyyy": tf.io.VarLenFeature(tf.int64)}, \
|
||||
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
|
||||
bytes_fields (list): the bytes fields which are in feature_dict.
|
||||
bytes_fields (list, optional): the bytes fields which are in feature_dict and can be images bytes.
|
||||
|
||||
Raises:
|
||||
ValueError: If parameter is invalid.
|
||||
|
@ -147,7 +147,7 @@ class TFRecordToMR:
|
|||
self.feature_dict = feature_dict
|
||||
|
||||
bytes_fields_list = []
|
||||
if bytes_fields:
|
||||
if bytes_fields is not None:
|
||||
if not isinstance(bytes_fields, list):
|
||||
raise ValueError("Parameter bytes_fields: {} must be list(str).".format(bytes_fields))
|
||||
for item in bytes_fields:
|
||||
|
@ -161,6 +161,9 @@ class TFRecordToMR:
|
|||
if not isinstance(self.feature_dict[item].shape, list):
|
||||
raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item))
|
||||
|
||||
if self.feature_dict[item].dtype != tf.string:
|
||||
raise ValueError("Parameter bytes_field: {} should be tf.string in feature_dict.".format(item))
|
||||
|
||||
casted_bytes_field = _cast_name(item)
|
||||
bytes_fields_list.append(casted_bytes_field)
|
||||
|
||||
|
@ -172,7 +175,7 @@ class TFRecordToMR:
|
|||
for key, val in self.feature_dict.items():
|
||||
if not val.shape:
|
||||
self.scalar_set.add(_cast_name(key))
|
||||
if key in self.bytes_fields_list:
|
||||
if _cast_name(key) in self.bytes_fields_list:
|
||||
mindrecord_schema[_cast_name(key)] = {"type": "bytes"}
|
||||
else:
|
||||
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype)}
|
||||
|
@ -182,8 +185,8 @@ class TFRecordToMR:
|
|||
if val.shape[0] < 1:
|
||||
raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key))
|
||||
if val.dtype == tf.string:
|
||||
raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] \
|
||||
is not None. It is not supported.".format(key))
|
||||
raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] " \
|
||||
"is not None. It is not supported.".format(key))
|
||||
self.list_set.add(_cast_name(key))
|
||||
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype), "shape": [val.shape[0]]}
|
||||
self.mindrecord_schema = mindrecord_schema
|
||||
|
@ -219,12 +222,12 @@ class TFRecordToMR:
|
|||
index_id = index_id + 1
|
||||
for key, val in features.items():
|
||||
cast_key = _cast_name(key)
|
||||
if key in self.scalar_set:
|
||||
if cast_key in self.scalar_set:
|
||||
self._get_data_when_scalar_field(ms_dict, cast_key, key, val)
|
||||
else:
|
||||
if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list):
|
||||
raise ValueError("he response key: {}, value: {} from TFRecord should be a ndarray or list."
|
||||
.format(key, val))
|
||||
raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \
|
||||
"list.".format(key, val))
|
||||
# list set
|
||||
ms_dict[cast_key] = \
|
||||
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import collections
|
||||
from importlib import import_module
|
||||
import os
|
||||
from string import punctuation
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -35,6 +36,27 @@ TFRECORD_FILE_NAME = "test.tfrecord"
|
|||
MINDRECORD_FILE_NAME = "test.mindrecord"
|
||||
PARTITION_NUM = 1
|
||||
|
||||
def cast_name(key):
|
||||
"""
|
||||
Cast schema names which containing special characters to valid names.
|
||||
|
||||
Here special characters means any characters in
|
||||
'!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~
|
||||
Valid names can only contain a-z, A-Z, and 0-9 and _
|
||||
|
||||
Args:
|
||||
key (str): original key that might contains special characters.
|
||||
|
||||
Returns:
|
||||
str, casted key that replace the special characters with "_". i.e. if
|
||||
key is "a b" then returns "a_b".
|
||||
"""
|
||||
special_symbols = set('{}{}'.format(punctuation, ' '))
|
||||
special_symbols.remove('_')
|
||||
new_key = ['_' if x in special_symbols else x for x in key]
|
||||
casted_key = ''.join(new_key)
|
||||
return casted_key
|
||||
|
||||
def verify_data(transformer, reader):
|
||||
"""Verify the data by read from mindrecord"""
|
||||
tf_iter = transformer.tfrecord_iterator()
|
||||
|
@ -43,14 +65,14 @@ def verify_data(transformer, reader):
|
|||
count = 0
|
||||
for tf_item, mr_item in zip(tf_iter, mr_iter):
|
||||
count = count + 1
|
||||
assert len(tf_item) == 6
|
||||
assert len(mr_item) == 6
|
||||
assert len(tf_item) == len(mr_item)
|
||||
for key, value in tf_item.items():
|
||||
logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value, mr_item[key]))
|
||||
logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value,
|
||||
mr_item[cast_name(key)]))
|
||||
if isinstance(value, np.ndarray):
|
||||
assert (value == mr_item[key]).all()
|
||||
assert (value == mr_item[cast_name(key)]).all()
|
||||
else:
|
||||
assert value == mr_item[key]
|
||||
assert value == mr_item[cast_name(key)]
|
||||
assert count == 10
|
||||
|
||||
def generate_tfrecord():
|
||||
|
@ -102,6 +124,39 @@ def generate_tfrecord():
|
|||
writer.close()
|
||||
logger.info("Write {} rows in tfrecord.".format(example_count))
|
||||
|
||||
def generate_tfrecord_with_special_field_name():
|
||||
def create_int_feature(values):
|
||||
if isinstance(values, list):
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int]
|
||||
else:
|
||||
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) # values: int
|
||||
return feature
|
||||
|
||||
def create_bytes_feature(values):
|
||||
if isinstance(values, bytes):
|
||||
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) # values: bytes
|
||||
else:
|
||||
# values: string
|
||||
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))
|
||||
return feature
|
||||
|
||||
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
example_count = 0
|
||||
for i in range(10):
|
||||
label = i
|
||||
image_bytes = bytes(str("aaaabbbbcccc" + str(i)), encoding="utf-8")
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["image/class/label"] = create_int_feature(label)
|
||||
features["image/encoded"] = create_bytes_feature(image_bytes)
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
writer.write(tf_example.SerializeToString())
|
||||
example_count += 1
|
||||
writer.close()
|
||||
logger.info("Write {} rows in tfrecord.".format(example_count))
|
||||
|
||||
def test_tfrecord_to_mindrecord():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
|
@ -398,3 +453,110 @@ def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception():
|
|||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
"int64_scalar": tf.io.FixedLenFeature([], tf.int64),
|
||||
"float_scalar": tf.io.FixedLenFeature([], tf.float32),
|
||||
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
||||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["int64_list"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string),
|
||||
"image_bytes": tf.io.FixedLenFeature([], tf.string),
|
||||
"int64_scalar": tf.io.FixedLenFeature([], tf.int64),
|
||||
"float_scalar": tf.io.FixedLenFeature([], tf.float32),
|
||||
"int64_list": tf.io.FixedLenFeature([6], tf.int64),
|
||||
"float_list": tf.io.FixedLenFeature([7], tf.float32),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, "")
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
def test_tfrecord_to_mindrecord_with_special_field_name():
|
||||
"""test transform tfrecord to mindrecord."""
|
||||
if not tf or tf.__version__ < SupportedTensorFlowVersion:
|
||||
# skip the test
|
||||
logger.warning("Module tensorflow is not found or version wrong, \
|
||||
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion))
|
||||
return
|
||||
|
||||
generate_tfrecord_with_special_field_name()
|
||||
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
||||
feature_dict = {"image/class/label": tf.io.FixedLenFeature([], tf.int64),
|
||||
"image/encoded": tf.io.FixedLenFeature([], tf.string),
|
||||
}
|
||||
|
||||
if os.path.exists(MINDRECORD_FILE_NAME):
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
if os.path.exists(MINDRECORD_FILE_NAME + ".db"):
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME),
|
||||
MINDRECORD_FILE_NAME, feature_dict, ["image/encoded"])
|
||||
tfrecord_transformer.transform()
|
||||
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME)
|
||||
assert os.path.exists(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME)
|
||||
verify_data(tfrecord_transformer, fr_mindrecord)
|
||||
|
||||
os.remove(MINDRECORD_FILE_NAME)
|
||||
os.remove(MINDRECORD_FILE_NAME + ".db")
|
||||
|
||||
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME))
|
||||
|
|
Loading…
Reference in New Issue