回退 'Pull Request !6995 : Support the transition of datarecord from 1.x.x version Tensorflow to Mindspore '

This commit is contained in:
xsmq 2020-10-12 14:16:04 +08:00 committed by Gitee
parent d0e29996ec
commit 6364d56ebe
3 changed files with 16 additions and 78 deletions

@ -1 +1 @@
Subproject commit 14db109491bc81473905a5eb9e82f6234aca419b
Subproject commit 7a75f024d5a70c51b6428008587c4125bc015349

View File

@ -30,7 +30,7 @@ except ModuleNotFoundError:
__all__ = ['TFRecordToMR']
SupportedTensorFlowVersion = '1.13.0-rc1'
SupportedTensorFlowVersion = '2.1.0'
def _cast_type(value):
"""
@ -210,69 +210,17 @@ class TFRecordToMR:
else:
ms_dict[cast_key] = float(val.numpy())
def _get_data_when_scalar_field_oldversion(self, ms_dict, cast_key, key, val):
"""
put data in ms_dict when field type is string
However, we have to make change due to the different structure of old version
"""
if isinstance(val, (bytes, str)):
if isinstance(val, (np.ndarray, list)):
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val))
if self.feature_dict[key].dtype == tf.string:
if cast_key in self.bytes_fields_list:
ms_dict[cast_key] = val
else:
ms_dict[cast_key] = val.decode("utf-8")
else:
ms_dict[cast_key] = val
else:
if _cast_type(self.feature_dict[key].dtype).startswith("int"):
ms_dict[cast_key] = int(val)
else:
ms_dict[cast_key] = float(val)
def tfrecord_iterator_oldversion(self):
"""
Yield a dict with key to be fields in schema, and value to be data.
This function is for old version tensorflow whose version number < 2.1.0
"""
dataset = tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
while True:
try:
ms_dict = {}
sample = iterator.get_next()
sample = sess.run(sample)
for key, val in sample.items():
cast_key = _cast_name(key)
if cast_key in self.scalar_set:
self._get_data_when_scalar_field_oldversion(ms_dict, cast_key, key, val)
else:
if not isinstance(val, np.ndarray) and not isinstance(val, list):
raise ValueError("The response key: {}, value: {} from "
"TFRecord should be a ndarray or "
"list.".format(key, val))
# list set
ms_dict[cast_key] = \
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
yield ms_dict
except tf.errors.OutOfRangeError:
break
except tf.errors.InvalidArgumentError:
raise ValueError("TFRecord feature_dict parameter error.")
def tfrecord_iterator(self):
"""Yield a dict with key to be fields in schema, and value to be data."""
dataset = tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record)
iterator = dataset.__iter__()
while True:
index_id = 0
try:
for features in iterator:
ms_dict = {}
sample = iterator.get_next()
for key, val in sample.items():
index_id = index_id + 1
for key, val in features.items():
cast_key = _cast_name(key)
if cast_key in self.scalar_set:
self._get_data_when_scalar_field(ms_dict, cast_key, key, val)
@ -284,8 +232,6 @@ class TFRecordToMR:
ms_dict[cast_key] = \
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
yield ms_dict
except tf.errors.OutOfRangeError:
break
except tf.errors.InvalidArgumentError:
raise ValueError("TFRecord feature_dict parameter error.")
@ -301,11 +247,10 @@ class TFRecordToMR:
.format(self.mindrecord_schema, self.feature_dict))
writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord")
if tf.__version__ < '2.0.0':
tf_iter = self.tfrecord_iterator_oldversion()
else:
tf_iter = self.tfrecord_iterator()
batch_size = 256
transform_count = 0
while True:
data_list = []

View File

@ -23,7 +23,7 @@ from mindspore import log as logger
from mindspore.mindrecord import FileReader
from mindspore.mindrecord import TFRecordToMR
SupportedTensorFlowVersion = '1.13.0-rc1'
SupportedTensorFlowVersion = '2.1.0'
try:
tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
@ -58,14 +58,7 @@ def cast_name(key):
return casted_key
def verify_data(transformer, reader):
"""
Verify the data by read from mindrecord
If in 1.x.x version, use old version to receive that iteration
"""
if tf.__version__ < '2.0.0':
tf_iter = transformer.tfrecord_iterator_oldversion()
else:
"""Verify the data by read from mindrecord"""
tf_iter = transformer.tfrecord_iterator()
mr_iter = reader.get_next()