!7195 Transition tool for TensorRecord to MindRecord, support 1.13.0-rc1 and higher version.

Merge pull request !7195 from lizhenglong1992/TR_to_MR
This commit is contained in:
mindspore-ci-bot 2020-10-13 10:14:31 +08:00 committed by Gitee
commit 2cb27c59e0
2 changed files with 77 additions and 15 deletions

View File

@ -30,7 +30,7 @@ except ModuleNotFoundError:
__all__ = ['TFRecordToMR']
SupportedTensorFlowVersion = '2.1.0'
SupportedTensorFlowVersion = '1.13.0-rc1'
def _cast_type(value):
"""
@ -210,17 +210,69 @@ class TFRecordToMR:
else:
ms_dict[cast_key] = float(val.numpy())
def _get_data_when_scalar_field_oldversion(self, ms_dict, cast_key, key, val):
"""
put data in ms_dict when field type is string
However, we have to make change due to the different structure of old version
"""
if isinstance(val, (bytes, str)):
if isinstance(val, (np.ndarray, list)):
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val))
if self.feature_dict[key].dtype == tf.string:
if cast_key in self.bytes_fields_list:
ms_dict[cast_key] = val
else:
ms_dict[cast_key] = val.decode("utf-8")
else:
ms_dict[cast_key] = val
else:
if _cast_type(self.feature_dict[key].dtype).startswith("int"):
ms_dict[cast_key] = int(val)
else:
ms_dict[cast_key] = float(val)
def tfrecord_iterator_oldversion(self):
"""
Yield a dict with key to be fields in schema, and value to be data.
This function is for old version tensorflow whose version number < 2.1.0
"""
dataset = tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
while True:
try:
ms_dict = {}
sample = iterator.get_next()
sample = sess.run(sample)
for key, val in sample.items():
cast_key = _cast_name(key)
if cast_key in self.scalar_set:
self._get_data_when_scalar_field_oldversion(ms_dict, cast_key, key, val)
else:
if not isinstance(val, np.ndarray) and not isinstance(val, list):
raise ValueError("The response key: {}, value: {} from "
"TFRecord should be a ndarray or "
"list.".format(key, val))
# list set
ms_dict[cast_key] = \
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
yield ms_dict
except tf.errors.OutOfRangeError:
break
except tf.errors.InvalidArgumentError:
raise ValueError("TFRecord feature_dict parameter error.")
def tfrecord_iterator(self):
"""Yield a dict with key to be fields in schema, and value to be data."""
dataset = tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record)
iterator = dataset.__iter__()
index_id = 0
while True:
try:
for features in iterator:
ms_dict = {}
index_id = index_id + 1
for key, val in features.items():
sample = iterator.get_next()
for key, val in sample.items():
cast_key = _cast_name(key)
if cast_key in self.scalar_set:
self._get_data_when_scalar_field(ms_dict, cast_key, key, val)
@ -232,6 +284,8 @@ class TFRecordToMR:
ms_dict[cast_key] = \
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
yield ms_dict
except tf.errors.OutOfRangeError:
break
except tf.errors.InvalidArgumentError:
raise ValueError("TFRecord feature_dict parameter error.")
@ -247,10 +301,11 @@ class TFRecordToMR:
.format(self.mindrecord_schema, self.feature_dict))
writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord")
if tf.__version__ < '2.0.0':
tf_iter = self.tfrecord_iterator_oldversion()
else:
tf_iter = self.tfrecord_iterator()
batch_size = 256
transform_count = 0
while True:
data_list = []

View File

@ -23,7 +23,7 @@ from mindspore import log as logger
from mindspore.mindrecord import FileReader
from mindspore.mindrecord import TFRecordToMR
SupportedTensorFlowVersion = '2.1.0'
SupportedTensorFlowVersion = '1.13.0-rc1'
try:
tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
@ -58,7 +58,14 @@ def cast_name(key):
return casted_key
def verify_data(transformer, reader):
"""Verify the data by read from mindrecord"""
"""
Verify the data by read from mindrecord
If in 1.x.x version, use old version to receive that iteration
"""
if tf.__version__ < '2.0.0':
tf_iter = transformer.tfrecord_iterator_oldversion()
else:
tf_iter = transformer.tfrecord_iterator()
mr_iter = reader.get_next()