diff --git a/mindspore/mindrecord/tools/mnist_to_mr.py b/mindspore/mindrecord/tools/mnist_to_mr.py index 462ab7fb53b..046788535dc 100644 --- a/mindspore/mindrecord/tools/mnist_to_mr.py +++ b/mindspore/mindrecord/tools/mnist_to_mr.py @@ -77,20 +77,20 @@ class MnistToMR: self.mnist_schema_json = {"label": {"type": "int64"}, "data": {"type": "bytes"}} - def _extract_images(self, filename, num_images): + def _extract_images(self, filename): """Extract the images into a 4D tensor [image index, y, x, channels].""" with gzip.open(filename) as bytestream: bytestream.read(16) - buf = bytestream.read(self.image_size * self.image_size * num_images * self.num_channels) + buf = bytestream.read() data = np.frombuffer(buf, dtype=np.uint8) - data = data.reshape(num_images, self.image_size, self.image_size, self.num_channels) + data = data.reshape(-1, self.image_size, self.image_size, self.num_channels) return data - def _extract_labels(self, filename, num_images): + def _extract_labels(self, filename): """Extract the labels into a vector of int64 label IDs.""" with gzip.open(filename) as bytestream: bytestream.read(8) - buf = bytestream.read(1 * num_images) + buf = bytestream.read() labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64) return labels @@ -101,8 +101,8 @@ class MnistToMR: Yields: data (dict of list): mnist data list which contains dict. """ - train_data = self._extract_images(self.train_data_filename_, 60000) - train_labels = self._extract_labels(self.train_labels_filename_, 60000) + train_data = self._extract_images(self.train_data_filename_) + train_labels = self._extract_labels(self.train_labels_filename_) for data, label in zip(train_data, train_labels): _, img = cv2.imencode(".jpeg", data) yield {"label": int(label), "data": img.tobytes()} @@ -114,8 +114,8 @@ class MnistToMR: Yields: data (dict of list): mnist data list which contains dict. """ - test_data = self._extract_images(self.test_data_filename_, 10000) - test_labels = self._extract_labels(self.test_labels_filename_, 10000) + test_data = self._extract_images(self.test_data_filename_) + test_labels = self._extract_labels(self.test_labels_filename_) for data, label in zip(test_data, test_labels): _, img = cv2.imencode(".jpeg", data) yield {"label": int(label), "data": img.tobytes()} diff --git a/tests/ut/data/mindrecord/testMnistData/t10k-images-idx3-ubyte.gz b/tests/ut/data/mindrecord/testMnistData/t10k-images-idx3-ubyte.gz index 9fdddeebe95..d7a2ea5de2a 100644 Binary files a/tests/ut/data/mindrecord/testMnistData/t10k-images-idx3-ubyte.gz and b/tests/ut/data/mindrecord/testMnistData/t10k-images-idx3-ubyte.gz differ diff --git a/tests/ut/data/mindrecord/testMnistData/t10k-labels-idx1-ubyte.gz b/tests/ut/data/mindrecord/testMnistData/t10k-labels-idx1-ubyte.gz index c8a68516600..6925ee8ce4a 100644 Binary files a/tests/ut/data/mindrecord/testMnistData/t10k-labels-idx1-ubyte.gz and b/tests/ut/data/mindrecord/testMnistData/t10k-labels-idx1-ubyte.gz differ diff --git a/tests/ut/data/mindrecord/testMnistData/train-images-idx3-ubyte.gz b/tests/ut/data/mindrecord/testMnistData/train-images-idx3-ubyte.gz index 4f27a302031..80c13bf9a37 100644 Binary files a/tests/ut/data/mindrecord/testMnistData/train-images-idx3-ubyte.gz and b/tests/ut/data/mindrecord/testMnistData/train-images-idx3-ubyte.gz differ diff --git a/tests/ut/data/mindrecord/testMnistData/train-labels-idx1-ubyte.gz b/tests/ut/data/mindrecord/testMnistData/train-labels-idx1-ubyte.gz index abc30a7c68f..61cc616763c 100644 Binary files a/tests/ut/data/mindrecord/testMnistData/train-labels-idx1-ubyte.gz and b/tests/ut/data/mindrecord/testMnistData/train-labels-idx1-ubyte.gz differ diff --git a/tests/ut/python/mindrecord/test_mindrecord_base.py b/tests/ut/python/mindrecord/test_mindrecord_base.py index 576063295ae..7fdf1f0f94b 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_base.py +++ b/tests/ut/python/mindrecord/test_mindrecord_base.py @@ -203,9 +203,9 @@ def test_nlp_page_reader_tutorial(): os.remove("{}".format(x)) os.remove("{}.db".format(x)) -def test_cv_file_writer_shard_num_1000(): - """test file writer when shard num equals 1000.""" - writer = FileWriter(CV_FILE_NAME, 1000) +def test_cv_file_writer_shard_num_10(): + """test file writer when shard num equals 10.""" + writer = FileWriter(CV_FILE_NAME, 10) data = get_data("../data/mindrecord/testImageNetData/") cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int64"}, "data": {"type": "bytes"}} @@ -214,8 +214,8 @@ def test_cv_file_writer_shard_num_1000(): writer.write_raw_data(data) writer.commit() - paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(3, '0')) - for x in range(1000)] + paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) + for x in range(10)] for x in paths: os.remove("{}".format(x)) os.remove("{}.db".format(x)) diff --git a/tests/ut/python/mindrecord/test_mnist_to_mr.py b/tests/ut/python/mindrecord/test_mnist_to_mr.py index c299a1f7192..505b0d6b432 100644 --- a/tests/ut/python/mindrecord/test_mnist_to_mr.py +++ b/tests/ut/python/mindrecord/test_mnist_to_mr.py @@ -37,7 +37,7 @@ def read(train_name, test_name): count = count + 1 if count == 1: logger.info("data: {}".format(x)) - assert count == 60000 + assert count == 20 reader.close() count = 0 @@ -47,7 +47,7 @@ def read(train_name, test_name): count = count + 1 if count == 1: logger.info("data: {}".format(x)) - assert count == 10000 + assert count == 10 reader.close() @@ -102,10 +102,10 @@ def test_mnist_to_mindrecord_compare_data(): 't10k-images-idx3-ubyte.gz') test_labels_filename_ = os.path.join(MNIST_DIR, 't10k-labels-idx1-ubyte.gz') - train_data = _extract_images(train_data_filename_, 60000) - train_labels = _extract_labels(train_labels_filename_, 60000) - test_data = _extract_images(test_data_filename_, 10000) - test_labels = _extract_labels(test_labels_filename_, 10000) + train_data = _extract_images(train_data_filename_, 20) + train_labels = _extract_labels(train_labels_filename_, 20) + test_data = _extract_images(test_data_filename_, 10) + test_labels = _extract_labels(test_labels_filename_, 10) reader = FileReader(train_name) for x, data, label in zip(reader.get_next(), train_data, train_labels):