forked from mindspore-Ecosystem/mindspore
!7195 Transition tool for TensorRecord to MindRecord, support 1.13.0-rc1 and higher version.
Merge pull request !7195 from lizhenglong1992/TR_to_MR
This commit is contained in:
commit
2cb27c59e0
|
@ -30,7 +30,7 @@ except ModuleNotFoundError:
|
|||
|
||||
__all__ = ['TFRecordToMR']
|
||||
|
||||
SupportedTensorFlowVersion = '2.1.0'
|
||||
SupportedTensorFlowVersion = '1.13.0-rc1'
|
||||
|
||||
def _cast_type(value):
|
||||
"""
|
||||
|
@ -210,17 +210,69 @@ class TFRecordToMR:
|
|||
else:
|
||||
ms_dict[cast_key] = float(val.numpy())
|
||||
|
||||
def _get_data_when_scalar_field_oldversion(self, ms_dict, cast_key, key, val):
|
||||
"""
|
||||
put data in ms_dict when field type is string
|
||||
However, we have to make change due to the different structure of old version
|
||||
"""
|
||||
if isinstance(val, (bytes, str)):
|
||||
if isinstance(val, (np.ndarray, list)):
|
||||
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val))
|
||||
if self.feature_dict[key].dtype == tf.string:
|
||||
if cast_key in self.bytes_fields_list:
|
||||
ms_dict[cast_key] = val
|
||||
else:
|
||||
ms_dict[cast_key] = val.decode("utf-8")
|
||||
else:
|
||||
ms_dict[cast_key] = val
|
||||
else:
|
||||
if _cast_type(self.feature_dict[key].dtype).startswith("int"):
|
||||
ms_dict[cast_key] = int(val)
|
||||
else:
|
||||
ms_dict[cast_key] = float(val)
|
||||
|
||||
def tfrecord_iterator_oldversion(self):
|
||||
"""
|
||||
Yield a dict with key to be fields in schema, and value to be data.
|
||||
This function is for old version tensorflow whose version number < 2.1.0
|
||||
"""
|
||||
dataset = tf.data.TFRecordDataset(self.source)
|
||||
dataset = dataset.map(self._parse_record)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
with tf.Session() as sess:
|
||||
while True:
|
||||
try:
|
||||
ms_dict = {}
|
||||
sample = iterator.get_next()
|
||||
sample = sess.run(sample)
|
||||
for key, val in sample.items():
|
||||
cast_key = _cast_name(key)
|
||||
if cast_key in self.scalar_set:
|
||||
self._get_data_when_scalar_field_oldversion(ms_dict, cast_key, key, val)
|
||||
else:
|
||||
if not isinstance(val, np.ndarray) and not isinstance(val, list):
|
||||
raise ValueError("The response key: {}, value: {} from "
|
||||
"TFRecord should be a ndarray or "
|
||||
"list.".format(key, val))
|
||||
# list set
|
||||
ms_dict[cast_key] = \
|
||||
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
|
||||
yield ms_dict
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
except tf.errors.InvalidArgumentError:
|
||||
raise ValueError("TFRecord feature_dict parameter error.")
|
||||
|
||||
def tfrecord_iterator(self):
|
||||
"""Yield a dict with key to be fields in schema, and value to be data."""
|
||||
dataset = tf.data.TFRecordDataset(self.source)
|
||||
dataset = dataset.map(self._parse_record)
|
||||
iterator = dataset.__iter__()
|
||||
index_id = 0
|
||||
while True:
|
||||
try:
|
||||
for features in iterator:
|
||||
ms_dict = {}
|
||||
index_id = index_id + 1
|
||||
for key, val in features.items():
|
||||
sample = iterator.get_next()
|
||||
for key, val in sample.items():
|
||||
cast_key = _cast_name(key)
|
||||
if cast_key in self.scalar_set:
|
||||
self._get_data_when_scalar_field(ms_dict, cast_key, key, val)
|
||||
|
@ -232,6 +284,8 @@ class TFRecordToMR:
|
|||
ms_dict[cast_key] = \
|
||||
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
|
||||
yield ms_dict
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
except tf.errors.InvalidArgumentError:
|
||||
raise ValueError("TFRecord feature_dict parameter error.")
|
||||
|
||||
|
@ -247,10 +301,11 @@ class TFRecordToMR:
|
|||
.format(self.mindrecord_schema, self.feature_dict))
|
||||
|
||||
writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord")
|
||||
|
||||
if tf.__version__ < '2.0.0':
|
||||
tf_iter = self.tfrecord_iterator_oldversion()
|
||||
else:
|
||||
tf_iter = self.tfrecord_iterator()
|
||||
batch_size = 256
|
||||
|
||||
transform_count = 0
|
||||
while True:
|
||||
data_list = []
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore import log as logger
|
|||
from mindspore.mindrecord import FileReader
|
||||
from mindspore.mindrecord import TFRecordToMR
|
||||
|
||||
SupportedTensorFlowVersion = '2.1.0'
|
||||
SupportedTensorFlowVersion = '1.13.0-rc1'
|
||||
|
||||
try:
|
||||
tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
|
||||
|
@ -58,7 +58,14 @@ def cast_name(key):
|
|||
return casted_key
|
||||
|
||||
def verify_data(transformer, reader):
|
||||
"""Verify the data by read from mindrecord"""
|
||||
"""
|
||||
Verify the data by read from mindrecord
|
||||
If in 1.x.x version, use old version to receive that iteration
|
||||
"""
|
||||
|
||||
if tf.__version__ < '2.0.0':
|
||||
tf_iter = transformer.tfrecord_iterator_oldversion()
|
||||
else:
|
||||
tf_iter = transformer.tfrecord_iterator()
|
||||
mr_iter = reader.get_next()
|
||||
|
||||
|
|
Loading…
Reference in New Issue