opotimize face quality performance

This commit is contained in:
xiefangqi 2021-04-01 18:41:22 +08:00
parent aac165a8e5
commit a665470cf5
1 changed files with 3 additions and 8 deletions

View File

@ -20,7 +20,6 @@ from PIL import Image, ImageFile
import mindspore.dataset as ds
import mindspore.dataset.vision.py_transforms as F
import mindspore.dataset.transforms.py_transforms as F2
warnings.filterwarnings('ignore')
ImageFile.LOAD_TRUNCATED_IMAGES = True
@ -72,7 +71,7 @@ class MdFaceDataset():
landmarks = self._trans_cor(path_label_info[4:14], x_length, y_length)
eulers = np.array([e / 90. for e in list(map(float, path_label_info[1:4]))])
labels = np.concatenate([eulers, landmarks], axis=0)
sample = image
sample = F.ToTensor()(image)
return sample, labels
@ -107,14 +106,10 @@ class DistributedSampler():
def faceqa_dataset(imlist, per_batch_size, local_rank, world_size):
'''faceqa dataset'''
transform_img = F2.Compose([F.ToTensor()])
dataset = MdFaceDataset(imlist)
sampler = DistributedSampler(dataset, local_rank, world_size)
de_dataset = ds.GeneratorDataset(dataset, ["image", "label"], sampler=sampler, num_parallel_workers=8,
de_dataset = ds.GeneratorDataset(dataset, ["image", "label"], sampler=sampler, num_parallel_workers=16,
python_multiprocessing=True)
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=8,
python_multiprocessing=True)
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=True)
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=True, num_parallel_workers=4)
return de_dataset