forked from mindspore-Ecosystem/mindspore
!94 enhance: reduce execution time for mindrecord test case
Merge pull request !94 from yanzhenxiang2020/fix_mindrecord_ut_long_time
This commit is contained in:
commit
d245792842
|
@ -77,20 +77,20 @@ class MnistToMR:
|
||||||
|
|
||||||
self.mnist_schema_json = {"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
self.mnist_schema_json = {"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||||
|
|
||||||
def _extract_images(self, filename, num_images):
|
def _extract_images(self, filename):
|
||||||
"""Extract the images into a 4D tensor [image index, y, x, channels]."""
|
"""Extract the images into a 4D tensor [image index, y, x, channels]."""
|
||||||
with gzip.open(filename) as bytestream:
|
with gzip.open(filename) as bytestream:
|
||||||
bytestream.read(16)
|
bytestream.read(16)
|
||||||
buf = bytestream.read(self.image_size * self.image_size * num_images * self.num_channels)
|
buf = bytestream.read()
|
||||||
data = np.frombuffer(buf, dtype=np.uint8)
|
data = np.frombuffer(buf, dtype=np.uint8)
|
||||||
data = data.reshape(num_images, self.image_size, self.image_size, self.num_channels)
|
data = data.reshape(-1, self.image_size, self.image_size, self.num_channels)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _extract_labels(self, filename, num_images):
|
def _extract_labels(self, filename):
|
||||||
"""Extract the labels into a vector of int64 label IDs."""
|
"""Extract the labels into a vector of int64 label IDs."""
|
||||||
with gzip.open(filename) as bytestream:
|
with gzip.open(filename) as bytestream:
|
||||||
bytestream.read(8)
|
bytestream.read(8)
|
||||||
buf = bytestream.read(1 * num_images)
|
buf = bytestream.read()
|
||||||
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
|
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
@ -101,8 +101,8 @@ class MnistToMR:
|
||||||
Yields:
|
Yields:
|
||||||
data (dict of list): mnist data list which contains dict.
|
data (dict of list): mnist data list which contains dict.
|
||||||
"""
|
"""
|
||||||
train_data = self._extract_images(self.train_data_filename_, 60000)
|
train_data = self._extract_images(self.train_data_filename_)
|
||||||
train_labels = self._extract_labels(self.train_labels_filename_, 60000)
|
train_labels = self._extract_labels(self.train_labels_filename_)
|
||||||
for data, label in zip(train_data, train_labels):
|
for data, label in zip(train_data, train_labels):
|
||||||
_, img = cv2.imencode(".jpeg", data)
|
_, img = cv2.imencode(".jpeg", data)
|
||||||
yield {"label": int(label), "data": img.tobytes()}
|
yield {"label": int(label), "data": img.tobytes()}
|
||||||
|
@ -114,8 +114,8 @@ class MnistToMR:
|
||||||
Yields:
|
Yields:
|
||||||
data (dict of list): mnist data list which contains dict.
|
data (dict of list): mnist data list which contains dict.
|
||||||
"""
|
"""
|
||||||
test_data = self._extract_images(self.test_data_filename_, 10000)
|
test_data = self._extract_images(self.test_data_filename_)
|
||||||
test_labels = self._extract_labels(self.test_labels_filename_, 10000)
|
test_labels = self._extract_labels(self.test_labels_filename_)
|
||||||
for data, label in zip(test_data, test_labels):
|
for data, label in zip(test_data, test_labels):
|
||||||
_, img = cv2.imencode(".jpeg", data)
|
_, img = cv2.imencode(".jpeg", data)
|
||||||
yield {"label": int(label), "data": img.tobytes()}
|
yield {"label": int(label), "data": img.tobytes()}
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -203,9 +203,9 @@ def test_nlp_page_reader_tutorial():
|
||||||
os.remove("{}".format(x))
|
os.remove("{}".format(x))
|
||||||
os.remove("{}.db".format(x))
|
os.remove("{}.db".format(x))
|
||||||
|
|
||||||
def test_cv_file_writer_shard_num_1000():
|
def test_cv_file_writer_shard_num_10():
|
||||||
"""test file writer when shard num equals 1000."""
|
"""test file writer when shard num equals 10."""
|
||||||
writer = FileWriter(CV_FILE_NAME, 1000)
|
writer = FileWriter(CV_FILE_NAME, 10)
|
||||||
data = get_data("../data/mindrecord/testImageNetData/")
|
data = get_data("../data/mindrecord/testImageNetData/")
|
||||||
cv_schema_json = {"file_name": {"type": "string"},
|
cv_schema_json = {"file_name": {"type": "string"},
|
||||||
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
"label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||||
|
@ -214,8 +214,8 @@ def test_cv_file_writer_shard_num_1000():
|
||||||
writer.write_raw_data(data)
|
writer.write_raw_data(data)
|
||||||
writer.commit()
|
writer.commit()
|
||||||
|
|
||||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(3, '0'))
|
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||||
for x in range(1000)]
|
for x in range(10)]
|
||||||
for x in paths:
|
for x in paths:
|
||||||
os.remove("{}".format(x))
|
os.remove("{}".format(x))
|
||||||
os.remove("{}.db".format(x))
|
os.remove("{}.db".format(x))
|
||||||
|
|
|
@ -37,7 +37,7 @@ def read(train_name, test_name):
|
||||||
count = count + 1
|
count = count + 1
|
||||||
if count == 1:
|
if count == 1:
|
||||||
logger.info("data: {}".format(x))
|
logger.info("data: {}".format(x))
|
||||||
assert count == 60000
|
assert count == 20
|
||||||
reader.close()
|
reader.close()
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -47,7 +47,7 @@ def read(train_name, test_name):
|
||||||
count = count + 1
|
count = count + 1
|
||||||
if count == 1:
|
if count == 1:
|
||||||
logger.info("data: {}".format(x))
|
logger.info("data: {}".format(x))
|
||||||
assert count == 10000
|
assert count == 10
|
||||||
reader.close()
|
reader.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,10 +102,10 @@ def test_mnist_to_mindrecord_compare_data():
|
||||||
't10k-images-idx3-ubyte.gz')
|
't10k-images-idx3-ubyte.gz')
|
||||||
test_labels_filename_ = os.path.join(MNIST_DIR,
|
test_labels_filename_ = os.path.join(MNIST_DIR,
|
||||||
't10k-labels-idx1-ubyte.gz')
|
't10k-labels-idx1-ubyte.gz')
|
||||||
train_data = _extract_images(train_data_filename_, 60000)
|
train_data = _extract_images(train_data_filename_, 20)
|
||||||
train_labels = _extract_labels(train_labels_filename_, 60000)
|
train_labels = _extract_labels(train_labels_filename_, 20)
|
||||||
test_data = _extract_images(test_data_filename_, 10000)
|
test_data = _extract_images(test_data_filename_, 10)
|
||||||
test_labels = _extract_labels(test_labels_filename_, 10000)
|
test_labels = _extract_labels(test_labels_filename_, 10)
|
||||||
|
|
||||||
reader = FileReader(train_name)
|
reader = FileReader(train_name)
|
||||||
for x, data, label in zip(reader.get_next(), train_data, train_labels):
|
for x, data, label in zip(reader.get_next(), train_data, train_labels):
|
||||||
|
|
Loading…
Reference in New Issue