forked from mindspore-Ecosystem/mindspore
opotimize face quality performance
This commit is contained in:
parent
aac165a8e5
commit
a665470cf5
|
@ -20,7 +20,6 @@ from PIL import Image, ImageFile
|
||||||
|
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.vision.py_transforms as F
|
import mindspore.dataset.vision.py_transforms as F
|
||||||
import mindspore.dataset.transforms.py_transforms as F2
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
@ -72,7 +71,7 @@ class MdFaceDataset():
|
||||||
landmarks = self._trans_cor(path_label_info[4:14], x_length, y_length)
|
landmarks = self._trans_cor(path_label_info[4:14], x_length, y_length)
|
||||||
eulers = np.array([e / 90. for e in list(map(float, path_label_info[1:4]))])
|
eulers = np.array([e / 90. for e in list(map(float, path_label_info[1:4]))])
|
||||||
labels = np.concatenate([eulers, landmarks], axis=0)
|
labels = np.concatenate([eulers, landmarks], axis=0)
|
||||||
sample = image
|
sample = F.ToTensor()(image)
|
||||||
|
|
||||||
return sample, labels
|
return sample, labels
|
||||||
|
|
||||||
|
@ -107,14 +106,10 @@ class DistributedSampler():
|
||||||
|
|
||||||
def faceqa_dataset(imlist, per_batch_size, local_rank, world_size):
|
def faceqa_dataset(imlist, per_batch_size, local_rank, world_size):
|
||||||
'''faceqa dataset'''
|
'''faceqa dataset'''
|
||||||
transform_img = F2.Compose([F.ToTensor()])
|
|
||||||
dataset = MdFaceDataset(imlist)
|
dataset = MdFaceDataset(imlist)
|
||||||
sampler = DistributedSampler(dataset, local_rank, world_size)
|
sampler = DistributedSampler(dataset, local_rank, world_size)
|
||||||
de_dataset = ds.GeneratorDataset(dataset, ["image", "label"], sampler=sampler, num_parallel_workers=8,
|
de_dataset = ds.GeneratorDataset(dataset, ["image", "label"], sampler=sampler, num_parallel_workers=16,
|
||||||
python_multiprocessing=True)
|
python_multiprocessing=True)
|
||||||
|
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=True, num_parallel_workers=4)
|
||||||
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=8,
|
|
||||||
python_multiprocessing=True)
|
|
||||||
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=True)
|
|
||||||
|
|
||||||
return de_dataset
|
return de_dataset
|
||||||
|
|
Loading…
Reference in New Issue