!9167 Add FaceQualityAssessment network to model_zoo/research/cv/

From: @zhanghuiyao
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-07 11:12:12 +08:00 committed by Gitee
commit 7269757230
14 changed files with 1698 additions and 0 deletions

View File

@ -0,0 +1,237 @@
# Contents
- [Face Quality Assessment Description](#face-quality-assessment-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Running Example](#running-example)
- [Model Description](#model-description)
- [Performance](#performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Face Quality Assessment Description](#contents)
This is a Face Quality Assessment network based on Resnet12, with support for training and evaluation on Ascend910.
ResNet (residual neural network) was proposed by Kaiming He and other four Chinese of Microsoft Research Institute. Through the use of ResNet unit, it successfully trained 152 layers of neural network, and won the championship in ilsvrc2015. The error rate on top 5 was 3.57%, and the parameter quantity was lower than vggnet, so the effect was very outstanding. Traditional convolution network or full connection network will have more or less information loss. At the same time, it will lead to the disappearance or explosion of gradient, which leads to the failure of deep network training. ResNet solves this problem to a certain extent. By passing the input information to the output, the integrity of the information is protected. The whole network only needs to learn the part of the difference between input and output, which simplifies the learning objectives and difficulties.The structure of ResNet can accelerate the training of neural network very quickly, and the accuracy of the model is also greatly improved. At the same time, ResNet is very popular, even can be directly used in the concept net network.
[Paper](https://arxiv.org/pdf/1512.03385.pdf): Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition"
# [Model Architecture](#contents)
Face Quality Assessment uses a modified-Resnet12 network for performing feature extraction.
# [Dataset](#contents)
This network can recognize the euler angel of human head and 5 key points of human face.
We use about 122K face images as training dataset and 2K as evaluating dataset in this example, and you can also use your own datasets or open source datasets (e.g. 300W-LP as training dataset, AFLW2000 as evaluating dataset)
- step 1: The training dataset should be saved in a txt file, which contains the following contents:
```python
[PATH_TO_IMAGE]/1.jpg [YAW] [PITCH] [ROLL] [LEFT_EYE_CENTER_X] [LEFT_EYE_CENTER_Y] [RIGHT_EYE_CENTER_X] [RIGHT_EYE_CENTER_Y] [NOSE_TIP_X] [NOSE_TIP_Y] [MOUTH_LEFT_CORNER_X] [MOUTH_LEFT_CORNER_Y] [MOUTH_RIGHT_CORNER_X] [MOUTH_RIGHT_CORNER_Y]
[PATH_TO_IMAGE]/2.jpg [YAW] [PITCH] [ROLL] [LEFT_EYE_CENTER_X] [LEFT_EYE_CENTER_Y] [RIGHT_EYE_CENTER_X] [RIGHT_EYE_CENTER_Y] [NOSE_TIP_X] [NOSE_TIP_Y] [MOUTH_LEFT_CORNER_X] [MOUTH_LEFT_CORNER_Y] [MOUTH_RIGHT_CORNER_X] [MOUTH_RIGHT_CORNER_Y]
[PATH_TO_IMAGE]/3.jpg [YAW] [PITCH] [ROLL] [LEFT_EYE_CENTER_X] [LEFT_EYE_CENTER_Y] [RIGHT_EYE_CENTER_X] [RIGHT_EYE_CENTER_Y] [NOSE_TIP_X] [NOSE_TIP_Y] [MOUTH_LEFT_CORNER_X] [MOUTH_LEFT_CORNER_Y] [MOUTH_RIGHT_CORNER_X] [MOUTH_RIGHT_CORNER_Y]
...
e.g. /home/train/1.jpg -33.073415 -9.533774 -9.285695 229.802368 257.432800 289.186188 262.831543 271.241638 301.224426 218.571747 322.097321 277.498291 328.260376
The label info are separated by '\t'.
Set -1 when the keypoint is not visible.
```
- step 2: The directory structure of evaluating dataset is as follows:
```python
├─ dataset
├─ img1.jpg
├─ img1.txt
├─ img2.jpg
├─ img2.txt
├─ img3.jpg
├─ img3.txt
├─ ...
```
The txt file contains the following contents:
```python
[YAW] [PITCH] [ROLL] [LEFT_EYE_CENTER_X] [LEFT_EYE_CENTER_Y] [RIGHT_EYE_CENTER_X] [RIGHT_EYE_CENTER_Y] [NOSE_TIP_X] [NOSE_TIP_Y] [MOUTH_LEFT_CORNER_X] [MOUTH_LEFT_CORNER_Y] [MOUTH_RIGHT_CORNER_X] [MOUTH_RIGHT_CORNER_Y]
The label info are separated by ' '.
Set -1 when the keypoint is not visible.
```
# [Environment Requirements](#contents)
- Hardware(Ascend)
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below:
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
The entire code structure is as following:
```python
.
└─ Face Quality Assessment
├─ README.md
├─ scripts
├─ run_standalone_train.sh # launch standalone training(1p) in ascend
├─ run_distribute_train.sh # launch distributed training(8p) in ascend
├─ run_eval.sh # launch evaluating in ascend
└─ run_export.sh # launch exporting air model
├─ src
├─ config.py # parameter configuration
├─ dataset.py # dataset loading and preprocessing for training
├─ face_qa.py # network backbone
├─ log.py # log function
├─ loss_factory.py # loss function
└─ lr_generator.py # generate learning rate
├─ train.py # training scripts
├─ eval.py # evaluation scripts
└─ export.py # export air model
```
## [Running Example](#contents)
### Train
- Stand alone mode
```bash
cd ./scripts
sh run_standalone_train.sh [TRAIN_LABEL_FILE] [USE_DEVICE_ID]
```
or (fine-tune)
```bash
cd ./scripts
sh run_standalone_train.sh [TRAIN_LABEL_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
```
for example:
```bash
cd ./scripts
sh run_standalone_train.sh /home/train.txt 0 /home/a.ckpt
```
- Distribute mode (recommended)
```bash
cd ./scripts
sh run_distribute_train.sh [TRAIN_LABEL_FILE] [RANK_TABLE]
```
or (fine-tune)
```bash
cd ./scripts
sh run_distribute_train.sh [TRAIN_LABEL_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]
```
for example:
```bash
cd ./scripts
sh run_distribute_train.sh /home/train.txt ./rank_table_8p.json /home/a.ckpt
```
You will get the loss value of each step as following in "./output/[TIME]/[TIME].log" or "./scripts/device0/train.log":
```python
epoch[0], iter[0], loss:39.206444, 5.31 imgs/sec
epoch[0], iter[10], loss:38.200620, 10423.44 imgs/sec
epoch[0], iter[20], loss:31.253260, 13555.87 imgs/sec
epoch[0], iter[30], loss:26.349678, 8762.34 imgs/sec
epoch[0], iter[40], loss:23.469613, 7848.85 imgs/sec
...
epoch[39], iter[19080], loss:1.881406, 7620.63 imgs/sec
epoch[39], iter[19090], loss:2.091236, 7601.15 imgs/sec
epoch[39], iter[19100], loss:2.140766, 8088.52 imgs/sec
epoch[39], iter[19110], loss:2.111101, 8791.05 imgs/sec
```
### Evaluation
```bash
cd ./scripts
sh run_eval.sh [EVAL_DIR] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
```
for example:
```bash
cd ./scripts
sh run_eval.sh /home/eval/ 0 /home/a.ckpt
```
You will get the result as following in "./scripts/device0/eval.log" or txt file in [PRETRAINED_BACKBONE]'s folder:
```python
5 keypoints average err:['4.069', '3.439', '4.001', '3.206', '3.413']
3 eulers average err:['21.667', '15.627', '16.770']
IPN of 5 keypoints:19.57019303768714
MAE of elur:18.021210976971098
```
### Convert model
If you want to infer the network on Ascend 310, you should convert the model to AIR:
```bash
cd ./scripts
sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | Face Quality Assessment |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | V1 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 09/30/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 122K images |
| Training Parameters | epoch=40, batch_size=32, momentum=0.9, lr=0.02 |
| Optimizer | Momentum |
| Loss Function | MSELoss, Softmax Cross Entropy |
| outputs | probability and point |
| Speed | 1pc: 200~240 ms/step; 8pcs: 35~40 ms/step |
| Total time | 1ps: 2.5 hours; 8pcs: 0.5 hours |
| Checkpoint for Fine tuning | 16M (.ckpt file) |
### Evaluation Performance
| Parameters | Face Quality Assessment |
| ------------------- | --------------------------- |
| Model Version | V1 |
| Resource | Ascend 910 |
| Uploaded Date | 09/30/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 2K images |
| batch_size | 1 |
| outputs | IPN, MAE |
| Accuracy(8pcs) | IPN of 5 keypoints:19.5 |
| | MAE of elur:18.02 |
| Model for inference | 16M (.ckpt file) |
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,216 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Quality Assessment eval."""
import os
import warnings
import argparse
import numpy as np
from tqdm import tqdm
import cv2
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.ops import operations as P
from mindspore import context
from src.face_qa import FaceQABackbone
warnings.filterwarnings('ignore')
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
return np.exp(x) / np.sum(np.exp(x), axis=1)
def get_md_output(out):
'''get md output'''
out_eul = out[0].asnumpy().astype(np.float32)[0]
heatmap = out[1].asnumpy().astype(np.float32)[0]
eulers = out_eul * 90
kps_score_sum = 0
kp_scores = list()
kp_coord_ori = list()
for i, _ in enumerate(heatmap):
map_1 = heatmap[i].reshape(1, 48*48)
map_1 = softmax(map_1)
kp_coor = map_1.argmax()
max_response = map_1.max()
kp_scores.append(max_response)
kps_score_sum += min(max_response, 0.25)
kp_coor = int((kp_coor % 48) * 2.0), int((kp_coor / 48) * 2.0)
kp_coord_ori.append(kp_coor)
return kp_scores, kps_score_sum, kp_coord_ori, eulers, 1
def read_gt(txt_path, x_length, y_length):
'''read gt'''
txt_line = open(txt_path).readline()
eulers_txt = txt_line.strip().split(" ")[:3]
kp_list = [[-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]]
box_cur = txt_line.strip().split(" ")[3:]
bndbox = []
for index in range(len(box_cur) // 2):
bndbox.append([box_cur[index * 2], box_cur[index * 2 + 1]])
kp_id = -1
for box in bndbox:
kp_id = kp_id + 1
x_coord = float(box[0])
y_coord = float(box[1])
if x_coord < 0 or y_coord < 0:
continue
kp_list[kp_id][0] = int(float(x_coord) / x_length * 96)
kp_list[kp_id][1] = int(float(y_coord) / y_length * 96)
return eulers_txt, kp_list
def read_img(img_path):
img_ori = cv2.imread(img_path)
img = cv2.cvtColor(img_ori, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (96, 96))
img = img.transpose(2, 0, 1)
img = np.array([img]).astype(np.float32)/255.
img = Tensor(img)
return img, img_ori
blur_soft = nn.Softmax(0)
kps_soft = nn.Softmax(-1)
reshape = P.Reshape()
argmax = P.ArgMaxWithValue()
def test_trains(args):
'''test trains'''
print('----eval----begin----')
model_path = args.pretrained
result_file = model_path.replace('.ckpt', '.txt')
if os.path.exists(result_file):
os.remove(result_file)
epoch_result = open(result_file, 'a')
epoch_result.write(model_path + '\n')
network = FaceQABackbone()
ckpt_path = model_path
if os.path.isfile(ckpt_path):
param_dict = load_checkpoint(ckpt_path)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
else:
print('wrong model path')
return 1
path = args.eval_dir
kp_error_all = [[], [], [], [], []]
eulers_error_all = [[], [], []]
kp_ipn = []
file_list = os.listdir(path)
for file_name in tqdm(file_list):
if file_name.endswith('jpg'):
img_path = os.path.join(path, file_name)
img, img_ori = read_img(img_path)
txt_path = img_path.replace('jpg', 'txt')
if os.path.exists(txt_path):
euler_kps_do = True
x_length = img_ori.shape[1]
y_length = img_ori.shape[0]
eulers_gt, kp_list = read_gt(txt_path, x_length, y_length)
else:
euler_kps_do = False
continue
out = network(img)
_, _, kp_coord_ori, eulers_ori, _ = get_md_output(out)
if euler_kps_do:
eulgt = list(eulers_gt)
for euler_id, _ in enumerate(eulers_ori):
eulori = eulers_ori[euler_id]
eulers_error_all[euler_id].append(abs(eulori-float(eulgt[euler_id])))
eye01 = kp_list[0]
eye02 = kp_list[1]
eye_dis = 1
cur_flag = True
if eye01[0] < 0 or eye01[1] < 0 or eye02[0] < 0 or eye02[1] < 0:
cur_flag = False
else:
eye_dis = np.sqrt(np.square(abs(eye01[0]-eye02[0]))+np.square(abs(eye01[1]-eye02[1])))
cur_error_list = []
for i in range(5):
kp_coord_gt = kp_list[i]
kp_coord_model = kp_coord_ori[i]
if kp_coord_gt[0] != -1:
dis = np.sqrt(np.square(
kp_coord_gt[0] - kp_coord_model[0]) + np.square(kp_coord_gt[1] - kp_coord_model[1]))
kp_error_all[i].append(dis)
cur_error_list.append(dis)
if cur_flag:
kp_ipn.append(sum(cur_error_list)/len(cur_error_list)/eye_dis)
kp_ave_error = []
for kps, _ in enumerate(kp_error_all):
kp_ave_error.append("%.3f" % (sum(kp_error_all[kps])/len(kp_error_all[kps])))
euler_ave_error = []
elur_mae = []
for eulers, _ in enumerate(eulers_error_all):
euler_ave_error.append("%.3f" % (sum(eulers_error_all[eulers])/len(eulers_error_all[eulers])))
elur_mae.append((sum(eulers_error_all[eulers])/len(eulers_error_all[eulers])))
print(r'5 keypoints average err:'+str(kp_ave_error))
print(r'3 eulers average err:'+str(euler_ave_error))
print('IPN of 5 keypoints:'+str(sum(kp_ipn)/len(kp_ipn)*100))
print('MAE of elur:'+str(sum(elur_mae)/len(elur_mae)))
epoch_result.write(str(sum(kp_ipn)/len(kp_ipn)*100)+'\t'+str(sum(elur_mae)/len(elur_mae))+'\t'
+ str(kp_ave_error)+'\t'+str(euler_ave_error)+'\n')
print('----eval----end----')
return 0
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Face Quality Assessment')
parser.add_argument('--eval_dir', type=str, default='', help='eval image dir, e.g. /home/test')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
arg = parser.parse_args()
test_trains(arg)

View File

@ -0,0 +1,64 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert ckpt to air."""
import os
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.face_qa import FaceQABackbone
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
def main(args):
network = FaceQABackbone()
ckpt_path = args.pretrained
if os.path.isfile(ckpt_path):
param_dict = load_checkpoint(ckpt_path)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
print('-----------------------load model success-----------------------')
else:
print('-----------------------load model failed -----------------------')
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 96, 96)).astype(np.float32)
tensor_input_data = Tensor(input_data)
export(network, tensor_input_data, file_name=ckpt_path.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'),
file_format='AIR')
print('-----------------------export model success-----------------------')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Convert ckpt to air')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
arg = parser.parse_args()
main(arg)

View File

@ -0,0 +1,88 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 -a $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [TRAIN_LABEL_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]"
echo " or: sh run_distribute_train.sh [TRAIN_LABEL_FILE] [RANK_TABLE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname $(pwd))
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
SCRIPT_NAME='train.py'
rm -rf ${current_exec_path}/device*
ulimit -c unlimited
TRAIN_LABEL_FILE=$(get_real_path $1)
RANK_TABLE=$(get_real_path $2)
PRETRAINED_BACKBONE=''
if [ ! -f $TRAIN_LABEL_FILE ]
then
echo "error: TRAIN_LABEL_FILE=$TRAIN_LABEL_FILE is not a file"
exit 1
fi
if [ $# == 3 ]
then
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
fi
echo $TRAIN_LABEL_FILE
echo $RANK_TABLE
echo $PRETRAINED_BACKBONE
export RANK_TABLE_FILE=$RANK_TABLE
export RANK_SIZE=8
echo 'start training'
for((i=0;i<=$RANK_SIZE-1;i++));
do
echo 'start rank '$i
mkdir ${current_exec_path}/device$i
cd ${current_exec_path}/device$i
export RANK_ID=$i
dev=`expr $i + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--is_distributed=1 \
--train_label_file=$TRAIN_LABEL_FILE \
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
done
echo 'running'

View File

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_eval.sh [EVAL_DIR] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname $(pwd))
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
SCRIPT_NAME='eval.py'
ulimit -c unlimited
EVAL_DIR=$(get_real_path $1)
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
echo $EVAL_DIR
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
echo 'start evaluating'
export RANK_ID=0
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
echo 'start device '$USE_DEVICE_ID
mkdir ${current_exec_path}/device$USE_DEVICE_ID
cd ${current_exec_path}/device$USE_DEVICE_ID
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--eval_dir=$EVAL_DIR \
--pretrained=$PRETRAINED_BACKBONE > eval.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname $(pwd))
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
SCRIPT_NAME='export.py'
ulimit -c unlimited
BATCH_SIZE=$1
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
echo $BATCH_SIZE
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
echo 'start converting'
export RANK_ID=0
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
echo 'start device '$USE_DEVICE_ID
mkdir ${current_exec_path}/device$USE_DEVICE_ID
cd ${current_exec_path}/device$USE_DEVICE_ID
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--batch_size=$BATCH_SIZE \
--pretrained=$PRETRAINED_BACKBONE > convert.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,83 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 -a $# != 3 ]
then
echo "Usage: sh run_standalone_train.sh [TRAIN_LABEL_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
echo " or: sh run_standalone_train.sh [TRAIN_LABEL_FILE] [USE_DEVICE_ID]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname $(pwd))
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
SCRIPT_NAME='train.py'
ulimit -c unlimited
TRAIN_LABEL_FILE=$(get_real_path $1)
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=''
if [ ! -f $TRAIN_LABEL_FILE ]
then
echo "error: TRAIN_LABEL_FILE=$TRAIN_LABEL_FILE is not a file"
exit 1
fi
if [ $# == 3 ]
then
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
fi
echo $TRAIN_LABEL_FILE
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
echo 'start training'
export RANK_ID=0
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
echo 'start device '$USE_DEVICE_ID
mkdir ${current_exec_path}/device$USE_DEVICE_ID
cd ${current_exec_path}/device$USE_DEVICE_ID
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--is_distributed=0 \
--train_label_file=$TRAIN_LABEL_FILE \
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,76 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network config setting, will be used in train.py and eval.py"""
from easydict import EasyDict as edict
faceqa_1p_cfg = edict({
'task': 'face_qa',
# dataset related
'per_batch_size': 256,
# network structure related
'steps_per_epoch': 0,
'loss_scale': 1024,
# optimizer related
'lr': 0.02,
'lr_scale': 1,
'lr_epochs': '10, 20, 30',
'weight_decay': 0.0005,
'momentum': 0.9,
'max_epoch': 40,
'warmup_epochs': 0,
'pretrained': '',
'local_rank': 0,
'world_size': 1,
# logging related
'log_interval': 10,
'ckpt_path': '../../output',
'ckpt_interval': 500,
'device_id': 0,
})
faceqa_8p_cfg = edict({
'task': 'face_qa',
# dataset related
'per_batch_size': 32,
# network structure related
'steps_per_epoch': 0,
'loss_scale': 1024,
# optimizer related
'lr': 0.02,
'lr_scale': 1,
'lr_epochs': '10, 20, 30',
'weight_decay': 0.0005,
'momentum': 0.9,
'max_epoch': 40,
'warmup_epochs': 0,
'pretrained': '',
'local_rank': 0,
'world_size': 8,
# logging related
'log_interval': 10, # 10
'ckpt_path': '../../output',
'ckpt_interval': 500,
})

View File

@ -0,0 +1,120 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Quality Assessment dataset."""
import math
import warnings
import numpy as np
from PIL import Image, ImageFile
import mindspore.dataset as ds
import mindspore.dataset.vision.py_transforms as F
import mindspore.dataset.transforms.py_transforms as F2
warnings.filterwarnings('ignore')
ImageFile.LOAD_TRUNCATED_IMAGES = True
class MdFaceDataset():
"""Face Landmarks dataset."""
def __init__(self, imlist,
img_shape=(96, 96),
heatmap_shape=(48, 48)):
self.imlist = imlist
self.img_shape = img_shape
self.heatmap_shape = heatmap_shape
print('Reading data...')
with open(imlist) as fr:
self.imgs_info = fr.readlines()
def _trans_cor(self, landmark, x_length, y_length):
'''_trans_cor'''
landmark = list(map(float, landmark))
landmark = np.array(landmark).reshape((5, 2))
landmark_class_label = []
for _, cor in enumerate(landmark):
x, y = cor
if x < 0:
heatmap_label = -1
else:
x = float(x) / float(x_length) * 96.
y = float(y) / float(y_length) * 96.
x_out = int(x * 1.0 * self.heatmap_shape[1] / self.img_shape[1])
y_out = int(y * 1.0 * self.heatmap_shape[0] / self.img_shape[0])
heatmap_label = y_out * self.heatmap_shape[1] + x_out
if heatmap_label >= self.heatmap_shape[0]*self.heatmap_shape[1] or heatmap_label < 0:
heatmap_label = -1
landmark_class_label.append(heatmap_label)
return landmark_class_label
def __len__(self):
return len(self.imgs_info)
def __getitem__(self, idx):
path_label_info = self.imgs_info[idx].strip().split('\t')
impath = path_label_info[0]
image = Image.open(impath).convert('RGB')
x_length = image.size[0]
y_length = image.size[1]
image = image.resize((96, 96))
landmarks = self._trans_cor(path_label_info[4:14], x_length, y_length)
eulers = np.array([e / 90. for e in list(map(float, path_label_info[1:4]))])
labels = np.concatenate([eulers, landmarks], axis=0)
sample = image
return sample, labels
class DistributedSampler():
'''DistributedSampler'''
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
self.dataset = dataset
self.rank = rank
self.group_size = group_size
self.dataset_length = len(self.dataset)
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
self.total_size = self.num_samples * self.group_size
self.shuffle = shuffle
self.seed = seed
def __iter__(self):
if self.shuffle:
self.seed = (self.seed + 1) & 0xffffffff
np.random.seed(self.seed)
indices = np.random.permutation(self.dataset_length).tolist()
else:
indices = list(range(len(self.dataset.classes)))
indices += indices[:(self.total_size - len(indices))]
indices = indices[self.rank::self.group_size]
return iter(indices)
def __len__(self):
return self.num_samples
def faceqa_dataset(imlist, per_batch_size, local_rank, world_size):
'''faceqa dataset'''
transform_img = F2.Compose([F.ToTensor()])
dataset = MdFaceDataset(imlist)
sampler = DistributedSampler(dataset, local_rank, world_size)
de_dataset = ds.GeneratorDataset(dataset, ["image", "label"], sampler=sampler, num_parallel_workers=8,
python_multiprocessing=True)
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=8,
python_multiprocessing=True)
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=True)
return de_dataset

View File

@ -0,0 +1,237 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Quality Assessment backbone."""
import mindspore.nn as nn
from mindspore.ops.operations import TensorAdd
from mindspore.ops import operations as P
from mindspore.nn import Dense, Cell
class Cut(nn.Cell):
def construct(self, x):
return x
def bn_with_initialize(out_channels):
bn = nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5)
return bn
def fc_with_initialize(input_channels, out_channels):
return Dense(input_channels, out_channels)
def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
pad_mode=pad_mode, group=groups, has_bias=False, dilation=dilation, padding=padding)
def conv1x1(in_channels, out_channels, pad_mode="pad", stride=1, padding=0):
"""1x1 convolution"""
return nn.Conv2d(in_channels, out_channels, pad_mode=pad_mode, kernel_size=1, stride=stride, has_bias=False,
padding=padding)
def conv4x4(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1):
"""4x4 convolution with padding"""
return nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride,
pad_mode=pad_mode, group=groups, has_bias=False, dilation=dilation, padding=padding)
class Block1(Cell):
'''Block1'''
def __init__(self):
super(Block1, self).__init__()
self.bk1_conv0 = conv3x3(64, 64, stride=1, padding=1)
self.bk1_bn0 = bn_with_initialize(64)
self.bk1_relu0 = P.ReLU()
self.bk1_conv1 = conv3x3(64, 64, stride=1, padding=1)
self.bk1_bn1 = bn_with_initialize(64)
self.bk1_conv2 = conv1x1(64, 64, stride=1, padding=0)
self.bk1_bn2 = bn_with_initialize(64)
self.bk1_relu1 = P.ReLU()
self.bk1_conv3 = conv3x3(64, 64, stride=1, padding=1)
self.bk1_bn3 = bn_with_initialize(64)
self.bk1_relu3 = P.ReLU()
self.bk1_conv4 = conv3x3(64, 64, stride=1, padding=1)
self.bk1_bn4 = bn_with_initialize(64)
self.bk1_relu4 = P.ReLU()
self.cast = P.Cast()
self.add = TensorAdd()
def construct(self, x):
'''construct'''
identity = x
out = self.bk1_conv0(x)
out = self.bk1_bn0(out)
out = self.bk1_relu0(out)
out = self.bk1_conv1(out)
out = self.bk1_bn1(out)
identity = self.bk1_conv2(identity)
identity = self.bk1_bn2(identity)
out = self.add(out, identity)
out = self.bk1_relu1(out)
identity = out
out = self.bk1_conv3(out)
out = self.bk1_bn3(out)
out = self.bk1_relu3(out)
out = self.bk1_conv4(out)
out = self.bk1_bn4(out)
out = self.add(out, identity)
out = self.bk1_relu4(out)
return out
class Block2(Cell):
'''Block2'''
def __init__(self):
super(Block2, self).__init__()
self.bk2_conv0 = conv3x3(64, 128, stride=2, padding=1)
self.bk2_bn0 = bn_with_initialize(128)
self.bk2_relu0 = P.ReLU()
self.bk2_conv1 = conv3x3(128, 128, stride=1, padding=1)
self.bk2_bn1 = bn_with_initialize(128)
self.bk2_conv2 = conv1x1(64, 128, stride=2, padding=0)
self.bk2_bn2 = bn_with_initialize(128)
self.bk2_relu1 = P.ReLU()
self.bk2_conv3 = conv3x3(128, 128, stride=1, padding=1)
self.bk2_bn3 = bn_with_initialize(128)
self.bk2_relu3 = P.ReLU()
self.bk2_conv4 = conv3x3(128, 128, stride=1, padding=1)
self.bk2_bn4 = bn_with_initialize(128)
self.bk2_relu4 = P.ReLU()
self.cast = P.Cast()
self.add = TensorAdd()
def construct(self, x):
'''construct'''
identity = x
out = self.bk2_conv0(x)
out = self.bk2_bn0(out)
out = self.bk2_relu0(out)
out = self.bk2_conv1(out)
out = self.bk2_bn1(out)
identity = self.bk2_conv2(identity)
identity = self.bk2_bn2(identity)
out = self.add(out, identity)
out = self.bk2_relu1(out)
identity = out
out = self.bk2_conv3(out)
out = self.bk2_bn3(out)
out = self.bk2_relu3(out)
out = self.bk2_conv4(out)
out = self.bk2_bn4(out)
out = self.add(out, identity)
out = self.bk2_relu4(out)
return out
class FaceQABackbone(Cell):
'''FaceQABackbone'''
def __init__(self):
super(FaceQABackbone, self).__init__()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.conv0 = conv3x3(3, 64, stride=2, padding=1)
self.bn0 = bn_with_initialize(64)
self.relu0 = P.ReLU()
self.conv1 = conv3x3(64, 64, stride=2, padding=1)
self.bn1 = bn_with_initialize(64)
self.relu1 = P.ReLU()
self.backbone = nn.SequentialCell([
Block1(),
Block2()
])
# branch euler
self.euler_conv = conv3x3(128, 128, stride=2, padding=1)
self.euler_bn = bn_with_initialize(128)
self.euler_relu = P.ReLU()
self.euler_fc1 = fc_with_initialize(128*6*6, 256)
self.euler_relu1 = P.ReLU()
self.euler_fc2 = fc_with_initialize(256, 128)
self.euler_relu2 = P.ReLU()
self.euler_fc3 = fc_with_initialize(128, 3)
# branch heatmap
self.kps_deconv = nn.Conv2dTranspose(128, 5, 4, stride=2, pad_mode='pad', group=1, dilation=1, padding=1,
has_bias=False)
self.kps_up = nn.Conv2dTranspose(5, 5, 4, stride=2, pad_mode='pad', group=1, dilation=1, padding=1,
has_bias=False)
def construct(self, x):
'''construct'''
# backbone
x = self.conv0(x)
x = self.bn0(x)
x = self.relu0(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.backbone(x)
# branch euler
out1 = self.euler_conv(x)
out1 = self.euler_bn(out1)
out1 = self.euler_relu(out1)
b, _, _, _ = self.shape(out1)
out1 = self.reshape(out1, (b, -1))
out1 = self.euler_fc1(out1)
out1 = self.euler_relu1(out1)
out1 = self.euler_fc2(out1)
out1 = self.euler_relu2(out1)
out1 = self.euler_fc3(out1)
# branch kps
out2 = self.kps_deconv(x)
out2 = self.kps_up(out2)
return out1, out2
class BuildTrainNetwork(nn.Cell):
'''BuildTrainNetwork'''
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
def construct(self, input_data, label):
out_eul, out_kps = self.network(input_data)
loss = self.criterion(out_eul, out_kps, label)
return loss

View File

@ -0,0 +1,105 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom logger."""
import logging
import os
import sys
from datetime import datetime
logger_name_1 = 'face_qw_me'
class LOGGER(logging.Logger):
'''LOGGER'''
def __init__(self, logger_name):
super(LOGGER, self).__init__(logger_name)
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
self.local_rank = 0
def setup_logging_file(self, log_dir, local_rank=0):
'''setup_logging_file'''
self.local_rank = local_rank
if self.local_rank == 0:
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '.log'
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
logger = LOGGER(logger_name_1)
logger.setup_logging_file(path, rank)
return logger
class AverageMeter():
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)

