forked from mindspore-Ecosystem/mindspore
!14562 [MD] Fix face quality performance issue
From: @xiefangqi Reviewed-by: @liucunwei,@oacjiewen Signed-off-by: @liucunwei
This commit is contained in:
commit
04c9f1df84
|
@ -20,7 +20,6 @@ from PIL import Image, ImageFile
|
|||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.py_transforms as F
|
||||
import mindspore.dataset.transforms.py_transforms as F2
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
@ -72,7 +71,7 @@ class MdFaceDataset():
|
|||
landmarks = self._trans_cor(path_label_info[4:14], x_length, y_length)
|
||||
eulers = np.array([e / 90. for e in list(map(float, path_label_info[1:4]))])
|
||||
labels = np.concatenate([eulers, landmarks], axis=0)
|
||||
sample = image
|
||||
sample = F.ToTensor()(image)
|
||||
|
||||
return sample, labels
|
||||
|
||||
|
@ -107,14 +106,10 @@ class DistributedSampler():
|
|||
|
||||
def faceqa_dataset(imlist, per_batch_size, local_rank, world_size):
|
||||
'''faceqa dataset'''
|
||||
transform_img = F2.Compose([F.ToTensor()])
|
||||
dataset = MdFaceDataset(imlist)
|
||||
sampler = DistributedSampler(dataset, local_rank, world_size)
|
||||
de_dataset = ds.GeneratorDataset(dataset, ["image", "label"], sampler=sampler, num_parallel_workers=8,
|
||||
de_dataset = ds.GeneratorDataset(dataset, ["image", "label"], sampler=sampler, num_parallel_workers=16,
|
||||
python_multiprocessing=True)
|
||||
|
||||
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=8,
|
||||
python_multiprocessing=True)
|
||||
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=True)
|
||||
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=True, num_parallel_workers=4)
|
||||
|
||||
return de_dataset
|
||||
|
|
Loading…
Reference in New Issue