forked from mindspore-Ecosystem/mindspore
fix mindrecord ut long time
This commit is contained in:
parent
d8b460c780
commit
5a4f17bfb6
|
@ -77,20 +77,20 @@ class MnistToMR:
|
|||
|
||||
self.mnist_schema_json = {"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
|
||||
def _extract_images(self, filename, num_images):
|
||||
def _extract_images(self, filename):
|
||||
"""Extract the images into a 4D tensor [image index, y, x, channels]."""
|
||||
with gzip.open(filename) as bytestream:
|
||||
bytestream.read(16)
|
||||
buf = bytestream.read(self.image_size * self.image_size * num_images * self.num_channels)
|
||||
buf = bytestream.read()
|
||||
data = np.frombuffer(buf, dtype=np.uint8)
|
||||
data = data.reshape(num_images, self.image_size, self.image_size, self.num_channels)
|
||||
data = data.reshape(-1, self.image_size, self.image_size, self.num_channels)
|
||||
return data
|
||||
|
||||
def _extract_labels(self, filename, num_images):
|
||||
def _extract_labels(self, filename):
|
||||
"""Extract the labels into a vector of int64 label IDs."""
|
||||
with gzip.open(filename) as bytestream:
|
||||
bytestream.read(8)
|
||||
buf = bytestream.read(1 * num_images)
|
||||
buf = bytestream.read()
|
||||
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
|
||||
return labels
|
||||
|
||||
|
@ -101,8 +101,8 @@ class MnistToMR:
|
|||
Yields:
|
||||
data (dict of list): mnist data list which contains dict.
|
||||
"""
|
||||
train_data = self._extract_images(self.train_data_filename_, 60000)
|
||||
train_labels = self._extract_labels(self.train_labels_filename_, 60000)
|
||||
train_data = self._extract_images(self.train_data_filename_)
|
||||
train_labels = self._extract_labels(self.train_labels_filename_)
|
||||
for data, label in zip(train_data, train_labels):
|
||||
_, img = cv2.imencode(".jpeg", data)
|
||||
yield {"label": int(label), "data": img.tobytes()}
|
||||
|
@ -114,8 +114,8 @@ class MnistToMR:
|
|||
Yields:
|
||||
data (dict of list): mnist data list which contains dict.
|
||||
"""
|
||||
test_data = self._extract_images(self.test_data_filename_, 10000)
|
||||
test_labels = self._extract_labels(self.test_labels_filename_, 10000)
|
||||
test_data = self._extract_images(self.test_data_filename_)
|
||||
test_labels = self._extract_labels(self.test_labels_filename_)
|
||||
for data, label in zip(test_data, test_labels):
|
||||
_, img = cv2.imencode(".jpeg", data)
|
||||
yield {"label": int(label), "data": img.tobytes()}
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -203,9 +203,9 @@ def test_nlp_page_reader_tutorial():
|
|||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
def test_cv_file_writer_shard_num_1000():
|
||||
"""test file writer when shard num equals 1000."""
|
||||
writer = FileWriter(CV_FILE_NAME, 1000)
|
||||
def test_cv_file_writer_shard_num_10():
|
||||
"""test file writer when shard num equals 10."""
|
||||
writer = FileWriter(CV_FILE_NAME, 10)
|
||||
data = get_data("../data/mindrecord/testImageNetData/")
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
|
@ -214,8 +214,8 @@ def test_cv_file_writer_shard_num_1000():
|
|||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(3, '0'))
|
||||
for x in range(1000)]
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(10)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
|
|
@ -37,7 +37,7 @@ def read(train_name, test_name):
|
|||
count = count + 1
|
||||
if count == 1:
|
||||
logger.info("data: {}".format(x))
|
||||
assert count == 60000
|
||||
assert count == 20
|
||||
reader.close()
|
||||
|
||||
count = 0
|
||||
|
@ -47,7 +47,7 @@ def read(train_name, test_name):
|
|||
count = count + 1
|
||||
if count == 1:
|
||||
logger.info("data: {}".format(x))
|
||||
assert count == 10000
|
||||
assert count == 10
|
||||
reader.close()
|
||||
|
||||
|
||||
|
@ -102,10 +102,10 @@ def test_mnist_to_mindrecord_compare_data():
|
|||
't10k-images-idx3-ubyte.gz')
|
||||
test_labels_filename_ = os.path.join(MNIST_DIR,
|
||||
't10k-labels-idx1-ubyte.gz')
|
||||
train_data = _extract_images(train_data_filename_, 60000)
|
||||
train_labels = _extract_labels(train_labels_filename_, 60000)
|
||||
test_data = _extract_images(test_data_filename_, 10000)
|
||||
test_labels = _extract_labels(test_labels_filename_, 10000)
|
||||
train_data = _extract_images(train_data_filename_, 20)
|
||||
train_labels = _extract_labels(train_labels_filename_, 20)
|
||||
test_data = _extract_images(test_data_filename_, 10)
|
||||
test_labels = _extract_labels(test_labels_filename_, 10)
|
||||
|
||||
reader = FileReader(train_name)
|
||||
for x, data, label in zip(reader.get_next(), train_data, train_labels):
|
||||
|
|
Loading…
Reference in New Issue