!12712 masked face recognition
From: @ffeiding Reviewed-by: Signed-off-by:
This commit is contained in:
commit
609a518068
|
@ -0,0 +1,143 @@
|
|||
# Masked Face Recognition with Latent Part Detection
|
||||
|
||||
# Contents
|
||||
|
||||
- [Masked Face Recognition Description](#masked-face-recognition-description)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Training](#training)
|
||||
- [Evaluation](#evaluation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [Masked Face Recognition Description](#contents)
|
||||
|
||||
<p align="center">
|
||||
<img src="./img/overview.png">
|
||||
</p>
|
||||
|
||||
This is a **MindSpore** implementation of [Masked Face Recognition with Latent Part Detection (ACM MM20)](https://dl.acm.org/doi/10.1145/3394171.3413731) by *Feifei Ding, Peixi Peng, Yangru Huang, Mengyue Geng and Yonghong Tian*.
|
||||
|
||||
*Masked Face Recognition* aims to match masked faces with common faces and is important especially during the global outbreak of COVID-19. It is challenging to identify masked faces since most facial cues are occluded by mask.
|
||||
|
||||
*Latent Part Detection* (LPD) is a differentiable module that can locate the latent facial part which is robust to mask wearing, and the latent part is further used to extract discriminative features. The proposed LPD model is trained in an end-to-end manner and only utilizes the original and synthetic training data.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
## Training Dataset
|
||||
|
||||
We use [CASIA-WebFace Dataset](http://www.cbsr.ia.ac.cn/english/CASIA-WebFace-Database.html) as the training dataset. After downloading CASIA-WebFace, we first detect faces and facial landmarks using `MTCNN` and align faces to a canonical pose using similarity transformation. (see: [MTCNN - face detection & alignment](https://github.com/kpzhang93/MTCNN_face_detection_alignment)).
|
||||
|
||||
Collecting and labeling realistic masked facial data requires a great deal of human labor. To address this issue, we generate masked face images based on CASIA-WebFace. We generate 8 kinds of synthetic masked face images to augment training data based on 8 different styles of masks, such as surgical masks, N95 respirators and activated carbon masks. We mix the original face images with the synthetic masked images as the training data.
|
||||
|
||||
<p align="center">
|
||||
<img src="./img/generated_masked_faces.png" width="600px">
|
||||
</p>
|
||||
|
||||
## Evaluating Dataset
|
||||
|
||||
We use [PKU-Masked-Face Dataset](https://pkuml.org/resources/pku-masked-face-dataset.html) as the evaluating dataset. The dataset contains 10,301 face images of 1,018 identities. Each identity has masked and common face images with various orientations, lighting conditions and mask types. Most identities have 5 holistic face images and 5 masked face images with 5 different views: front, left, right, up and down.
|
||||
|
||||
The directory structure is as follows:
|
||||
|
||||
```python
|
||||
.
|
||||
└─ dataset
|
||||
├─ train dataset
|
||||
├─ ID1
|
||||
├─ ID1_0001.jpg
|
||||
├─ ID1_0002.jpg
|
||||
...
|
||||
├─ ID2
|
||||
...
|
||||
├─ ID3
|
||||
...
|
||||
...
|
||||
├─ test dataset
|
||||
├─ ID1
|
||||
├─ ID1_0001.jpg
|
||||
├─ ID1_0002.jpg
|
||||
...
|
||||
├─ ID2
|
||||
...
|
||||
├─ ID3
|
||||
...
|
||||
...
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to get Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
The entire code structure is as following:
|
||||
|
||||
```python
|
||||
└─ face_recognition
|
||||
├── README.md // descriptions about face_recognition
|
||||
├── scripts
|
||||
│ ├── run_train.sh // shell script for training on Ascend
|
||||
│ ├── run_eval.sh // shell script for evaluation on Ascend
|
||||
├── src
|
||||
│ ├── dataset
|
||||
│ │ ├── Dataset.py // loading evaluating dataset
|
||||
│ │ ├── MGDataset.py // loading training dataset
|
||||
│ ├── model
|
||||
│ │ ├── model.py // lpd model
|
||||
│ │ ├── stn.py // spatial transformer network module
|
||||
│ ├── utils
|
||||
│ │ ├── distance.py // calculate distance of two features
|
||||
│ │ ├── metric.py // calculate mAP and CMC scores
|
||||
├─ config.py // hyperparameter setting
|
||||
├─ train_dataset.py // training data format setting
|
||||
├─ test_dataset.py // evaluating data format setting
|
||||
├─ train.py // training scripts
|
||||
├─ test.py // evaluation scripts
|
||||
```
|
||||
|
||||
# [Training](#contents)
|
||||
|
||||
```bash
|
||||
sh scripts/run_train.sh [USE_DEVICE_ID]
|
||||
```
|
||||
|
||||
You will get the loss value of each epoch as following in "./scripts/data_parallel_log_[DEVICE_ID]/outputs/logs/[TIME].log" or "./scripts/log_parallel_graph/face_recognition_[DEVICE_ID].log":
|
||||
|
||||
```python
|
||||
epoch[0], iter[100], loss:(Tensor(shape=[], dtype=Float32, value= 50.2733), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 32768)), cur_lr:0.000660, mean_fps:743.09 imgs/sec
|
||||
epoch[0], iter[200], loss:(Tensor(shape=[], dtype=Float32, value= 49.3693), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 32768)), cur_lr:0.001314, mean_fps:4426.42 imgs/sec
|
||||
epoch[0], iter[300], loss:(Tensor(shape=[], dtype=Float32, value= 48.7081), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 16384)), cur_lr:0.001968, mean_fps:4428.09 imgs/sec
|
||||
epoch[0], iter[400], loss:(Tensor(shape=[], dtype=Float32, value= 45.7791), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 16384)), cur_lr:0.002622, mean_fps:4428.17 imgs/sec
|
||||
|
||||
...
|
||||
epoch[8], iter[27300], loss:(Tensor(shape=[], dtype=Float32, value= 2.13556), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4429.38 imgs/sec
|
||||
epoch[8], iter[27400], loss:(Tensor(shape=[], dtype=Float32, value= 2.36922), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4429.88 imgs/sec
|
||||
epoch[8], iter[27500], loss:(Tensor(shape=[], dtype=Float32, value= 2.08594), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4430.59 imgs/sec
|
||||
epoch[8], iter[27600], loss:(Tensor(shape=[], dtype=Float32, value= 2.38706), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4430.37 imgs/sec
|
||||
```
|
||||
|
||||
# [Evaluation](#contents)
|
||||
|
||||
```bash
|
||||
sh scripts/run_eval.sh [USE_DEVICE_ID]
|
||||
```
|
||||
|
||||
You will get the result as following in "./scripts/log_inference/outputs/models/logs/[TIME].log":
|
||||
[test_dataset]: zj2jk=0.9495, jk2zj=0.9480, avg=0.9487
|
||||
|
||||
| model | mAP | rank1 | rank5 | rank10|
|
||||
| ---------| ------| ----- | ----- | ----- |
|
||||
| Baseline | 27.09 | 70.17 | 87.95 | 91.80 |
|
||||
| MG | 36.55 | 94.12 | 98.01 | 98.66 |
|
||||
| LPD | 42.14 | 96.22 | 98.11 | 98.75 |
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"class_num": 10572,
|
||||
"batch_size": 128,
|
||||
"learning_rate": 0.01,
|
||||
"lr_decay_epochs": [40, 80, 100],
|
||||
"lr_decay_factor": 0.1,
|
||||
"lr_warmup_epochs": 20,
|
||||
"p": 16,
|
||||
"k": 8,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 120,
|
||||
"buffer_size": 10000,
|
||||
"image_height": 128,
|
||||
"image_width": 128,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 195,
|
||||
"keep_checkpoint_max": 2,
|
||||
"save_checkpoint_path": "./"
|
||||
})
|
|
@ -0,0 +1,202 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""data process"""
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from PIL import ImageFile
|
||||
import cv2
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
__all__ = ['DistributedPKSampler', 'Dataset']
|
||||
IMG_EXTENSIONS = ('.jpg', 'jpeg', '.png', '.ppm', '.bmp', 'pgm', '.tif', '.tiff', 'webp')
|
||||
|
||||
|
||||
class DistributedPKSampler:
|
||||
'''DistributedPKSampler'''
|
||||
def __init__(self, dataset, shuffle=True, p=5, k=2):
|
||||
assert isinstance(dataset, PKDataset), 'PK Sampler Only Supports PK Dataset!'
|
||||
self.p = p
|
||||
self.k = k
|
||||
self.dataset = dataset
|
||||
self.epoch = 0
|
||||
self.step_nums = int(math.ceil(len(self.dataset.classes)*1.0/p))
|
||||
self.total_ids = self.step_nums*p
|
||||
self.batch_size = p*k
|
||||
self.num_samples = self.total_ids * self.k
|
||||
self.shuffle = shuffle
|
||||
self.epoch_gen = 1
|
||||
|
||||
|
||||
def _sample_pk(self, indices):
|
||||
'''sample pk'''
|
||||
sampled_pk = []
|
||||
for indice in indices:
|
||||
sampled_id = indice
|
||||
replacement = False
|
||||
if len(self.dataset.id2range[sampled_id]) < self.k:
|
||||
replacement = True
|
||||
index_list = np.random.choice(self.dataset.id2range[sampled_id][0:], self.k, replace=replacement)
|
||||
sampled_pk.extend(index_list.tolist())
|
||||
|
||||
return sampled_pk
|
||||
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.epoch_gen = (self.epoch_gen + 1) & 0xffffffff
|
||||
np.random.seed(self.epoch_gen)
|
||||
indices = np.random.permutation(len(self.dataset.classes))
|
||||
indices = indices.tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset.classes)))
|
||||
indices += indices[:(self.total_ids - len(indices))]
|
||||
assert len(indices) == self.total_ids
|
||||
|
||||
sampled_idxs = self._sample_pk(indices)
|
||||
|
||||
return iter(sampled_idxs)
|
||||
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def has_file_allowed_extension(filename, extensions):
|
||||
""" check if a file has an allowed extensio n.
|
||||
|
||||
Args:
|
||||
filename (string): path to a file
|
||||
extensions (tuple of strings): extensions allowed (lowercase)
|
||||
|
||||
Returns:
|
||||
bool: True if the file ends with one of the given extensions
|
||||
"""
|
||||
return filename.lower().endswith(extensions)
|
||||
|
||||
|
||||
def make_dataset(dir_name, class_to_idx, extensions=None, is_valid_file=None):
|
||||
'''make dataset'''
|
||||
images = []
|
||||
dir_name = os.path.expanduser(dir_name)
|
||||
if not (extensions is None) ^ (is_valid_file is None):
|
||||
raise ValueError("Extensions and is_valid_file should not be the same.")
|
||||
def is_valid(x):
|
||||
if extensions is not None:
|
||||
return has_file_allowed_extension(x, extensions)
|
||||
return is_valid_file(x)
|
||||
for target in sorted(class_to_idx.keys()):
|
||||
d = os.path.join(dir_name, target)
|
||||
if not os.path.isdir(d):
|
||||
continue
|
||||
for root, _, fnames in sorted(os.walk(d)):
|
||||
for fname in sorted(fnames):
|
||||
path = os.path.join(root, fname)
|
||||
if is_valid(path):
|
||||
item = (path, class_to_idx[target], 0.6)
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
|
||||
class ImageFolderPKDataset:
|
||||
'''ImageFolderPKDataset'''
|
||||
def __init__(self, root):
|
||||
self.classes, self.classes_to_idx = self._find_classes(root)
|
||||
self.samples = make_dataset(root, self.classes_to_idx, IMG_EXTENSIONS, None)
|
||||
self.id2range = self._build_id2range()
|
||||
self.all_image_idxs = range(len(self.samples))
|
||||
self.classes = list(self.id2range.keys())
|
||||
|
||||
def _find_classes(self, dir_name):
|
||||
"""
|
||||
Finds the class folders in a dataset
|
||||
|
||||
Args:
|
||||
dir (string): root directory path
|
||||
|
||||
Returns:
|
||||
tuple (class, class_to_idx): where classes are relative to dir, and class_to_idx is a directionaty
|
||||
|
||||
Ensures:
|
||||
No class is a subdirectory of others
|
||||
"""
|
||||
|
||||
if sys.version_info >= (3, 5):
|
||||
# Faster and available in Python 3.5 and above
|
||||
classes = [d.name for d in os.scandir(dir_name) if d.is_dir()]
|
||||
else:
|
||||
classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))]
|
||||
classes.sort()
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
|
||||
return classes, class_to_idx
|
||||
|
||||
|
||||
def _build_id2range(self):
|
||||
'''map id to range'''
|
||||
id2range = defaultdict(list)
|
||||
ret_range = defaultdict(list)
|
||||
for idx, sample in enumerate(self.samples):
|
||||
label = sample[1]
|
||||
id2range[label].append((sample, idx))
|
||||
# print(id2range)
|
||||
for key in id2range:
|
||||
id2range[key].sort(key=lambda x: int(os.path.basename(x[0][0]).split(".")[0]))
|
||||
for item in id2range[key]:
|
||||
ret_range[key].append(item[1])
|
||||
return ret_range
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.samples[index]
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
|
||||
def pil_loader(path):
|
||||
'''pil loader'''
|
||||
img = cv2.imread(path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
class Dataset:
|
||||
'''Dataset'''
|
||||
def __init__(self, root, loader=pil_loader):
|
||||
self.dataset = ImageFolderPKDataset(root)
|
||||
print('Dataset len(dataset):{}'.format(len(self.dataset)))
|
||||
self.loader = loader
|
||||
self.classes = self.dataset.classes
|
||||
self.id2range = self.dataset.id2range
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
path, target1, target2 = self.dataset[index]
|
||||
sample = self.loader(path)
|
||||
return sample, target1, target2
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MGDataset"""
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
from collections import defaultdict
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import ImageFile
|
||||
import cv2
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
__all__ = ['DistributedPKSampler', 'MGDataset']
|
||||
IMG_EXTENSIONS = ('.jpg', 'jpeg', '.png', '.ppm', '.bmp', 'pgm', '.tif', '.tiff', 'webp')
|
||||
|
||||
|
||||
class DistributedPKSampler:
|
||||
'''DistributedPKSampler'''
|
||||
def __init__(self, dataset, shuffle=True, p=5, k=2):
|
||||
assert isinstance(dataset, MGDataset), 'PK Sampler Only Supports PK Dataset or MG Dataset!'
|
||||
self.p = p
|
||||
self.k = k
|
||||
self.dataset = dataset
|
||||
self.epoch = 0
|
||||
self.step_nums = int(math.ceil(len(self.dataset.classes)*1.0/p))
|
||||
self.total_ids = self.step_nums*p
|
||||
self.batch_size = p*k
|
||||
self.num_samples = self.total_ids * self.k
|
||||
self.shuffle = shuffle
|
||||
self.epoch_gen = 1
|
||||
|
||||
def _sample_pk(self, indices):
|
||||
'''sample pk'''
|
||||
sampled_pk = []
|
||||
for indice in indices:
|
||||
sampled_id = indice
|
||||
replacement = False
|
||||
if len(self.dataset.id2range[sampled_id]) < self.k:
|
||||
replacement = True
|
||||
index_list = np.random.choice(self.dataset.id2range[sampled_id][0:], self.k, replace=replacement)
|
||||
sampled_pk.extend(index_list.tolist())
|
||||
|
||||
return sampled_pk
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.epoch_gen = (self.epoch_gen + 1) & 0xffffffff
|
||||
np.random.seed(self.epoch_gen)
|
||||
indices = np.random.permutation(len(self.dataset.classes))
|
||||
indices = indices.tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset.classes)))
|
||||
|
||||
indices += indices[:(self.total_ids - len(indices))]
|
||||
assert len(indices) == self.total_ids
|
||||
|
||||
sampled_idxs = self._sample_pk(indices)
|
||||
|
||||
return iter(sampled_idxs)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def has_file_allowed_extension(filename, extensions):
|
||||
""" check if a file has an allowed extensio n.
|
||||
|
||||
Args:
|
||||
filename (string): path to a file
|
||||
extensions (tuple of strings): extensions allowed (lowercase)
|
||||
|
||||
Returns:
|
||||
bool: True if the file ends with one of the given extensions
|
||||
"""
|
||||
return filename.lower().endswith(extensions)
|
||||
|
||||
|
||||
def make_dataset(dir_name, class_to_idx, extensions=None, is_valid_file=None):
|
||||
'''make dataset'''
|
||||
images = []
|
||||
masked_datasets = ["n95", "3m", "new", "mask_1", "mask_2", "mask_3", "mask_4", "mask_5"]
|
||||
dir_name = os.path.expanduser(dir_name)
|
||||
if not (extensions is None) ^ (is_valid_file is None):
|
||||
raise ValueError("Extensions and is_valid_file should not be the same")
|
||||
|
||||
def is_valid(x):
|
||||
if extensions is not None:
|
||||
return has_file_allowed_extension(x, extensions)
|
||||
return is_valid_file(x)
|
||||
|
||||
|
||||
for target in sorted(class_to_idx.keys()):
|
||||
d = os.path.join(dir_name, target)
|
||||
if not os.path.isdir(d):
|
||||
continue
|
||||
for root, _, fnames in sorted(os.walk(d)):
|
||||
for fname in sorted(fnames):
|
||||
path = os.path.join(root, fname)
|
||||
if is_valid(path):
|
||||
scale = float(osp.splitext(fname)[0].split('_')[1])
|
||||
item = (path, class_to_idx[target], scale)
|
||||
images.append(item)
|
||||
mask_root_path = root.replace("faces_webface_112x112_raw_image", random.choice(masked_datasets))
|
||||
mask_name = fname.split('_')[0]+".jpg"
|
||||
mask_path = osp.join(mask_root_path, mask_name)
|
||||
if os.path.isfile(mask_path) and is_valid(mask_path):
|
||||
item = (mask_path, class_to_idx[target], scale)
|
||||
images.append(item)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
class ImageFolderPKDataset:
|
||||
'''Image Folder PKDataset'''
|
||||
def __init__(self, root):
|
||||
self.classes, self.classes_to_idx = self._find_classes(root)
|
||||
self.samples = make_dataset(root, self.classes_to_idx, IMG_EXTENSIONS, None)
|
||||
self.id2range = self._build_id2range()
|
||||
self.all_image_idxs = range(len(self.samples))
|
||||
self.classes = list(self.id2range.keys())
|
||||
|
||||
def _find_classes(self, dir_name):
|
||||
"""
|
||||
Finds the class folders in a dataset
|
||||
|
||||
Args:
|
||||
dir (string): root directory path
|
||||
|
||||
Returns:
|
||||
tuple (class, class_to_idx): where classes are relative to dir, and class_to_idx is a directionaty
|
||||
|
||||
Ensures:
|
||||
No class is a subdirectory of others
|
||||
"""
|
||||
|
||||
if sys.version_info >= (3, 5):
|
||||
# Faster and available in Python 3.5 and above
|
||||
classes = [d.name for d in os.scandir(dir_name) if d.is_dir()]
|
||||
else:
|
||||
classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))]
|
||||
classes.sort()
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
|
||||
return classes, class_to_idx
|
||||
|
||||
|
||||
def _build_id2range(self):
|
||||
'''id to range'''
|
||||
id2range = defaultdict(list)
|
||||
ret_range = defaultdict(list)
|
||||
for idx, sample in enumerate(self.samples):
|
||||
label = sample[1]
|
||||
id2range[label].append((sample, idx))
|
||||
for key in id2range:
|
||||
id2range[key].sort(key=lambda x: int(os.path.basename(x[0][0]).split(".")[0]))
|
||||
for item in id2range[key]:
|
||||
ret_range[key].append(item[1])
|
||||
|
||||
return ret_range
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.samples[index]
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
|
||||
def pil_loader(path):
|
||||
'''load pil'''
|
||||
img = cv2.imread(path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
class MGDataset:
|
||||
'''MGDataset'''
|
||||
def __init__(self, root, loader=pil_loader):
|
||||
self.dataset = ImageFolderPKDataset(root)
|
||||
print('MGDataset len(dataset):{}'.format(len(self.dataset)))
|
||||
self.loader = loader
|
||||
self.classes = self.dataset.classes
|
||||
self.id2range = self.dataset.id2range
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
path, target1, target2 = self.dataset[index]
|
||||
sample = self.loader(path)
|
||||
return sample, target1, target2
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
Binary file not shown.
After Width: | Height: | Size: 1.7 MiB |
Binary file not shown.
After Width: | Height: | Size: 337 KiB |
|
@ -0,0 +1,30 @@
|
|||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
|
||||
0.98500779 0.98617601 0.98676012 0.98753894]
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
|
||||
0.98500779 0.98617601 0.98676012 0.98753894]
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
660.75
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
|
||||
0.98500779 0.98617601 0.98676012 0.98753894]
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
|
||||
0.98500779 0.98617601 0.98676012 0.98753894]
|
||||
/home/dingfeifei/datasets/faces_webface_112x112_raw_image/0 0_0.5625.jpg
|
||||
MGDataset len(dataset):884896
|
||||
epoch: 1 step: 661, loss is 19.227043
|
||||
epoch: 2 step: 661, loss is 18.528654
|
||||
epoch: 3 step: 661, loss is 18.451244
|
|
@ -0,0 +1,445 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ResNet."""
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import ParameterTuple
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits, L1Loss
|
||||
from mindspore.nn import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import HeNormal
|
||||
from mindspore.common.initializer import Normal
|
||||
from mindspore import Tensor
|
||||
from .stn import STN
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1):
|
||||
n = 3*3*out_channel
|
||||
normal = Normal(math.sqrt(2. / n))
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=3, stride=stride, padding=1, pad_mode='pad', weight_init=normal)
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1):
|
||||
n = 1*1*out_channel
|
||||
normal = Normal(math.sqrt(2. / n))
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=normal)
|
||||
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1):
|
||||
n = 7*7*out_channel
|
||||
normal = Normal(math.sqrt(2. / n))
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=normal)
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1, use_batch_statistics=None)
|
||||
|
||||
def _bn1(channel):
|
||||
return nn.BatchNorm1d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1, use_batch_statistics=None)
|
||||
|
||||
def _bn1_kaiming(channel):
|
||||
return nn.BatchNorm1d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1, use_batch_statistics=None)
|
||||
|
||||
def _bn2_kaiming(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1, use_batch_statistics=None)
|
||||
|
||||
def _bn_last(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel):
|
||||
he_normal = HeNormal()
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=he_normal, bias_init='zeros')
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlock(3, 256, stride=2)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
channel,
|
||||
out_channel,
|
||||
stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = _conv1x1(in_channel, channel, stride=1)
|
||||
self.bn1 = _bn(channel)
|
||||
|
||||
self.conv2 = _conv3x3(channel, channel, stride=stride)
|
||||
self.bn2 = _bn(channel)
|
||||
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1)
|
||||
self.bn3 = _bn(out_channel)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.down_sample = False
|
||||
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
|
||||
_bn(out_channel)])
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
'''construct'''
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class HardAttn(nn.Cell):
|
||||
'''LPD module'''
|
||||
def __init__(self, in_channels):
|
||||
super(HardAttn, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.fc1 = _fc(128*128, 32)
|
||||
self.bn1 = _bn1(32)
|
||||
self.fc2 = _fc(32, 4)
|
||||
self.bn2 = _bn1(4)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
'''construct'''
|
||||
x = self.reduce_mean(x, 1)
|
||||
x_size = self.shape(x)
|
||||
x = self.reshape(x, (x_size[0], 128*128))
|
||||
x = self.fc1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.reshape(x, (x_size[0], 4))
|
||||
return x
|
||||
|
||||
|
||||
class ResNet(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
Args:
|
||||
block (Cell): Block for network.
|
||||
layer_nums (list): Numbers of block in different layers.
|
||||
in_channels (list): Input channel in each layer.
|
||||
out_channels (list): Output channel in each layer.
|
||||
strides (list): Stride size in each layer.
|
||||
num_classes (int): The number of classes that the training images are belonging to.
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResNet(ResidualBlock,
|
||||
>>> [3, 4, 6, 3],
|
||||
>>> [64, 256, 512, 1024],
|
||||
>>> [256, 512, 1024, 2048],
|
||||
>>> [1, 2, 2, 2],
|
||||
>>> 10)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes, is_train):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
|
||||
self.ha3 = HardAttn(2048)
|
||||
self.is_train = is_train
|
||||
self.conv1 = _conv7x7(3, 64, stride=2)
|
||||
self.bn1 = _bn(64)
|
||||
self.relu = nn.ReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
channel=channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0])
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
channel=channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1])
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
channel=channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2])
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
channel=channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3])
|
||||
|
||||
self.max = P.ReduceMax(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.global_bn = _bn2_kaiming(out_channels[3])
|
||||
self.partial_bn = _bn2_kaiming(out_channels[3])
|
||||
normal = Normal(0.001)
|
||||
self.global_fc = nn.Dense(out_channels[3], num_classes, has_bias=False, weight_init=normal, bias_init='zeros')
|
||||
self.partial_fc = nn.Dense(out_channels[3], num_classes, has_bias=False, weight_init=normal, bias_init='zeros')
|
||||
self.theta_0 = Tensor(np.zeros((128, 4)), mindspore.float32)
|
||||
self.theta_6 = Tensor(np.zeros((128, 4))+0.6, mindspore.float32)
|
||||
self.STN = STN(128, 128)
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.shape = P.Shape()
|
||||
self.tanh = P.Tanh()
|
||||
self.slice = P.Slice()
|
||||
self.split = P.Split(1, 4)
|
||||
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, channel, out_channel, stride):
|
||||
"""
|
||||
Make stage network of ResNet.
|
||||
|
||||
Args:
|
||||
block (Cell): Resnet block.
|
||||
layer_num (int): Layer number.
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer.
|
||||
|
||||
Returns:
|
||||
SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
|
||||
"""
|
||||
layers = []
|
||||
resnet_block = block(in_channel, channel, out_channel, stride=stride)
|
||||
layers.append(resnet_block)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = block(out_channel, channel, out_channel, stride=1)
|
||||
layers.append(resnet_block)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
|
||||
def stn(self, x, stn_theta):
|
||||
'''stn'''
|
||||
x_size = self.shape(x)
|
||||
theta = self.tanh(stn_theta)
|
||||
theta1, theta5, theta6, theta3 = self.split(theta)
|
||||
theta_0 = self.slice(self.theta_0, (0, 0), (x_size[0], 4))
|
||||
theta2, theta4, _, _ = self.split(theta_0)
|
||||
theta = self.concat((theta1, theta2, theta3, theta4, theta5, theta6))
|
||||
flip_feature = self.STN(x, theta)
|
||||
return flip_feature, theta5
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
'''construct'''
|
||||
stn_theta = self.ha3(x)
|
||||
x_p, theta = self.stn(x, stn_theta)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
c1 = self.maxpool(x)
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
out = self.max(c5, (2, 3))
|
||||
out = self.global_bn(out)
|
||||
global_f = self.flatten(out)
|
||||
|
||||
x_p = self.conv1(x_p)
|
||||
x_p = self.bn1(x_p)
|
||||
x_p = self.relu(x_p)
|
||||
c1_p = self.maxpool(x_p)
|
||||
|
||||
c2_p = self.layer1(c1_p)
|
||||
c3_p = self.layer2(c2_p)
|
||||
c4_p = self.layer3(c3_p)
|
||||
c5_p = self.layer4(c4_p)
|
||||
|
||||
out_p = self.max(c5_p, (2, 3))
|
||||
out_p = self.partial_bn(out_p)
|
||||
partial_f = self.flatten(out_p)
|
||||
|
||||
global_out = self.global_fc(global_f)
|
||||
partial_out = self.partial_fc(partial_f)
|
||||
return global_f, partial_f, global_out, partial_out, theta
|
||||
|
||||
|
||||
class NetWithLossClass(nn.Cell):
|
||||
'''net with loss'''
|
||||
def __init__(self, network, is_train=True):
|
||||
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||
self.loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
self.l1_loss = L1Loss()
|
||||
self.network = network
|
||||
self.is_train = is_train
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
|
||||
def construct(self, x, label1, label2):
|
||||
'''construct'''
|
||||
global_f, partial_f, global_out, partial_out, theta = self.network(x)
|
||||
if not self.is_train:
|
||||
out = self.concat((global_f, partial_f))
|
||||
return out
|
||||
loss_global = self.loss(global_out, label1)
|
||||
loss_partial = self.loss(partial_out, label1)
|
||||
loss_theta = self.l1_loss(theta, label2)
|
||||
loss = loss_global + loss_partial + loss_theta
|
||||
return loss
|
||||
|
||||
|
||||
class TrainStepWrap(nn.Cell):
|
||||
'''train step wrap'''
|
||||
def __init__(self, network, lr, momentum, is_train=True):
|
||||
super(TrainStepWrap, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = Momentum(self.weights, lr, momentum)
|
||||
self.grad = C.GradOperation(get_by_list=True)
|
||||
self.is_train = is_train
|
||||
|
||||
|
||||
def construct(self, x, labels1, labels2):
|
||||
'''construct'''
|
||||
weights = self.weights
|
||||
loss = self.network(x, labels1, labels2)
|
||||
if not self.is_train:
|
||||
return loss
|
||||
grads = self.grad(self.network, weights)(x, labels1, labels2)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class TestStepWrap(nn.Cell):
|
||||
"""
|
||||
Predict method
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(TestStepWrap, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.sigmoid = P.Sigmoid()
|
||||
|
||||
|
||||
def construct(self, x, labels):
|
||||
'''construct'''
|
||||
logits_global, _, _, _, = self.network(x)
|
||||
pred_probs = self.sigmoid(logits_global)
|
||||
|
||||
return logits_global, pred_probs, labels
|
||||
|
||||
|
||||
def resnet50(class_num=10, is_train=True):
|
||||
"""
|
||||
Get ResNet50 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet50 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[64, 128, 256, 512],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 1],
|
||||
class_num, is_train)
|
||||
|
||||
def resnet101(class_num=1001):
|
||||
"""
|
||||
Get ResNet101 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet101 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet101(1001)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 23, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
|
@ -0,0 +1,288 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""STN module"""
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore.nn as nn
|
||||
|
||||
class STN(nn.Cell):
|
||||
'''STN'''
|
||||
def __init__(self, H, W):
|
||||
super(STN, self).__init__()
|
||||
batch_size = 1
|
||||
x = np.linspace(-1.0, 1.0, H)
|
||||
y = np.linspace(-1.0, 1.0, W)
|
||||
x_t, y_t = np.meshgrid(x, y)
|
||||
x_t = Tensor(x_t, mindspore.float32)
|
||||
y_t = Tensor(y_t, mindspore.float32)
|
||||
expand_dims = P.ExpandDims()
|
||||
x_t = expand_dims(x_t, 0)
|
||||
y_t = expand_dims(y_t, 0)
|
||||
flatten = P.Flatten()
|
||||
x_t_flat = flatten(x_t)
|
||||
y_t_flat = flatten(y_t)
|
||||
oneslike = P.OnesLike()
|
||||
ones = oneslike(x_t_flat)
|
||||
concat = P.Concat()
|
||||
sampling_grid = concat((x_t_flat, y_t_flat, ones))
|
||||
self.sampling_grid = expand_dims(sampling_grid, 0)
|
||||
|
||||
batch_size = 128
|
||||
batch_idx = np.arange(batch_size)
|
||||
batch_idx = batch_idx.reshape((batch_size, 1, 1))
|
||||
self.batch_idx = Tensor(batch_idx, mindspore.float32)
|
||||
self.zero = Tensor(np.zeros([]), mindspore.float32)
|
||||
|
||||
|
||||
def get_pixel_value(self, img, x, y):
|
||||
"""
|
||||
Utility function to get pixel value for coordinate
|
||||
vectors x and y from a 4D tensor image.
|
||||
|
||||
Input
|
||||
-----
|
||||
- img: tensor of shape (B, H, W, C)
|
||||
- x: flattened tensor of shape (B*H*W,)
|
||||
- y: flattened tensor of shape (B*H*W,)
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output: tensor of shape (B, H, W, C)
|
||||
"""
|
||||
shape = P.Shape()
|
||||
img_shape = shape(x)
|
||||
batch_size = img_shape[0]
|
||||
height = img_shape[1]
|
||||
width = img_shape[2]
|
||||
img[:, 0, :, :] = self.zero
|
||||
img[:, height-1, :, :] = self.zero
|
||||
img[:, :, 0, :] = self.zero
|
||||
img[:, :, width-1, :] = self.zero
|
||||
|
||||
tile = P.Tile()
|
||||
batch_idx = P.Slice()(self.batch_idx, (0, 0, 0), (batch_size, 1, 1))
|
||||
b = tile(batch_idx, (1, height, width))
|
||||
|
||||
expand_dims = P.ExpandDims()
|
||||
b = expand_dims(b, 3)
|
||||
x = expand_dims(x, 3)
|
||||
y = expand_dims(y, 3)
|
||||
|
||||
concat = P.Concat(3)
|
||||
indices = concat((b, y, x))
|
||||
cast = P.Cast()
|
||||
indices = cast(indices, mindspore.int32)
|
||||
gather_nd = P.GatherNd()
|
||||
|
||||
return cast(gather_nd(img, indices), mindspore.float32)
|
||||
|
||||
|
||||
def affine_grid_generator(self, height, width, theta):
|
||||
"""
|
||||
This function returns a sampling grid, which when
|
||||
used with the bilinear sampler on the input feature
|
||||
map, will create an output feature map that is an
|
||||
affine transformation [1] of the input feature map.
|
||||
|
||||
zero = Tensor(np.zeros([]), mindspore.float32)
|
||||
Input
|
||||
-----
|
||||
- height: desired height of grid/output. Used
|
||||
to downsample or upsample.
|
||||
|
||||
- width: desired width of grid/output. Used
|
||||
to downsample or upsample.
|
||||
|
||||
- theta: affine transform matrices of shape (num_batch, 2, 3).
|
||||
For each image in the batch, we have 6 theta parameters of
|
||||
the form (2x3) that define the affine transformation T.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- normalized grid (-1, 1) of shape (num_batch, 2, H, W).
|
||||
The 2nd dimension has 2 components: (x, y) which are the
|
||||
sampling points of the original image for each point in the
|
||||
target image.
|
||||
|
||||
Note
|
||||
----
|
||||
[1]: the affine transformation allows cropping, translation,
|
||||
and isotropic scaling.
|
||||
"""
|
||||
shape = P.Shape()
|
||||
num_batch = shape(theta)[0]
|
||||
|
||||
cast = P.Cast()
|
||||
theta = cast(theta, mindspore.float32)
|
||||
|
||||
# transform the sampling grid - batch multiply
|
||||
matmul = P.BatchMatMul()
|
||||
tile = P.Tile()
|
||||
sampling_grid = tile(self.sampling_grid, (num_batch, 1, 1))
|
||||
cast = P.Cast()
|
||||
sampling_grid = cast(sampling_grid, mindspore.float32)
|
||||
|
||||
batch_grids = matmul(theta, sampling_grid)
|
||||
# batch grid has shape (num_batch, 2, H*W)
|
||||
|
||||
# reshape to (num_batch, H, W, 2)
|
||||
reshape = P.Reshape()
|
||||
batch_grids = reshape(batch_grids, (num_batch, 2, height, width))
|
||||
return batch_grids
|
||||
|
||||
|
||||
def bilinear_sampler(self, img, x, y):
|
||||
"""
|
||||
Performs bilinear sampling of the input images according to the
|
||||
normalized coordinates provided by the sampling grid. Note that
|
||||
the sampling is done identically for each channel of the input.
|
||||
|
||||
To test if the function works properly, output image should be
|
||||
identical to input image when theta is initialized to identity
|
||||
transform.
|
||||
|
||||
Input
|
||||
-----
|
||||
- img: batch of images in (B, H, W, C) layout.
|
||||
- grid: x, y which is the output of affine_grid_generator.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- out: interpolated images according to grids. Same size as grid.
|
||||
"""
|
||||
shape = P.Shape()
|
||||
H = shape(img)[1]
|
||||
W = shape(img)[2]
|
||||
cast = P.Cast()
|
||||
max_y = cast(H - 1, mindspore.float32)
|
||||
max_x = cast(W - 1, mindspore.float32)
|
||||
zero = self.zero
|
||||
|
||||
# rescale x and y to [0, W-1/H-1]
|
||||
x = 0.5 * ((x + 1.0) * (max_x-1))
|
||||
y = 0.5 * ((y + 1.0) * (max_y-1))
|
||||
|
||||
# grab 4 nearest corner points for each (x_i, y_i)
|
||||
floor = P.Floor()
|
||||
x0 = floor(x)
|
||||
x1 = x0 + 1
|
||||
y0 = floor(y)
|
||||
y1 = y0 + 1
|
||||
|
||||
# clip to range [0, H-1/W-1] to not violate img boundaries
|
||||
x0 = C.clip_by_value(x0, zero, max_x)
|
||||
x1 = C.clip_by_value(x1, zero, max_x)
|
||||
y0 = C.clip_by_value(y0, zero, max_y)
|
||||
y1 = C.clip_by_value(y1, zero, max_y)
|
||||
|
||||
# get pixel value at corner coords
|
||||
Ia = self.get_pixel_value(img, x0, y0)
|
||||
Ib = self.get_pixel_value(img, x0, y1)
|
||||
Ic = self.get_pixel_value(img, x1, y0)
|
||||
Id = self.get_pixel_value(img, x1, y1)
|
||||
|
||||
# recast as float for delta calculation
|
||||
x0 = cast(x0, mindspore.float32)
|
||||
x1 = cast(x1, mindspore.float32)
|
||||
y0 = cast(y0, mindspore.float32)
|
||||
y1 = cast(y1, mindspore.float32)
|
||||
|
||||
# calculate deltas
|
||||
wa = (x1-x) * (y1-y)
|
||||
wb = (x1-x) * (y-y0)
|
||||
wc = (x-x0) * (y1-y)
|
||||
wd = (x-x0) * (y-y0)
|
||||
|
||||
# add dimension for addition
|
||||
expand_dims = P.ExpandDims()
|
||||
wa = expand_dims(wa, 3)
|
||||
wb = expand_dims(wb, 3)
|
||||
wc = expand_dims(wc, 3)
|
||||
wd = expand_dims(wd, 3)
|
||||
|
||||
# compute output
|
||||
add_n = P.AddN()
|
||||
out = add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def construct(self, input_fmap, theta, out_dims=None, **kwargs):
|
||||
"""
|
||||
Spatial Transformer Network layer implementation as described in [1].
|
||||
|
||||
The layer is composed of 3 elements:
|
||||
|
||||
- localization_net: takes the original image as input and outputs
|
||||
the parameters of the affine transformation that should be applied
|
||||
to the input image.
|
||||
|
||||
- affine_grid_generator: generates a grid of (x,y) coordinates that
|
||||
correspond to a set of points where the input should be sampled
|
||||
to produce the transformed output.
|
||||
|
||||
- bilinear_sampler: takes as input the original image and the grid
|
||||
and produces the output image using bilinear interpolation.
|
||||
|
||||
Input
|
||||
-----
|
||||
- input_fmap: output of the previous layer. Can be input if spatial
|
||||
transformer layer is at the beginning of architecture. Should be
|
||||
a tensor of shape (B, H, W, C).
|
||||
|
||||
- theta: affine transform tensor of shape (B, 6). Permits cropping,
|
||||
translation and isotropic scaling. Initialize to identity matrix.
|
||||
It is the output of the localization network.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- out_fmap: transformed input feature map. Tensor of size (B, C, H, W)-->(B, H, W, C).
|
||||
|
||||
Notes
|
||||
-----
|
||||
[1]: 'Spatial Transformer Networks', Jaderberg et. al,
|
||||
(https://arxiv.org/abs/1506.02025)
|
||||
"""
|
||||
|
||||
# grab input dimensions
|
||||
trans = P.Transpose()
|
||||
input_fmap = trans(input_fmap, (0, 2, 3, 1))
|
||||
shape = P.Shape()
|
||||
input_size = shape(input_fmap)
|
||||
B = input_size[0]
|
||||
H = input_size[1]
|
||||
W = input_size[2]
|
||||
reshape = P.Reshape()
|
||||
theta = reshape(theta, (B, 2, 3))
|
||||
|
||||
# generate grids of same size or upsample/downsample if specified
|
||||
if out_dims:
|
||||
out_H = out_dims[0]
|
||||
out_W = out_dims[1]
|
||||
batch_grids = self.affine_grid_generator(out_H, out_W, theta)
|
||||
else:
|
||||
batch_grids = self.affine_grid_generator(H, W, theta)
|
||||
|
||||
x_s, y_s = P.Split(1, 2)(batch_grids)
|
||||
squeeze = P.Squeeze()
|
||||
x_s = squeeze(x_s)
|
||||
y_s = squeeze(y_s)
|
||||
out_fmap = self.bilinear_sampler(input_fmap, x_s, y_s)
|
||||
out_fmap = trans(out_fmap, (0, 3, 1, 2))
|
||||
|
||||
return out_fmap
|
|
@ -0,0 +1,2 @@
|
|||
#!/bin/bash
|
||||
python3 ./test.py
|
|
@ -0,0 +1,2 @@
|
|||
#!/bin/bash
|
||||
python3 ./train.py
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
from test_dataset import create_dataset
|
||||
from config import config
|
||||
from mindspore import context
|
||||
from mindspore.nn.dynamic_lr import piecewise_constant_lr, warmup_lr
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from model.model import resnet50, TrainStepWrap, NetWithLossClass
|
||||
from utils.distance import compute_dist, compute_score
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
local_data_url = 'data'
|
||||
local_train_url = 'ckpt'
|
||||
|
||||
|
||||
class Logger():
|
||||
'''Log'''
|
||||
def __init__(self, logFile="log_max.txt"):
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(logFile, 'a')
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
self.log.flush()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
sys.stdout = Logger("log/log.txt")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
query_dataset = create_dataset(data_dir=os.path.join('/home/dingfeifei/datasets', \
|
||||
'test/query'), p=config.p, k=config.k)
|
||||
gallery_dataset = create_dataset(data_dir=os.path.join('/home/dingfeifei/datasets', \
|
||||
'test/gallery'), p=config.p, k=config.k)
|
||||
|
||||
epoch_size = config.epoch_size
|
||||
net = resnet50(class_num=config.class_num, is_train=False)
|
||||
loss_net = NetWithLossClass(net, is_train=False)
|
||||
|
||||
base_lr = config.learning_rate
|
||||
warm_up_epochs = config.lr_warmup_epochs
|
||||
lr_decay_epochs = config.lr_decay_epochs
|
||||
lr_decay_factor = config.lr_decay_factor
|
||||
step_size = math.ceil(config.class_num / config.p)
|
||||
lr_decay_steps = []
|
||||
lr_decay = []
|
||||
for i, v in enumerate(lr_decay_epochs):
|
||||
lr_decay_steps.append(v * step_size)
|
||||
lr_decay.append(base_lr * lr_decay_factor ** i)
|
||||
lr_1 = warmup_lr(base_lr, step_size*warm_up_epochs, step_size, warm_up_epochs)
|
||||
lr_2 = piecewise_constant_lr(lr_decay_steps, lr_decay)
|
||||
lr = lr_1 + lr_2
|
||||
|
||||
train_net = TrainStepWrap(loss_net, lr, config.momentum, is_train=False)
|
||||
|
||||
load_checkpoint("checkpoints/40.ckpt", net=train_net)
|
||||
|
||||
q_feats, q_labels, g_feats, g_labels = [], [], [], []
|
||||
for data, gt_classes, theta in query_dataset:
|
||||
output = train_net(data, gt_classes, theta)
|
||||
output = output.asnumpy()
|
||||
label = gt_classes.asnumpy()
|
||||
q_feats.append(output)
|
||||
q_labels.append(label)
|
||||
q_feats = np.vstack(q_feats)
|
||||
q_labels = np.hstack(q_labels)
|
||||
|
||||
for data, gt_classes, theta in gallery_dataset:
|
||||
output = train_net(data, gt_classes, theta)
|
||||
output = output.asnumpy()
|
||||
label = gt_classes.asnumpy()
|
||||
g_feats.append(output)
|
||||
g_labels.append(label)
|
||||
g_feats = np.vstack(g_feats)
|
||||
g_labels = np.hstack(g_labels)
|
||||
|
||||
q_g_dist = compute_dist(q_feats, g_feats, dis_type='cosine')
|
||||
mAP, cmc_scores = compute_score(q_g_dist, q_labels, g_labels)
|
||||
|
||||
print(mAP, cmc_scores)
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create train or eval dataset.
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from config import config
|
||||
from dataset.Dataset import Dataset
|
||||
|
||||
|
||||
def create_dataset(data_dir, p=16, k=8):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
p(int): randomly choose p classes from all classes.
|
||||
k(int): randomly choose k images from each of the chosen p classes.
|
||||
p * k is the batchsize.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
dataset = Dataset(data_dir)
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "label1", "label2"])
|
||||
|
||||
resize_height = config.image_height
|
||||
resize_width = config.image_width
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
resize_op = CV.Resize((resize_height, resize_width))
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
normalize_op = CV.Normalize([0.486, 0.459, 0.408], [0.229, 0.224, 0.225])
|
||||
|
||||
change_swap_op = CV.HWC2CHW()
|
||||
|
||||
trans = []
|
||||
|
||||
trans += [resize_op, rescale_op, normalize_op, change_swap_op]
|
||||
|
||||
type_cast_op_label1 = C.TypeCast(mstype.int32)
|
||||
type_cast_op_label2 = C.TypeCast(mstype.float32)
|
||||
|
||||
de_dataset = de_dataset.map(input_columns="label1", operations=type_cast_op_label1)
|
||||
de_dataset = de_dataset.map(input_columns="label2", operations=type_cast_op_label2)
|
||||
de_dataset = de_dataset.map(input_columns="image", operations=trans)
|
||||
de_dataset = de_dataset.batch(p*k, drop_remainder=False)
|
||||
|
||||
return de_dataset
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
import sys
|
||||
import argparse
|
||||
import random
|
||||
import pickle
|
||||
import numpy as np
|
||||
from train_dataset import create_dataset
|
||||
from config import config
|
||||
from mindspore import context
|
||||
from mindspore.nn.dynamic_lr import piecewise_constant_lr, warmup_lr
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor # TimeMonitor
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from model.model import resnet50, NetWithLossClass, TrainStepWrap, TestStepWrap
|
||||
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
local_data_url = 'data'
|
||||
local_train_url = 'ckpt'
|
||||
|
||||
class Logger():
|
||||
'''Logger'''
|
||||
def __init__(self, logFile="log_max.txt"):
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(logFile, 'a')
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
self.log.flush()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
sys.stdout = Logger("log/log.txt")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
epoch_size = config.epoch_size
|
||||
net = resnet50(class_num=config.class_num, is_train=True)
|
||||
loss_net = NetWithLossClass(net)
|
||||
|
||||
dataset = create_dataset("/home/dingfeifei/datasets/faces_webface_112x112_raw_image", \
|
||||
p=config.p, k=config.k)
|
||||
|
||||
step_size = dataset.get_dataset_size()
|
||||
base_lr = config.learning_rate
|
||||
warm_up_epochs = config.lr_warmup_epochs
|
||||
lr_decay_epochs = config.lr_decay_epochs
|
||||
lr_decay_factor = config.lr_decay_factor
|
||||
lr_decay_steps = []
|
||||
lr_decay = []
|
||||
for i, v in enumerate(lr_decay_epochs):
|
||||
lr_decay_steps.append(v * step_size)
|
||||
lr_decay.append(base_lr * lr_decay_factor ** i)
|
||||
lr_1 = warmup_lr(base_lr, step_size*warm_up_epochs, step_size, warm_up_epochs)
|
||||
lr_2 = piecewise_constant_lr(lr_decay_steps, lr_decay)
|
||||
lr = lr_1 + lr_2
|
||||
|
||||
train_net = TrainStepWrap(loss_net, lr, config.momentum)
|
||||
test_net = TestStepWrap(net)
|
||||
|
||||
f = open("checkpoints/pretrained_resnet50.pkl", "rb")
|
||||
param_dict = pickle.load(f)
|
||||
load_param_into_net(net=train_net, parameter_dict=param_dict)
|
||||
|
||||
model = Model(train_net, eval_network=test_net, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
# time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
#cb = [time_cb, loss_cb]
|
||||
cb = [loss_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, \
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory='checkpoints/', \
|
||||
config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
model.train(epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create train or eval dataset.
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from config import config
|
||||
from dataset.MGDataset import DistributedPKSampler, MGDataset
|
||||
|
||||
|
||||
def create_dataset(data_dir, p=16, k=8):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
p(int): randomly choose p classes from all classes.
|
||||
k(int): randomly choose k images from each of the chosen p classes.
|
||||
p * k is the batchsize.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
dataset = MGDataset(data_dir)
|
||||
sampler = DistributedPKSampler(dataset, p=p, k=k)
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "label1", "label2"], sampler=sampler)
|
||||
|
||||
resize_height = config.image_height
|
||||
resize_width = config.image_width
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
resize_op = CV.Resize((resize_height, resize_width))
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
normalize_op = CV.Normalize([0.486, 0.459, 0.408], [0.229, 0.224, 0.225])
|
||||
|
||||
change_swap_op = CV.HWC2CHW()
|
||||
|
||||
trans = []
|
||||
|
||||
trans += [resize_op, rescale_op, normalize_op, change_swap_op]
|
||||
|
||||
type_cast_op_label1 = C.TypeCast(mstype.int32)
|
||||
type_cast_op_label2 = C.TypeCast(mstype.float32)
|
||||
|
||||
de_dataset = de_dataset.map(input_columns="label1", operations=type_cast_op_label1)
|
||||
de_dataset = de_dataset.map(input_columns="label2", operations=type_cast_op_label2)
|
||||
de_dataset = de_dataset.map(input_columns="image", operations=trans)
|
||||
de_dataset = de_dataset.batch(p*k, drop_remainder=True)
|
||||
|
||||
return de_dataset
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Numpy version of euclidean distance, etc."""
|
||||
import numpy as np
|
||||
from utils.metric import cmc, mean_ap
|
||||
|
||||
|
||||
def normalize(nparray, order=2, axis=0):
|
||||
"""Normalize a N-D numpy array along the specified axis."""
|
||||
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
|
||||
return nparray / (norm + np.finfo(np.float32).eps)
|
||||
|
||||
|
||||
def compute_dist(array1, array2, dis_type='euclidean'):
|
||||
"""Compute the euclidean or cosine distance of all pairs.
|
||||
Args:
|
||||
array1: numpy array with shape [m1, n]
|
||||
array2: numpy array with shape [m2, n]
|
||||
type:
|
||||
one of ['cosine', 'euclidean']
|
||||
Returns:
|
||||
numpy array with shape [m1, m2]
|
||||
"""
|
||||
assert dis_type in ['cosine', 'euclidean']
|
||||
if dis_type == 'cosine':
|
||||
array1 = normalize(array1, axis=1)
|
||||
array2 = normalize(array2, axis=1)
|
||||
dist = np.matmul(array1, array2.T)
|
||||
return -1*dist
|
||||
|
||||
# shape [m1, 1]
|
||||
square1 = np.sum(np.square(array1), axis=1)[..., np.newaxis]
|
||||
# shape [1, m2]
|
||||
square2 = np.sum(np.square(array2), axis=1)[np.newaxis, ...]
|
||||
squared_dist = - 2 * np.matmul(array1, array2.T) + square1 + square2
|
||||
squared_dist[squared_dist < 0] = 0
|
||||
dist = np.sqrt(squared_dist)
|
||||
return dist
|
||||
|
||||
|
||||
def compute_score(dist_mat, query_ids, gallery_ids):
|
||||
mAP = mean_ap(distmat=dist_mat, query_ids=query_ids, gallery_ids=gallery_ids)
|
||||
cmc_scores, _ = cmc(distmat=dist_mat, query_ids=query_ids, gallery_ids=gallery_ids, topk=10)
|
||||
return mAP, cmc_scores
|
|
@ -0,0 +1,194 @@
|
|||
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid)
|
||||
reid/evaluation_metrics/ranking.py. Modifications:
|
||||
1) Only accepts numpy data input, no torch is involved.
|
||||
1) Here results of each query can be returned.
|
||||
2) In the single-gallery-shot evaluation case, the time of repeats is changed
|
||||
from 10 to 100.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import average_precision_score
|
||||
|
||||
|
||||
def _unique_sample(ids_dict, num):
|
||||
mask = np.zeros(num, dtype=np.bool)
|
||||
for _, indices in ids_dict.items():
|
||||
i = np.random.choice(indices)
|
||||
mask[i] = True
|
||||
return mask
|
||||
|
||||
|
||||
def cmc(
|
||||
distmat,
|
||||
query_ids=None,
|
||||
gallery_ids=None,
|
||||
query_cams=None,
|
||||
gallery_cams=None,
|
||||
topk=100,
|
||||
separate_camera_set=False,
|
||||
single_gallery_shot=False,
|
||||
first_match_break=False,
|
||||
average=True):
|
||||
"""
|
||||
Args:
|
||||
distmat: numpy array with shape [num_query, num_gallery], the
|
||||
pairwise distance between query and gallery samples
|
||||
query_ids: numpy array with shape [num_query]
|
||||
gallery_ids: numpy array with shape [num_gallery]
|
||||
query_cams: numpy array with shape [num_query]
|
||||
gallery_cams: numpy array with shape [num_gallery]
|
||||
average: whether to average the results across queries
|
||||
Returns:
|
||||
If `average` is `False`:
|
||||
ret: numpy array with shape [num_query, topk]
|
||||
is_valid_query: numpy array with shape [num_query], containing 0's and
|
||||
1's, whether each query is valid or not
|
||||
If `average` is `True`:
|
||||
numpy array with shape [topk]
|
||||
"""
|
||||
# Ensure numpy array
|
||||
assert isinstance(distmat, np.ndarray)
|
||||
assert isinstance(query_ids, np.ndarray)
|
||||
assert isinstance(gallery_ids, np.ndarray)
|
||||
# assert isinstance(query_cams, np.ndarray)
|
||||
# assert isinstance(gallery_cams, np.ndarray)
|
||||
# separate_camera_set=False
|
||||
first_match_break = True
|
||||
m, _ = distmat.shape
|
||||
# Sort and find correct matches
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
#print(indices)
|
||||
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
|
||||
# Compute CMC for each query
|
||||
ret = np.zeros([m, topk])
|
||||
is_valid_query = np.zeros(m)
|
||||
num_valid_queries = 0
|
||||
|
||||
for i in range(m):
|
||||
valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
|
||||
|
||||
if separate_camera_set:
|
||||
# Filter out samples from same camera
|
||||
valid = (gallery_cams[indices[i]] != query_cams[i])
|
||||
|
||||
if not np.any(matches[i, valid]): continue
|
||||
|
||||
is_valid_query[i] = 1
|
||||
|
||||
if single_gallery_shot:
|
||||
repeat = 100
|
||||
gids = gallery_ids[indices[i][valid]]
|
||||
inds = np.where(valid)[0]
|
||||
ids_dict = defaultdict(list)
|
||||
for j, x in zip(inds, gids):
|
||||
ids_dict[x].append(j)
|
||||
else:
|
||||
repeat = 1
|
||||
|
||||
for _ in range(repeat):
|
||||
if single_gallery_shot:
|
||||
# Randomly choose one instance for each id
|
||||
sampled = (valid & _unique_sample(ids_dict, len(valid)))
|
||||
index = np.nonzero(matches[i, sampled])[0]
|
||||
else:
|
||||
index = np.nonzero(matches[i, valid])[0]
|
||||
|
||||
delta = 1. / (len(index) * repeat)
|
||||
for j, k in enumerate(index):
|
||||
if k - j >= topk: break
|
||||
if first_match_break:
|
||||
ret[i, k - j] += 1
|
||||
break
|
||||
ret[i, k - j] += delta
|
||||
num_valid_queries += 1
|
||||
|
||||
if num_valid_queries == 0:
|
||||
raise RuntimeError("No valid query")
|
||||
ret = ret.cumsum(axis=1)
|
||||
|
||||
if average:
|
||||
return np.sum(ret, axis=0) / num_valid_queries, indices
|
||||
|
||||
return ret, is_valid_query, indices
|
||||
|
||||
|
||||
def mean_ap(
|
||||
distmat,
|
||||
query_ids=None,
|
||||
gallery_ids=None,
|
||||
query_cams=None,
|
||||
gallery_cams=None,
|
||||
average=True):
|
||||
"""
|
||||
Args:
|
||||
distmat: numpy array with shape [num_query, num_gallery], the
|
||||
pairwise distance between query and gallery samples
|
||||
query_ids: numpy array with shape [num_query]
|
||||
gallery_ids: numpy array with shape [num_gallery]
|
||||
query_cams: numpy array with shape [num_query]
|
||||
gallery_cams: numpy array with shape [num_gallery]
|
||||
average: whether to average the results across queries
|
||||
Returns:
|
||||
If `average` is `False`:
|
||||
ret: numpy array with shape [num_query]
|
||||
is_valid_query: numpy array with shape [num_query], containing 0's and
|
||||
1's, whether each query is valid or not
|
||||
If `average` is `True`:
|
||||
a scalar
|
||||
"""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# The behavior of method `sklearn.average_precision` has changed since version
|
||||
# 0.19.
|
||||
# Version 0.18.1 has same results as Matlab evaluation code by Zhun Zhong
|
||||
# (https://github.com/zhunzhong07/person-re-ranking/
|
||||
# blob/master/evaluation/utils/evaluation.m) and by Liang Zheng
|
||||
# (http://www.liangzheng.org/Project/project_reid.html).
|
||||
# My current awkward solution is sticking to this older version.
|
||||
# if cur_version != required_version:
|
||||
# print('User Warning: Version {} is required for package scikit-learn, '
|
||||
# 'your current version is {}. '
|
||||
# 'As a result, the mAP score may not be totally correct. '
|
||||
# 'You can try `pip uninstall scikit-learn` '
|
||||
# 'and then `pip install scikit-learn=={}`'.format(
|
||||
# required_version, cur_version, required_version))
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Ensure numpy array
|
||||
assert isinstance(distmat, np.ndarray)
|
||||
assert isinstance(query_ids, np.ndarray)
|
||||
assert isinstance(gallery_ids, np.ndarray)
|
||||
# assert isinstance(query_cams, np.ndarray)
|
||||
# assert isinstance(gallery_cams, np.ndarray)
|
||||
|
||||
m, _ = distmat.shape
|
||||
|
||||
# Sort and find correct matches
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
# print("indices:", indices)
|
||||
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
|
||||
# Compute AP for each query
|
||||
aps = np.zeros(m)
|
||||
is_valid_query = np.zeros(m)
|
||||
for i in range(m):
|
||||
# Filter out the same id and same camera
|
||||
# valid = ((gallery_ids[indices[i]] != query_ids[i]) |
|
||||
# (gallery_cams[indices[i]] != query_cams[i]))
|
||||
valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
|
||||
# valid = indices[i] != i
|
||||
# valid = (gallery_cams[indices[i]] != query_cams[i])
|
||||
y_true = matches[i, valid]
|
||||
y_score = -distmat[i][indices[i]][valid]
|
||||
|
||||
# y_true=y_true[0:100]
|
||||
# y_score=y_score[0:100]
|
||||
if not np.any(y_true): continue
|
||||
is_valid_query[i] = 1
|
||||
aps[i] = average_precision_score(y_true, y_score)
|
||||
# if not aps:
|
||||
# raise RuntimeError("No valid query")
|
||||
if average:
|
||||
return float(np.sum(aps)) / np.sum(is_valid_query)
|
||||
return aps, is_valid_query
|
Loading…
Reference in New Issue