add np.frombuffer to support return bytes

This commit is contained in:
ms_yan 2021-08-20 20:36:06 +08:00
parent 4661b47b52
commit 6ecfc34524
2 changed files with 52 additions and 4 deletions

View File

@ -3502,11 +3502,11 @@ def _iter_fn(dataset, num_samples):
except StopIteration:
return
# convert output tensors to ndarrays
yield tuple([np.array(x, copy=False) for x in val])
yield _convert_row(val)
else:
for val in dataset:
# convert output tensors to ndarrays
yield tuple([np.array(x, copy=False) for x in val])
yield _convert_row(val)
def _generator_fn(generator, num_samples):
@ -3539,7 +3539,7 @@ def _cpp_sampler_fn(sample_ids, dataset):
for i in sample_ids:
val = dataset[i]
# convert output tensors to ndarrays
yield tuple([np.array(x, copy=False) for x in val])
yield _convert_row(val)
def _cpp_sampler_fn_mp(sample_ids, sample_fn):
@ -3606,6 +3606,17 @@ def _watch_dog(pids, eof):
os.kill(os.getpid(), signal.SIGTERM)
def _convert_row(row):
value = []
# convert each column in row into numpy array
for x in row:
if isinstance(x, bytes):
value.append(np.frombuffer(x, np.uint8))
else:
value.append(np.array(x, copy=False))
return tuple(value)
class SamplerFn:
"""
Multiprocessing or multithread generator function wrapper master process.
@ -3704,7 +3715,7 @@ class SamplerFn:
return
if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
yield tuple([np.array(x, copy=False) for x in result])
yield _convert_row(result)
def _stop_subprocess(self):
# Only the main process can call join

View File

@ -16,9 +16,11 @@
Testing Decode op in DE
"""
import cv2
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore import log as logger
from util import diff_mse
@ -76,6 +78,41 @@ def test_decode_op_tf_file_dataset():
assert mse == 0
class ImageDataset:
def __init__(self, data_path, data_type="numpy"):
self.data = [data_path]
self.label = np.random.sample((1, 1))
self.data_type = data_type
def __getitem__(self, index):
# use file open and read method
f = open(self.data[index], 'rb')
img_bytes = f.read()
f.close()
if self.data_type == "numpy":
img_bytes = np.frombuffer(img_bytes, dtype=np.uint8)
# return bytes directly
return (img_bytes, self.label[index])
def __len__(self):
return len(self.data)
def test_read_image_decode_op():
data_path = "../data/dataset/testPK/data/class1/0.jpg"
dataset1 = ds.GeneratorDataset(ImageDataset(data_path, data_type="numpy"), ["data", "label"])
dataset2 = ds.GeneratorDataset(ImageDataset(data_path, data_type="bytes"), ["data", "label"])
decode_op = py_vision.Decode()
to_tensor = py_vision.ToTensor(output_type=np.int32)
dataset1 = dataset1.map(operations=[decode_op, to_tensor], input_columns=["data"])
dataset2 = dataset2.map(operations=[decode_op, to_tensor], input_columns=["data"])
for item1, item2 in zip(dataset1, dataset2):
assert np.count_nonzero(item1[0].asnumpy() - item2[0].asnumpy()) == 0
if __name__ == "__main__":
test_decode_op()
test_decode_op_tf_file_dataset()
test_read_image_decode_op()