diff --git a/mindspore/mindrecord/tools/tfrecord_to_mr.py b/mindspore/mindrecord/tools/tfrecord_to_mr.py index ce5b8400f01..5aad3b075d2 100644 --- a/mindspore/mindrecord/tools/tfrecord_to_mr.py +++ b/mindspore/mindrecord/tools/tfrecord_to_mr.py @@ -30,7 +30,7 @@ except ModuleNotFoundError: __all__ = ['TFRecordToMR'] -SupportedTensorFlowVersion = '2.1.0' +SupportedTensorFlowVersion = '1.13.0-rc1' def _cast_type(value): """ @@ -210,30 +210,84 @@ class TFRecordToMR: else: ms_dict[cast_key] = float(val.numpy()) + def _get_data_when_scalar_field_oldversion(self, ms_dict, cast_key, key, val): + """ + put data in ms_dict when field type is string + However, we have to make change due to the different structure of old version + """ + if isinstance(val, (bytes, str)): + if isinstance(val, (np.ndarray, list)): + raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) + if self.feature_dict[key].dtype == tf.string: + if cast_key in self.bytes_fields_list: + ms_dict[cast_key] = val + else: + ms_dict[cast_key] = val.decode("utf-8") + else: + ms_dict[cast_key] = val + else: + if _cast_type(self.feature_dict[key].dtype).startswith("int"): + ms_dict[cast_key] = int(val) + else: + ms_dict[cast_key] = float(val) + + def tfrecord_iterator_oldversion(self): + """ + Yield a dict with key to be fields in schema, and value to be data. + This function is for old version tensorflow whose version number < 2.1.0 + """ + dataset = tf.data.TFRecordDataset(self.source) + dataset = dataset.map(self._parse_record) + iterator = dataset.make_one_shot_iterator() + with tf.Session() as sess: + while True: + try: + ms_dict = {} + sample = iterator.get_next() + sample = sess.run(sample) + for key, val in sample.items(): + cast_key = _cast_name(key) + if cast_key in self.scalar_set: + self._get_data_when_scalar_field_oldversion(ms_dict, cast_key, key, val) + else: + if not isinstance(val, np.ndarray) and not isinstance(val, list): + raise ValueError("The response key: {}, value: {} from " + "TFRecord should be a ndarray or " + "list.".format(key, val)) + # list set + ms_dict[cast_key] = \ + np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) + yield ms_dict + except tf.errors.OutOfRangeError: + break + except tf.errors.InvalidArgumentError: + raise ValueError("TFRecord feature_dict parameter error.") + def tfrecord_iterator(self): """Yield a dict with key to be fields in schema, and value to be data.""" dataset = tf.data.TFRecordDataset(self.source) dataset = dataset.map(self._parse_record) iterator = dataset.__iter__() - index_id = 0 - try: - for features in iterator: + while True: + try: ms_dict = {} - index_id = index_id + 1 - for key, val in features.items(): + sample = iterator.get_next() + for key, val in sample.items(): cast_key = _cast_name(key) if cast_key in self.scalar_set: self._get_data_when_scalar_field(ms_dict, cast_key, key, val) else: if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list): raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \ - "list.".format(key, val)) + "list.".format(key, val)) # list set ms_dict[cast_key] = \ np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) yield ms_dict - except tf.errors.InvalidArgumentError: - raise ValueError("TFRecord feature_dict parameter error.") + except tf.errors.OutOfRangeError: + break + except tf.errors.InvalidArgumentError: + raise ValueError("TFRecord feature_dict parameter error.") def run(self): """ @@ -247,10 +301,11 @@ class TFRecordToMR: .format(self.mindrecord_schema, self.feature_dict)) writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord") - - tf_iter = self.tfrecord_iterator() + if tf.__version__ < '2.0.0': + tf_iter = self.tfrecord_iterator_oldversion() + else: + tf_iter = self.tfrecord_iterator() batch_size = 256 - transform_count = 0 while True: data_list = [] diff --git a/tests/ut/python/mindrecord/test_tfrecord_to_mr.py b/tests/ut/python/mindrecord/test_tfrecord_to_mr.py index cfd0d53a492..e9f03a9cca6 100644 --- a/tests/ut/python/mindrecord/test_tfrecord_to_mr.py +++ b/tests/ut/python/mindrecord/test_tfrecord_to_mr.py @@ -23,7 +23,7 @@ from mindspore import log as logger from mindspore.mindrecord import FileReader from mindspore.mindrecord import TFRecordToMR -SupportedTensorFlowVersion = '2.1.0' +SupportedTensorFlowVersion = '1.13.0-rc1' try: tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord @@ -58,8 +58,15 @@ def cast_name(key): return casted_key def verify_data(transformer, reader): - """Verify the data by read from mindrecord""" - tf_iter = transformer.tfrecord_iterator() + """ + Verify the data by read from mindrecord + If in 1.x.x version, use old version to receive that iteration + """ + + if tf.__version__ < '2.0.0': + tf_iter = transformer.tfrecord_iterator_oldversion() + else: + tf_iter = transformer.tfrecord_iterator() mr_iter = reader.get_next() count = 0