View File

@ -0,0 +1,99 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Quality Assessment loss."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore import Tensor
eps = 1e-24
class CEWithIgnoreIndex3D(_Loss):
'''CEWithIgnoreIndex3D'''
def __init__(self):
super(CEWithIgnoreIndex3D, self).__init__()
self.exp = P.Exp()
self.sum = P.ReduceSum()
self.reshape = P.Reshape()
self.log = P.Log()
self.cast = P.Cast()
self.eps_const = Tensor(eps, dtype=mstype.float32)
self.ones = P.OnesLike()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.relu = P.ReLU()
self.maximum = P.Maximum()
self.resum = P.ReduceSum(keep_dims=False)
def construct(self, logit, label):
'''construct'''
mask = self.reshape(label, (F.shape(label)[0], F.shape(label)[1], 1))
mask = self.cast(mask, mstype.float32)
mask = mask + F.scalar_to_array(0.00001)
mask = self.relu(mask) / (mask)
logit = logit * mask
exp = self.exp(logit)
exp_sum = self.sum(exp, -1)
exp_sum = self.reshape(exp_sum, (F.shape(exp_sum)[0], F.shape(exp_sum)[1], 1))
softmax_result = self.log(exp / exp_sum + self.eps_const)
one_hot_label = self.onehot(
self.cast(label, mstype.int32), F.shape(logit)[2], self.on_value, self.off_value)
loss = (softmax_result * self.cast(one_hot_label, mstype.float32) * self.cast(F.scalar_to_array(-1),
mstype.float32))
loss = self.sum(loss, -1)
loss = self.sum(loss, -1)
loss = self.sum(loss, 0)
loss = loss
return loss
class CriterionsFaceQA(nn.Cell):
'''CriterionsFaceQA'''
def __init__(self):
super(CriterionsFaceQA, self).__init__()
self.gatherv2 = P.GatherV2()
self.squeeze = P.Squeeze(axis=1)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.euler_label_list = Tensor([0, 1, 2], dtype=mstype.int32)
self.mse_loss = nn.MSELoss(reduction='sum')
self.kp_label_list = Tensor([3, 4, 5, 6, 7], dtype=mstype.int32)
self.kps_loss = CEWithIgnoreIndex3D()
def construct(self, x1, x2, label):
'''construct'''
# euler
euler_label = self.gatherv2(label, self.euler_label_list, 1)
loss_euler = self.mse_loss(x1, euler_label)
# key points
b, _, _, _ = self.shape(x2)
x2 = self.reshape(x2, (b, 5, 48 * 48))
kps_label = self.gatherv2(label, self.kp_label_list, 1)
loss_kps = self.kps_loss(x2, kps_label)
loss_tot = (loss_kps + loss_euler) / b
return loss_tot

View File

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Quality Assessment learning rate scheduler."""
from collections import Counter
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def warmup_step(args, gamma=0.1):
'''warmup_step'''
base_lr = args.lr
warmup_init_lr = 0
total_steps = int(args.max_epoch * args.steps_per_epoch)
warmup_steps = int(args.warmup_epochs * args.steps_per_epoch)
milestones = args.lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone*args.steps_per_epoch
milestones_steps.append(milestones_step)
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_learning_rate(i, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
yield lr

View File

@ -0,0 +1,187 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Quality Assessment train."""
import os
import time
import datetime
import argparse
import warnings
import numpy as np
import mindspore
from mindspore import context
from mindspore import Tensor
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
from mindspore.nn import TrainOneStepCell
from mindspore.nn.optim import Momentum
from mindspore.communication.management import get_group_size, init, get_rank
from src.loss import CriterionsFaceQA
from src.config import faceqa_1p_cfg, faceqa_8p_cfg
from src.face_qa import FaceQABackbone, BuildTrainNetwork
from src.lr_generator import warmup_step
from src.dataset import faceqa_dataset
from src.log import get_logger, AverageMeter
warnings.filterwarnings('ignore')
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
mindspore.common.seed.set_seed(1)
def main(args):
if args.is_distributed == 0:
cfg = faceqa_1p_cfg
else:
cfg = faceqa_8p_cfg
cfg.data_lst = args.train_label_file
cfg.pretrained = args.pretrained
# Init distributed
if args.is_distributed:
init()
cfg.local_rank = get_rank()
cfg.world_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
# parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.world_size,
gradients_mean=True)
mindspore.common.set_seed(1)
# logger
cfg.outputs_dir = os.path.join(cfg.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank)
loss_meter = AverageMeter('loss')
# Dataloader
cfg.logger.info('start create dataloader')
de_dataset = faceqa_dataset(imlist=cfg.data_lst, local_rank=cfg.local_rank, world_size=cfg.world_size,
per_batch_size=cfg.per_batch_size)
cfg.steps_per_epoch = de_dataset.get_dataset_size()
de_dataset = de_dataset.repeat(cfg.max_epoch)
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
# Show cfg
cfg.logger.save_args(cfg)
cfg.logger.info('end create dataloader')
# backbone and loss
cfg.logger.important_info('start create network')
create_network_start = time.time()
network = FaceQABackbone()
criterion = CriterionsFaceQA()
# load pretrain model
if os.path.isfile(cfg.pretrained):
param_dict = load_checkpoint(cfg.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
cfg.logger.info('load model {} success'.format(cfg.pretrained))
# optimizer and lr scheduler
lr = warmup_step(cfg, gamma=0.9)
opt = Momentum(params=network.trainable_params(),
learning_rate=lr,
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale)
# package training process, adjust lr + forward + backward + optimizer
train_net = BuildTrainNetwork(network, criterion)
train_net = TrainOneStepCell(train_net, opt, sens=cfg.loss_scale,)
# checkpoint save
if cfg.local_rank == 0:
ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval
train_config = CheckpointConfig(save_checkpoint_steps=cfg.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=train_config, directory=cfg.outputs_dir, prefix='{}'.format(cfg.local_rank))
cb_params = _InternalCallbackParam()
cb_params.train_network = train_net
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 1
run_context = RunContext(cb_params)
ckpt_cb.begin(run_context)
train_net.set_train()
t_end = time.time()
t_epoch = time.time()
old_progress = -1
cfg.logger.important_info('====start train====')
for i, (data, gt) in enumerate(de_dataloader):
# clean grad + adjust lr + put data into device + forward + backward + optimizer, return loss
data = data.astype(np.float32)
gt = gt.astype(np.float32)
data = Tensor(data)
gt = Tensor(gt)
loss = train_net(data, gt)
loss_meter.update(loss.asnumpy())
# ckpt
if cfg.local_rank == 0:
cb_params.cur_step_num = i + 1 # current step number
cb_params.batch_num = i + 2
ckpt_cb.step_end(run_context)
# logging loss, fps, ...
if i == 0:
time_for_graph_compile = time.time() - create_network_start
cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(cfg.task, time_for_graph_compile))
if i % cfg.log_interval == 0 and cfg.local_rank == 0:
time_used = time.time() - t_end
epoch = int(i / cfg.steps_per_epoch)
fps = cfg.per_batch_size * (i - old_progress) * cfg.world_size / time_used
cfg.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
t_end = time.time()
loss_meter.reset()
old_progress = i
if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0:
epoch_time_used = time.time() - t_epoch
epoch = int(i / cfg.steps_per_epoch)
fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used
cfg.logger.info('=================================================')
cfg.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
cfg.logger.info('=================================================')
t_epoch = time.time()
cfg.logger.important_info('====train end====')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Face Quality Assessment')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--train_label_file', type=str, default='', help='image label list file, e.g. /home/label.txt')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
arg = parser.parse_args()
main(arg)