forked from mindspore-Ecosystem/mindspore
Add FaceRecognitionForTracking net to /model_zoo/research/cv
This commit is contained in:
parent
6cf308076d
commit
6b3663d123
|
@ -0,0 +1,228 @@
|
|||
# Contents
|
||||
|
||||
- [Face Recognition For Tracking Description](#face-recognition-for-tracking-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Running Example](#running-example)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [Face Recognition For Tracking Description](#contents)
|
||||
|
||||
This is a face recognition for tracking network based on Resnet, with support for training and evaluation on Ascend910.
|
||||
|
||||
ResNet (residual neural network) was proposed by Kaiming He and other four Chinese of Microsoft Research Institute. Through the use of ResNet unit, it successfully trained 152 layers of neural network, and won the championship in ilsvrc2015. The error rate on top 5 was 3.57%, and the parameter quantity was lower than vggnet, so the effect was very outstanding. Traditional convolution network or full connection network will have more or less information loss. At the same time, it will lead to the disappearance or explosion of gradient, which leads to the failure of deep network training. ResNet solves this problem to a certain extent. By passing the input information to the output, the integrity of the information is protected. The whole network only needs to learn the part of the difference between input and output, which simplifies the learning objectives and difficulties.The structure of ResNet can accelerate the training of neural network very quickly, and the accuracy of the model is also greatly improved. At the same time, ResNet is very popular, even can be directly used in the concept net network.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1512.03385.pdf): Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition"
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Face Recognition For Tracking uses a Resnet network for performing feature extraction.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
We use about 10K face images as training dataset and 2K as evaluating dataset in this example, and you can also use your own datasets or open source datasets (e.g. Labeled Faces in the Wild)
|
||||
The directory structure is as follows:
|
||||
|
||||
```python
|
||||
.
|
||||
└─ dataset
|
||||
├─ train dataset
|
||||
├─ ID1
|
||||
├─ ID1_0001.jpg
|
||||
├─ ID1_0002.jpg
|
||||
...
|
||||
├─ ID2
|
||||
...
|
||||
├─ ID3
|
||||
...
|
||||
...
|
||||
├─ test dataset
|
||||
├─ ID1
|
||||
├─ ID1_0001.jpg
|
||||
├─ ID1_0002.jpg
|
||||
...
|
||||
├─ ID2
|
||||
...
|
||||
├─ ID3
|
||||
...
|
||||
...
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
The entire code structure is as following:
|
||||
|
||||
```python
|
||||
.
|
||||
└─ Face Recognition For Tracking
|
||||
├─ README.md
|
||||
├─ scripts
|
||||
├─ run_standalone_train.sh # launch standalone training(1p) in ascend
|
||||
├─ run_distribute_train.sh # launch distributed training(8p) in ascend
|
||||
├─ run_eval.sh # launch evaluating in ascend
|
||||
└─ run_export.sh # launch exporting air model
|
||||
├─ src
|
||||
├─ config.py # parameter configuration
|
||||
├─ dataset.py # dataset loading and preprocessing for training
|
||||
├─ reid.py # network backbone
|
||||
├─ reid_for_export.py # network backbone for export
|
||||
├─ log.py # log function
|
||||
├─ loss.py # loss function
|
||||
├─ lr_generator.py # generate learning rate
|
||||
└─ me_init.py # network initialization
|
||||
├─ train.py # training scripts
|
||||
├─ eval.py # evaluation scripts
|
||||
└─ export.py # export air model
|
||||
```
|
||||
|
||||
## [Running Example](#contents)
|
||||
|
||||
### Train
|
||||
|
||||
- Stand alone mode
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [DATA_DIR] [USE_DEVICE_ID]
|
||||
```
|
||||
|
||||
or (fine-tune)
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [DATA_DIR] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
|
||||
```
|
||||
|
||||
for example:
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh /home/train_dataset 0 /home/a.ckpt
|
||||
```
|
||||
|
||||
- Distribute mode (recommended)
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_distribute_train.sh [DATA_DIR] [RANK_TABLE]
|
||||
```
|
||||
|
||||
or (fine-tune)
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_distribute_train.sh [DATA_DIR] [RANK_TABLE] [PRETRAINED_BACKBONE]
|
||||
```
|
||||
|
||||
for example:
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_distribute_train.sh /home/train_dataset ./rank_table_8p.json /home/a.ckpt
|
||||
```
|
||||
|
||||
You will get the loss value of each step as following in "./output/[TIME]/[TIME].log" or "./scripts/device0/train.log":
|
||||
|
||||
```python
|
||||
epoch[0], iter[10], loss:43.314265, 8574.83 imgs/sec, lr=0.800000011920929
|
||||
epoch[0], iter[20], loss:45.121095, 8915.66 imgs/sec, lr=0.800000011920929
|
||||
epoch[0], iter[30], loss:42.342847, 9162.85 imgs/sec, lr=0.800000011920929
|
||||
epoch[0], iter[40], loss:39.456583, 9178.83 imgs/sec, lr=0.800000011920929
|
||||
|
||||
...
|
||||
epoch[179], iter[14900], loss:1.651353, 13001.25 imgs/sec, lr=0.02500000037252903
|
||||
epoch[179], iter[14910], loss:1.532123, 12669.85 imgs/sec, lr=0.02500000037252903
|
||||
epoch[179], iter[14920], loss:1.760322, 13457.81 imgs/sec, lr=0.02500000037252903
|
||||
epoch[179], iter[14930], loss:1.694281, 13417.38 imgs/sec, lr=0.02500000037252903
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_eval.sh [EVAL_DIR] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
|
||||
```
|
||||
|
||||
for example:
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_eval.sh /home/test_dataset 0 /home/a.ckpt
|
||||
```
|
||||
|
||||
You will get the result as following in "./scripts/device0/eval.log" or txt file in [PRETRAINED_BACKBONE]'s folder:
|
||||
|
||||
```python
|
||||
0.5: 0.9273788254649683@0.020893691253149882
|
||||
0.3: 0.8393850978779193@0.07438552515516506
|
||||
0.1: 0.6220871197028316@0.1523084478903911
|
||||
0.01: 0.2683641598437038@0.26217882879427634
|
||||
0.001: 0.11060269148211463@0.34509718987101223
|
||||
0.0001: 0.05381678898728808@0.4187797093636618
|
||||
1e-05: 0.035770748447963394@0.5053771466191392
|
||||
```
|
||||
|
||||
### Convert model
|
||||
|
||||
If you want to infer the network on Ascend 310, you should convert the model to AIR:
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | Face Recognition For Tracking |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
|
||||
| uploaded Date | 09/30/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | 10K images |
|
||||
| Training Parameters | epoch=180, batch_size=16, momentum=0.9 |
|
||||
| Optimizer | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| outputs | probability |
|
||||
| Speed | 1pc: 8~10 ms/step; 8pcs: 9~11 ms/step |
|
||||
| Total time | 1pc: 1 hours; 8pcs: 0.1 hours |
|
||||
| Checkpoint for Fine tuning | 17M (.ckpt file) |
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters |Face Recognition For Tracking|
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 09/30/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | 2K images |
|
||||
| batch_size | 128 |
|
||||
| outputs | recall |
|
||||
| Recall(8pcs) | 0.62(FAR=0.1) |
|
||||
| Model for inference | 17M (.ckpt file) |
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,185 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face Recognition eval."""
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
import argparse
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import mindspore.dataset.vision.py_transforms as V
|
||||
import mindspore.dataset.transforms.py_transforms as T
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.reid import SphereNet
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid)
|
||||
|
||||
|
||||
def inclass_likehood(ims_info, types='cos'):
|
||||
'''Inclass likehood.'''
|
||||
obj_feas = {}
|
||||
likehoods = []
|
||||
for name, _, fea in ims_info:
|
||||
if re.split('_\\d\\d\\d\\d', name)[0] not in obj_feas:
|
||||
obj_feas[re.split('_\\d\\d\\d\\d', name)[0]] = []
|
||||
obj_feas[re.split('_\\d\\d\\d\\d', name)[0]].append(fea) # pylint: "_\d\d\d\d" -> "_\\d\\d\\d\\d"
|
||||
for _, feas in tqdm(obj_feas.items()):
|
||||
feas = np.array(feas)
|
||||
if types == 'cos':
|
||||
likehood_mat = np.dot(feas, np.transpose(feas)).tolist()
|
||||
for row in likehood_mat:
|
||||
likehoods += row
|
||||
else:
|
||||
for fea in feas.tolist():
|
||||
likehoods += np.sum(-(fea - feas) ** 2, axis=1).tolist()
|
||||
|
||||
likehoods = np.array(likehoods)
|
||||
return likehoods
|
||||
|
||||
|
||||
def btclass_likehood(ims_info, types='cos'):
|
||||
'''Btclass likehood.'''
|
||||
likehoods = []
|
||||
count = 0
|
||||
for name1, _, fea1 in tqdm(ims_info):
|
||||
count += 1
|
||||
# pylint: "_\d\d\d\d" -> "_\\d\\d\\d\\d"
|
||||
frame_id1, _ = re.split('_\\d\\d\\d\\d', name1)[0], name1.split('_')[-1]
|
||||
fea1 = np.array(fea1)
|
||||
for name2, _, fea2 in ims_info:
|
||||
# pylint: "_\d\d\d\d" -> "_\\d\\d\\d\\d"
|
||||
frame_id2, _ = re.split('_\\d\\d\\d\\d', name2)[0], name2.split('_')[-1]
|
||||
if frame_id1 == frame_id2:
|
||||
continue
|
||||
fea2 = np.array(fea2)
|
||||
if types == 'cos':
|
||||
likehoods.append(np.sum(fea1 * fea2))
|
||||
else:
|
||||
likehoods.append(np.sum(-(fea1 - fea2) ** 2))
|
||||
|
||||
likehoods = np.array(likehoods)
|
||||
return likehoods
|
||||
|
||||
|
||||
def tar_at_far(inlikehoods, btlikehoods):
|
||||
test_point = [0.5, 0.3, 0.1, 0.01, 0.001, 0.0001, 0.00001]
|
||||
tar_far = []
|
||||
for point in test_point:
|
||||
thre = btlikehoods[int(btlikehoods.size * point)]
|
||||
n_ta = np.sum(inlikehoods > thre)
|
||||
tar_far.append((point, float(n_ta) / inlikehoods.size, thre))
|
||||
|
||||
return tar_far
|
||||
|
||||
|
||||
def load_images(paths, batch_size=128):
|
||||
'''Load images.'''
|
||||
ll = []
|
||||
resize = V.Resize((96, 64))
|
||||
transform = T.Compose([
|
||||
V.ToTensor(),
|
||||
V.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
||||
for i, _ in enumerate(paths):
|
||||
im = Image.open(paths[i])
|
||||
im = resize(im)
|
||||
img = np.array(im)
|
||||
ts = transform(img)
|
||||
ll.append(ts[0])
|
||||
if len(ll) == batch_size:
|
||||
yield np.stack(ll, axis=0)
|
||||
ll.clear()
|
||||
if ll:
|
||||
yield np.stack(ll, axis=0)
|
||||
|
||||
|
||||
def main(args):
|
||||
model_path = args.pretrained
|
||||
result_file = model_path.replace('.ckpt', '.txt')
|
||||
if os.path.exists(result_file):
|
||||
os.remove(result_file)
|
||||
|
||||
with open(result_file, 'a+') as result_fw:
|
||||
result_fw.write(model_path + '\n')
|
||||
|
||||
network = SphereNet(num_layers=12, feature_dim=128, shape=(96, 64))
|
||||
if os.path.isfile(model_path):
|
||||
param_dict = load_checkpoint(model_path)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('model.'):
|
||||
param_dict_new[key[6:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
print('-----------------------load model success-----------------------')
|
||||
else:
|
||||
print('-----------------------load model failed -----------------------')
|
||||
|
||||
network.add_flags_recursive(fp16=True)
|
||||
network.set_train(False)
|
||||
|
||||
root_path = args.eval_dir
|
||||
root_file_list = os.listdir(root_path)
|
||||
ims_info = []
|
||||
for sub_path in root_file_list:
|
||||
for im_path in os.listdir(os.path.join(root_path, sub_path)):
|
||||
ims_info.append((im_path.split('.')[0], os.path.join(root_path, sub_path, im_path)))
|
||||
|
||||
paths = [path for name, path in ims_info]
|
||||
names = [name for name, path in ims_info]
|
||||
print("exact feature...")
|
||||
|
||||
l_t = []
|
||||
for batch in load_images(paths):
|
||||
batch = batch.astype(np.float32)
|
||||
batch = Tensor(batch)
|
||||
fea = network(batch)
|
||||
l_t.append(fea.asnumpy().astype(np.float16))
|
||||
feas = np.concatenate(l_t, axis=0)
|
||||
ims_info = list(zip(names, paths, feas.tolist()))
|
||||
|
||||
print("exact inclass likehood...")
|
||||
inlikehoods = inclass_likehood(ims_info)
|
||||
inlikehoods[::-1].sort()
|
||||
|
||||
print("exact btclass likehood...")
|
||||
btlikehoods = btclass_likehood(ims_info)
|
||||
btlikehoods[::-1].sort()
|
||||
tar_far = tar_at_far(inlikehoods, btlikehoods)
|
||||
|
||||
for far, tar, thre in tar_far:
|
||||
print('---{}: {}@{}'.format(far, tar, thre))
|
||||
|
||||
for far, tar, thre in tar_far:
|
||||
result_fw.write('{}: {}@{} \n'.format(far, tar, thre))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='reid test')
|
||||
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--eval_dir', type=str, default='', help='eval image dir, e.g. /home/test')
|
||||
|
||||
arg = parser.parse_args()
|
||||
print(arg)
|
||||
|
||||
main(arg)
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert ckpt to air."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
|
||||
|
||||
from src.reid_for_export import SphereNet
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid)
|
||||
|
||||
|
||||
def main(args):
|
||||
network = SphereNet(num_layers=12, feature_dim=128, shape=(96, 64))
|
||||
ckpt_path = args.pretrained
|
||||
if os.path.isfile(ckpt_path):
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('model.'):
|
||||
param_dict_new[key[6:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
print('-----------------------load model success-----------------------')
|
||||
else:
|
||||
print('-----------------------load model failed -----------------------')
|
||||
|
||||
network.add_flags_recursive(fp16=True)
|
||||
network.set_train(False)
|
||||
|
||||
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 96, 64)).astype(np.float32)
|
||||
tensor_input_data = Tensor(input_data)
|
||||
|
||||
export(network, tensor_input_data, file_name=ckpt_path.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'),
|
||||
file_format='AIR')
|
||||
print('-----------------------export model success-----------------------')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert ckpt to air')
|
||||
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
|
||||
arg = parser.parse_args()
|
||||
|
||||
main(arg)
|
|
@ -0,0 +1,88 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 -a $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DATA_DIR] [RANK_TABLE] [PRETRAINED_BACKBONE]"
|
||||
echo " or: sh run_distribute_train.sh [DATA_DIR] [RANK_TABLE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname $(pwd))
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
SCRIPT_NAME='train.py'
|
||||
|
||||
rm -rf ${current_exec_path}/device*
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
DATA_DIR=$(get_real_path $1)
|
||||
RANK_TABLE=$(get_real_path $2)
|
||||
PRETRAINED_BACKBONE=''
|
||||
|
||||
if [ ! -d $DATA_DIR ]
|
||||
then
|
||||
echo "error: DATA_DIR=$DATA_DIR is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo $DATA_DIR
|
||||
echo $RANK_TABLE
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
export RANK_TABLE_FILE=$RANK_TABLE
|
||||
export RANK_SIZE=8
|
||||
|
||||
echo 'start training'
|
||||
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||
do
|
||||
echo 'start rank '$i
|
||||
mkdir ${current_exec_path}/device$i
|
||||
cd ${current_exec_path}/device$i
|
||||
export RANK_ID=$i
|
||||
dev=`expr $i + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--is_distributed=1 \
|
||||
--data_dir=$DATA_DIR \
|
||||
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
|
||||
done
|
||||
|
||||
echo 'running'
|
|
@ -0,0 +1,71 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [EVAL_DIR] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname $(pwd))
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='eval.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
EVAL_DIR=$(get_real_path $1)
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo $EVAL_DIR
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start evaluating'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--eval_dir=$EVAL_DIR \
|
||||
--pretrained=$PRETRAINED_BACKBONE > eval.log 2>&1 &
|
||||
|
||||
echo 'running'
|
|
@ -0,0 +1,71 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname $(pwd))
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='export.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
BATCH_SIZE=$1
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo $BATCH_SIZE
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start converting'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--pretrained=$PRETRAINED_BACKBONE > convert.log 2>&1 &
|
||||
|
||||
echo 'running'
|
|
@ -0,0 +1,83 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 -a $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [DATA_DIR] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
echo " or: sh run_standalone_train.sh [DATA_DIR] [USE_DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname $(pwd))
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='train.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
DATA_DIR=$(get_real_path $1)
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=''
|
||||
|
||||
if [ ! -d $DATA_DIR ]
|
||||
then
|
||||
echo "error: DATA_DIR=$DATA_DIR is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo $DATA_DIR
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start training'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--is_distributed=0 \
|
||||
--data_dir=$DATA_DIR \
|
||||
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
|
||||
|
||||
echo 'running'
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network config setting, will be used in train.py and eval.py"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
reid_1p_cfg = edict({
|
||||
'task': 'REID_1p',
|
||||
|
||||
# dataset related
|
||||
'per_batch_size': 128,
|
||||
|
||||
# network structure related
|
||||
'fp16': 1,
|
||||
'loss_scale': 2048.0,
|
||||
'input_size': (96, 64),
|
||||
'net_depth': 12,
|
||||
'embedding_size': 128,
|
||||
|
||||
# optimizer related
|
||||
'lr': 0.1,
|
||||
'lr_scale': 1,
|
||||
'lr_gamma': 1,
|
||||
'lr_epochs': '30,60,120,150',
|
||||
'epoch_size': 30,
|
||||
'warmup_epochs': 0,
|
||||
'steps_per_epoch': 0,
|
||||
'max_epoch': 180,
|
||||
'weight_decay': 0.0005,
|
||||
'momentum': 0.9,
|
||||
|
||||
# distributed parameter
|
||||
'is_distributed': 0,
|
||||
'local_rank': 0,
|
||||
'world_size': 1,
|
||||
|
||||
# logging related
|
||||
'log_interval': 10,
|
||||
'ckpt_path': '../../output',
|
||||
'ckpt_interval': 200,
|
||||
})
|
||||
|
||||
|
||||
reid_8p_cfg = edict({
|
||||
'task': 'REID_8p',
|
||||
|
||||
# dataset related
|
||||
'per_batch_size': 16,
|
||||
|
||||
# network structure related
|
||||
'fp16': 1,
|
||||
'loss_scale': 2048.0,
|
||||
'input_size': (96, 64),
|
||||
'net_depth': 12,
|
||||
'embedding_size': 128,
|
||||
|
||||
# optimizer related
|
||||
'lr': 0.8, # 0.8
|
||||
'lr_scale': 1,
|
||||
'lr_gamma': 0.5,
|
||||
'lr_epochs': '30,60,120,150',
|
||||
'epoch_size': 30,
|
||||
'warmup_epochs': 0,
|
||||
'steps_per_epoch': 0,
|
||||
'max_epoch': 180,
|
||||
'weight_decay': 0.0005,
|
||||
'momentum': 0.9,
|
||||
|
||||
# distributed parameter
|
||||
'is_distributed': 1,
|
||||
'local_rank': 0,
|
||||
'world_size': 8,
|
||||
|
||||
# logging related
|
||||
'log_interval': 10,
|
||||
'ckpt_path': '../../output',
|
||||
'ckpt_interval': 200,
|
||||
})
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face Recognition dataset."""
|
||||
import sys
|
||||
import warnings
|
||||
from PIL import ImageFile
|
||||
|
||||
from mindspore import dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as VC
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
|
||||
sys.path.append('./')
|
||||
sys.path.append('../data/')
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
def get_de_dataset(args):
|
||||
'''Get de_dataset.'''
|
||||
transform_label = [C.TypeCast(mstype.int32)]
|
||||
transform_img = [VC.Decode(),
|
||||
VC.Resize((96, 64)),
|
||||
VC.RandomColorAdjust(brightness=0.3, contrast=0.3, saturation=0.3, hue=0),
|
||||
VC.RandomHorizontalFlip(),
|
||||
VC.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)),
|
||||
VC.HWC2CHW()]
|
||||
|
||||
de_dataset = de.ImageFolderDataset(dataset_dir=args.data_dir, num_shards=args.world_size,
|
||||
shard_id=args.local_rank, shuffle=True)
|
||||
de_dataset = de_dataset.map(input_columns="image", operations=transform_img)
|
||||
de_dataset = de_dataset.map(input_columns="label", operations=transform_label)
|
||||
de_dataset = de_dataset.project(columns=["image", "label"])
|
||||
de_dataset = de_dataset.batch(args.per_batch_size, drop_remainder=True)
|
||||
|
||||
num_iter_per_gpu = de_dataset.get_dataset_size()
|
||||
de_dataset = de_dataset.repeat(args.max_epoch)
|
||||
num_classes = de_dataset.num_classes()
|
||||
|
||||
return de_dataset, num_iter_per_gpu, num_classes
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom logger."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
logger_name_1 = 'REID'
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
'''LOGGER'''
|
||||
def __init__(self, logger_name):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
self.local_rank = 0
|
||||
|
||||
def setup_logging_file(self, log_dir, local_rank=0):
|
||||
'''Setup logging file.'''
|
||||
self.local_rank = local_rank
|
||||
if self.local_rank == 0:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '.log'
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
logger = LOGGER(logger_name_1)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
||||
|
||||
|
||||
class AverageMeter():
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f', tb_writer=None):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
self.tb_writer = tb_writer
|
||||
self.cur_step = 1
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
|
||||
self.cur_step += 1
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name}:{avg' + self.fmt + '}'
|
||||
return fmtstr.format(**self.__dict__)
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face Recognition loss."""
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
eps = 1e-24
|
||||
|
||||
|
||||
class CrossEntropyNew(_Loss):
|
||||
'''CrossEntropyNew'''
|
||||
def __init__(self, smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropyNew, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logit, label):
|
||||
one_hot_label = self.onehot(self.cast(label, mstype.int32),
|
||||
F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, one_hot_label)
|
||||
loss = self.mean(loss, 0)
|
||||
return loss
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
'''CrossEntropy'''
|
||||
def __init__(self):
|
||||
super(CrossEntropy, self).__init__()
|
||||
|
||||
self.exp = P.Exp()
|
||||
self.sum = P.ReduceSum()
|
||||
self.reshape = P.Reshape()
|
||||
self.log = P.Log()
|
||||
self.cast = P.Cast()
|
||||
self.eps_const = Tensor(eps, dtype=mstype.float32)
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, logit, label):
|
||||
'''Construct function.'''
|
||||
exp = self.exp(logit)
|
||||
exp_sum = self.sum(exp, -1)
|
||||
exp_sum = self.reshape(exp_sum, (F.shape(exp_sum)[0], 1))
|
||||
softmax_result = exp / exp_sum
|
||||
one_hot_label = self.onehot(
|
||||
self.cast(label, mstype.int32), F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.sum((self.log(softmax_result + self.eps_const) * self.cast(
|
||||
one_hot_label, mstype.float32) * self.cast(F.scalar_to_array(-1), mstype.float32)), -1)
|
||||
batch_size = F.shape(logit)[0]
|
||||
batch_size_tensor = self.cast(
|
||||
F.scalar_to_array(batch_size), mstype.float32)
|
||||
loss = self.sum(loss, -1) / batch_size_tensor
|
||||
return loss
|
||||
|
||||
|
||||
class CrossEntropyWithIgnoreIndex(nn.Cell):
|
||||
'''CrossEntropyWithIgnoreIndex'''
|
||||
def __init__(self):
|
||||
super(CrossEntropyWithIgnoreIndex, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.greater = P.Greater()
|
||||
self.maximum = P.Maximum()
|
||||
self.fill = P.Fill()
|
||||
self.sum = P.ReduceSum(keep_dims=False)
|
||||
self.dtype = P.DType()
|
||||
self.relu = P.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x, label):
|
||||
mask = self.reshape(label, (F.shape(label)[0], 1))
|
||||
mask = self.cast(mask, mstype.float32)
|
||||
mask = mask + F.scalar_to_array(0.00001)
|
||||
mask = self.relu(mask) / (mask)
|
||||
x = x * mask
|
||||
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(x)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(x, one_hot_label)
|
||||
loss = self.sum(loss, 0)
|
||||
return loss
|
||||
|
||||
|
||||
eps = 1e-24
|
||||
|
||||
|
||||
class CEWithIgnoreIndex3D(_Loss):
|
||||
'''CEWithIgnoreIndex3D'''
|
||||
def __init__(self):
|
||||
super(CEWithIgnoreIndex3D, self).__init__()
|
||||
|
||||
self.exp = P.Exp()
|
||||
self.sum = P.ReduceSum()
|
||||
self.reshape = P.Reshape()
|
||||
self.log = P.Log()
|
||||
self.cast = P.Cast()
|
||||
self.eps_const = Tensor(eps, dtype=mstype.float32)
|
||||
self.ones = P.OnesLike()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.relu = P.ReLU()
|
||||
self.resum = P.ReduceSum(keep_dims=False)
|
||||
|
||||
def construct(self, logit, label):
|
||||
'''Construct function.'''
|
||||
mask = self.reshape(label, (F.shape(label)[0], F.shape(label)[1], 1))
|
||||
mask = self.cast(mask, mstype.float32)
|
||||
mask = mask + F.scalar_to_array(0.00001)
|
||||
mask = self.relu(mask) / (mask)
|
||||
logit = logit * mask
|
||||
|
||||
exp = self.exp(logit)
|
||||
exp_sum = self.sum(exp, -1)
|
||||
exp_sum = self.reshape(exp_sum, (F.shape(exp_sum)[0], F.shape(exp_sum)[1], 1))
|
||||
softmax_result = self.log(exp / exp_sum + self.eps_const)
|
||||
one_hot_label = self.onehot(
|
||||
self.cast(label, mstype.int32), F.shape(logit)[2], self.on_value, self.off_value)
|
||||
loss = (softmax_result * self.cast(one_hot_label, mstype.float32) * self.cast(F.scalar_to_array(-1),
|
||||
mstype.float32))
|
||||
|
||||
loss = self.sum(loss, -1)
|
||||
loss = self.sum(loss, -1)
|
||||
loss = self.sum(loss, 0)
|
||||
loss = loss
|
||||
|
||||
return loss
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face Recognition learning rate scheduler."""
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
|
||||
def warmup_step(args, gamma=0.1):
|
||||
'''Warmup step.'''
|
||||
base_lr = args.lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(args.max_epoch * args.steps_per_epoch)
|
||||
warmup_steps = int(args.warmup_epochs * args.steps_per_epoch)
|
||||
milestones = args.lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * args.steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_learning_rate(i, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma ** milestones_steps_counter[i]
|
||||
yield lr
|
||||
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
|
||||
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
|
||||
'''Warmup step lr.'''
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
milestones = lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr_each_step = []
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma ** milestones_steps_counter[i]
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
|
||||
|
||||
|
||||
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
lr_epochs = []
|
||||
for i in range(1, max_epoch):
|
||||
if i % epoch_size == 0:
|
||||
lr_epochs.append(i)
|
||||
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
|
|
@ -0,0 +1,217 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""init"""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore.common import initializer as init
|
||||
from mindspore.common.initializer import _assignment as assignment
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
r"""Return the recommended gain value for the given nonlinearity function.
|
||||
The values are as follows:
|
||||
|
||||
================= ====================================================
|
||||
nonlinearity gain
|
||||
================= ====================================================
|
||||
Linear / Identity :math:`1`
|
||||
Conv{1,2,3}D :math:`1`
|
||||
Sigmoid :math:`1`
|
||||
Tanh :math:`\frac{5}{3}`
|
||||
ReLU :math:`\sqrt{2}`
|
||||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
||||
================= ====================================================
|
||||
|
||||
Args:
|
||||
nonlinearity: the non-linear function (`nn.functional` name)
|
||||
param: optional parameter for the non-linear function
|
||||
|
||||
Examples:
|
||||
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
|
||||
"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
a = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
a = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
a = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
a = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return a
|
||||
|
||||
def _calculate_correct_fan(array, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
described in `Delving deep into rectifiers: Surpassing human-level
|
||||
performance on ImageNet classification` - He, K. et al. (2015), using a
|
||||
uniform distribution. The resulting tensor will have values sampled from
|
||||
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
||||
|
||||
.. math::
|
||||
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
||||
|
||||
Also known as He initialization.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
a: the negative slope of the rectifier used after this layer (only
|
||||
used with ``'leaky_relu'``)
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
nonlinearity: the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
||||
"""
|
||||
fan = _calculate_correct_fan(arr, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
return np.random.uniform(-bound, bound, arr.shape)
|
||||
|
||||
|
||||
def kaiming_normal_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
described in `Delving deep into rectifiers: Surpassing human-level
|
||||
performance on ImageNet classification` - He, K. et al. (2015), using a
|
||||
normal distribution. The resulting tensor will have values sampled from
|
||||
:math:`\mathcal{N}(0, \text{std}^2)` where
|
||||
|
||||
.. math::
|
||||
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
|
||||
|
||||
Also known as He initialization.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
a: the negative slope of the rectifier used after this layer (only
|
||||
used with ``'leaky_relu'``)
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
nonlinearity: the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
|
||||
"""
|
||||
fan = _calculate_correct_fan(arr, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
return np.random.normal(0, std, arr.shape)
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(arr):
|
||||
'''Calculate fan_in and fan_out.'''
|
||||
dimensions = len(arr.shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
|
||||
|
||||
num_input_fmaps = arr.shape[1]
|
||||
num_output_fmaps = arr.shape[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = arr[0][0].size
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def xavier_uniform_(arr, gain=1.):
|
||||
# type: (Tensor, float) -> Tensor
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
described in `Understanding the difficulty of training deep feedforward
|
||||
neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
|
||||
distribution. The resulting tensor will have values sampled from
|
||||
:math:`\mathcal{U}(-a, a)` where
|
||||
|
||||
.. math::
|
||||
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
|
||||
|
||||
Also known as Glorot initialization.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
gain: an optional scaling factor
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
|
||||
"""
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(arr)
|
||||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
|
||||
return np.random.uniform(-a, a, arr.shape)
|
||||
|
||||
|
||||
class ReidXavierUniform(init.Initializer):
|
||||
def __init__(self, gain=1.):
|
||||
super(ReidXavierUniform, self).__init__()
|
||||
self.gain = gain
|
||||
|
||||
def _initialize(self, arr):
|
||||
tmp = xavier_uniform_(arr, self.gain)
|
||||
assignment(arr, tmp)
|
||||
|
||||
|
||||
class ReidKaimingUniform(init.Initializer):
|
||||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(ReidKaimingUniform, self).__init__()
|
||||
self.a = a
|
||||
self.mode = mode
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def _initialize(self, arr):
|
||||
tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
|
||||
assignment(arr, tmp)
|
||||
|
||||
|
||||
class ReidKaimingNormal(init.Initializer):
|
||||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(ReidKaimingNormal, self).__init__()
|
||||
self.a = a
|
||||
self.mode = mode
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def _initialize(self, arr):
|
||||
tmp = kaiming_normal_(arr, self.a, self.mode, self.nonlinearity)
|
||||
assignment(arr, tmp)
|
|
@ -0,0 +1,306 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face Recognition backbone."""
|
||||
import math
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import TensorAdd
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.nn import Dense, Cell
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore import Tensor, Parameter
|
||||
|
||||
from src import me_init
|
||||
|
||||
|
||||
class Cut(nn.Cell):
|
||||
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def bn_with_initialize(out_channels):
|
||||
bn = nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5).add_flags_recursive(fp32=True)
|
||||
return bn
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
return Dense(input_channels, out_channels)
|
||||
|
||||
|
||||
def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1, bias=True):
|
||||
"""3x3 convolution with padding"""
|
||||
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
|
||||
pad_mode=pad_mode, group=groups, has_bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
|
||||
def conv1x1(in_channels, out_channels, pad_mode="pad", stride=1, padding=0, bias=True):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_channels, out_channels, pad_mode=pad_mode, kernel_size=1, stride=stride, has_bias=bias,
|
||||
padding=padding)
|
||||
|
||||
|
||||
def conv4x4(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1, bias=True):
|
||||
"""4x4 convolution with padding"""
|
||||
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride,
|
||||
pad_mode=pad_mode, group=groups, has_bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
|
||||
class BaseBlock(Cell):
|
||||
'''BaseBlock'''
|
||||
def __init__(self, channels):
|
||||
super(BaseBlock, self).__init__()
|
||||
self.conv1 = conv3x3(channels, channels, stride=1, padding=1, bias=False)
|
||||
self.bn1 = bn_with_initialize(channels)
|
||||
self.relu1 = P.ReLU()
|
||||
self.conv2 = conv3x3(channels, channels, stride=1, padding=1, bias=False)
|
||||
self.bn2 = bn_with_initialize(channels)
|
||||
self.relu2 = P.ReLU()
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.add = TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu2(out)
|
||||
# hand cast
|
||||
identity = self.cast(identity, mstype.float16)
|
||||
out = self.cast(out, mstype.float16)
|
||||
|
||||
out = self.add(out, identity)
|
||||
return out
|
||||
|
||||
|
||||
class MakeLayer(Cell):
|
||||
'''MakeLayer'''
|
||||
def __init__(self, block, inplanes, planes, blocks, stride=2):
|
||||
super(MakeLayer, self).__init__()
|
||||
self.conv = conv3x3(inplanes, planes, stride=stride, padding=1, bias=True)
|
||||
self.bn = bn_with_initialize(planes)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
self.layers = []
|
||||
|
||||
for _ in range(0, blocks):
|
||||
self.layers.append(block(planes))
|
||||
self.layers = nn.CellList(self.layers)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class SphereNet(Cell):
|
||||
'''SphereNet'''
|
||||
def __init__(self, num_layers=36, feature_dim=128, shape=(96, 64)):
|
||||
super(SphereNet, self).__init__()
|
||||
assert num_layers in [12, 20, 36, 64], 'SphereNet num_layers should be 12, 20 or 64'
|
||||
if num_layers == 12:
|
||||
layers = [1, 1, 1, 1]
|
||||
filter_list = [3, 16, 32, 64, 128]
|
||||
fc_size = 128 * 6 * 4
|
||||
elif num_layers == 20:
|
||||
layers = [1, 2, 4, 1]
|
||||
filter_list = [3, 64, 128, 256, 512]
|
||||
fc_size = 512 * 6 * 4
|
||||
elif num_layers == 36:
|
||||
layers = [2, 4, 4, 2]
|
||||
filter_list = [3, 32, 64, 128, 256]
|
||||
fc_size = 256 * 6 * 4
|
||||
elif num_layers == 64:
|
||||
layers = [3, 7, 16, 3]
|
||||
filter_list = [3, 64, 128, 256, 512]
|
||||
fc_size = 512 * 6 * 4
|
||||
else:
|
||||
raise ValueError('sphere' + str(num_layers) + " IS NOT SUPPORTED! (sphere20 or sphere64)")
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.arg_shape = shape
|
||||
block = BaseBlock
|
||||
|
||||
self.layer1 = MakeLayer(block, filter_list[0], filter_list[1], layers[0], stride=2)
|
||||
self.layer2 = MakeLayer(block, filter_list[1], filter_list[2], layers[1], stride=2)
|
||||
self.layer3 = MakeLayer(block, filter_list[2], filter_list[3], layers[2], stride=2)
|
||||
self.layer4 = MakeLayer(block, filter_list[3], filter_list[4], layers[3], stride=2)
|
||||
|
||||
self.fc = fc_with_initialize(fc_size, feature_dim)
|
||||
self.last_bn = nn.BatchNorm1d(feature_dim, momentum=0.9).add_flags_recursive(fp32=True)
|
||||
self.cast = P.Cast()
|
||||
self.l2norm = P.L2Normalize(axis=1)
|
||||
|
||||
for _, cell in self.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Dense)):
|
||||
if cell.bias is not None:
|
||||
cell.weight.set_data(initializer(me_init.ReidKaimingUniform(a=math.sqrt(5), mode='fan_out'),
|
||||
cell.weight.shape))
|
||||
cell.bias.set_data(initializer('zeros', cell.bias.shape))
|
||||
else:
|
||||
cell.weight.set_data(initializer(me_init.ReidXavierUniform(), cell.weight.shape))
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
b, _, _, _ = self.shape(x)
|
||||
x = self.reshape(x, (b, -1))
|
||||
x = self.fc(x)
|
||||
x = self.last_bn(x)
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.l2norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CombineMarginFC(nn.Cell):
|
||||
'''CombineMarginFC'''
|
||||
def __init__(self, embbeding_size=128, classnum=270762, s=32, a=1.0, m=0.3, b=0.2):
|
||||
super(CombineMarginFC, self).__init__()
|
||||
weight_shape = [classnum, embbeding_size]
|
||||
weight_init = initializer(me_init.ReidXavierUniform(), weight_shape)
|
||||
self.weight = Parameter(weight_init, name='weight')
|
||||
self.m = m
|
||||
self.s = s
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.m_const = Tensor(self.m, dtype=mstype.float32)
|
||||
self.a_const = Tensor(self.a, dtype=mstype.float32)
|
||||
self.b_const = Tensor(self.b, dtype=mstype.float32)
|
||||
self.s_const = Tensor(self.s, dtype=mstype.float32)
|
||||
self.m_const_zero = Tensor(0.0, dtype=mstype.float32)
|
||||
self.a_const_one = Tensor(1.0, dtype=mstype.float32)
|
||||
self.normalize = P.L2Normalize(axis=1)
|
||||
self.fc = P.MatMul(transpose_b=True)
|
||||
self.onehot = P.OneHot()
|
||||
self.transpose = P.Transpose()
|
||||
self.acos = P.ACos()
|
||||
self.cos = P.Cos()
|
||||
self.cast = P.Cast()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, x, label):
|
||||
'''Construct function.'''
|
||||
w = self.normalize(self.weight)
|
||||
cosine = self.fc(self.cast(x, mstype.float16), self.cast(w, mstype.float16))
|
||||
cosine = self.cast(cosine, mstype.float32)
|
||||
cosine_shape = F.shape(cosine)
|
||||
|
||||
one_hot_float = self.onehot(
|
||||
self.cast(label, mstype.int32), cosine_shape[1], self.on_value, self.off_value)
|
||||
theta = self.acos(cosine)
|
||||
theta = self.a_const * theta
|
||||
theta = self.m_const + theta
|
||||
body = self.cos(theta)
|
||||
body = body - self.b_const
|
||||
cos_mask = F.scalar_to_array(1.0) - one_hot_float
|
||||
output = body * one_hot_float + cosine * cos_mask
|
||||
output = output * self.s_const
|
||||
return output, cosine
|
||||
|
||||
|
||||
class CombineMarginFCFp16(nn.Cell):
|
||||
'''CombineMarginFCFp16'''
|
||||
def __init__(self, embbeding_size=128, classnum=270762, s=32, a=1.0, m=0.3, b=0.2):
|
||||
super(CombineMarginFCFp16, self).__init__()
|
||||
weight_shape = [classnum, embbeding_size]
|
||||
weight_init = initializer(me_init.ReidXavierUniform(), weight_shape)
|
||||
self.weight = Parameter(weight_init, name='weight')
|
||||
|
||||
self.m = m
|
||||
self.s = s
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.m_const = Tensor(self.m, dtype=mstype.float16)
|
||||
self.a_const = Tensor(self.a, dtype=mstype.float16)
|
||||
self.b_const = Tensor(self.b, dtype=mstype.float16)
|
||||
self.s_const = Tensor(self.s, dtype=mstype.float16)
|
||||
self.m_const_zero = Tensor(0, dtype=mstype.float16)
|
||||
self.a_const_one = Tensor(1, dtype=mstype.float16)
|
||||
self.normalize = P.L2Normalize(axis=1)
|
||||
self.fc = P.MatMul(transpose_b=True)
|
||||
|
||||
self.onehot = P.OneHot()
|
||||
self.transpose = P.Transpose()
|
||||
self.acos = P.ACos()
|
||||
self.cos = P.Cos()
|
||||
self.cast = P.Cast()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, x, label):
|
||||
'''Construct function.'''
|
||||
w = self.normalize(self.weight)
|
||||
cosine = self.fc(x, w)
|
||||
cosine_shape = F.shape(cosine)
|
||||
one_hot_float = self.onehot(
|
||||
self.cast(label, mstype.int32), cosine_shape[1], self.on_value, self.off_value)
|
||||
one_hot_float = self.cast(one_hot_float, mstype.float16)
|
||||
theta = self.acos(cosine)
|
||||
theta = self.a_const * theta
|
||||
theta = self.m_const + theta
|
||||
body = self.cos(theta)
|
||||
body = body - self.b_const
|
||||
cos_mask = self.cast(F.scalar_to_array(1.0), mstype.float16) - one_hot_float
|
||||
output = body * one_hot_float + cosine * cos_mask
|
||||
output = output * self.s_const
|
||||
|
||||
return output, cosine
|
||||
|
||||
|
||||
class BuildTrainNetwork(Cell):
|
||||
def __init__(self, network, criterion):
|
||||
super(BuildTrainNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.criterion = criterion
|
||||
|
||||
def construct(self, input_data, label):
|
||||
output = self.network(input_data)
|
||||
loss = self.criterion(output, label)
|
||||
return loss
|
||||
|
||||
|
||||
class BuildTrainNetworkWithHead(nn.Cell):
|
||||
'''Build TrainNetwork With Head.'''
|
||||
def __init__(self, model, head, criterion):
|
||||
super(BuildTrainNetworkWithHead, self).__init__()
|
||||
self.model = model
|
||||
self.head = head
|
||||
self.criterion = criterion
|
||||
|
||||
def construct(self, input_data, labels):
|
||||
embeddings = self.model(input_data)
|
||||
thetas, _ = self.head(embeddings, labels)
|
||||
loss = self.criterion(thetas, labels)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,310 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face Recognition backbone."""
|
||||
import math
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import TensorAdd
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.nn import Dense, Cell
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore import Tensor, Parameter
|
||||
|
||||
from src import me_init
|
||||
|
||||
|
||||
class Cut(nn.Cell):
|
||||
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def bn_with_initialize(out_channels):
|
||||
bn = nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5).add_flags_recursive(fp32=True)
|
||||
return bn
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
return Dense(input_channels, out_channels)
|
||||
|
||||
|
||||
def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1, bias=True):
|
||||
"""3x3 convolution with padding"""
|
||||
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
|
||||
pad_mode=pad_mode, group=groups, has_bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
|
||||
def conv1x1(in_channels, out_channels, pad_mode="pad", stride=1, padding=0, bias=True):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_channels, out_channels, pad_mode=pad_mode, kernel_size=1, stride=stride, has_bias=bias,
|
||||
padding=padding)
|
||||
|
||||
|
||||
def conv4x4(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1, bias=True):
|
||||
"""4x4 convolution with padding"""
|
||||
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride,
|
||||
pad_mode=pad_mode, group=groups, has_bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
|
||||
class BaseBlock(Cell):
|
||||
'''BaseBlock'''
|
||||
def __init__(self, channels):
|
||||
super(BaseBlock, self).__init__()
|
||||
|
||||
self.conv1 = conv3x3(channels, channels, stride=1, padding=1, bias=False)
|
||||
self.bn1 = bn_with_initialize(channels)
|
||||
self.relu1 = P.ReLU()
|
||||
self.conv2 = conv3x3(channels, channels, stride=1, padding=1, bias=False)
|
||||
self.bn2 = bn_with_initialize(channels)
|
||||
self.relu2 = P.ReLU()
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.add = TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu2(out)
|
||||
|
||||
# hand cast
|
||||
identity = self.cast(identity, mstype.float16)
|
||||
out = self.cast(out, mstype.float16)
|
||||
|
||||
out = self.add(out, identity)
|
||||
return out
|
||||
|
||||
|
||||
class MakeLayer(Cell):
|
||||
'''MakeLayer'''
|
||||
def __init__(self, block, inplanes, planes, blocks, stride=2):
|
||||
super(MakeLayer, self).__init__()
|
||||
|
||||
self.conv = conv3x3(inplanes, planes, stride=stride, padding=1, bias=True)
|
||||
self.bn = bn_with_initialize(planes)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
self.layers = []
|
||||
|
||||
for _ in range(0, blocks):
|
||||
self.layers.append(block(planes))
|
||||
self.layers = nn.CellList(self.layers)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
class SphereNet(Cell):
|
||||
'''SphereNet'''
|
||||
def __init__(self, num_layers=36, feature_dim=128, shape=(96, 64)):
|
||||
super(SphereNet, self).__init__()
|
||||
assert num_layers in [12, 20, 36, 64], 'SphereNet num_layers should be 12, 20 or 64'
|
||||
if num_layers == 12:
|
||||
layers = [1, 1, 1, 1]
|
||||
filter_list = [3, 16, 32, 64, 128]
|
||||
fc_size = 128 * 6 * 4
|
||||
elif num_layers == 20:
|
||||
layers = [1, 2, 4, 1]
|
||||
filter_list = [3, 64, 128, 256, 512]
|
||||
fc_size = 512 * 6 * 4
|
||||
elif num_layers == 36:
|
||||
layers = [2, 4, 4, 2]
|
||||
filter_list = [3, 32, 64, 128, 256]
|
||||
fc_size = 256 * 6 * 4
|
||||
elif num_layers == 64:
|
||||
layers = [3, 7, 16, 3]
|
||||
filter_list = [3, 64, 128, 256, 512]
|
||||
fc_size = 512 * 6 * 4
|
||||
else:
|
||||
raise ValueError('sphere' + str(num_layers) + " IS NOT SUPPORTED! (sphere20 or sphere64)")
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.arg_shape = shape
|
||||
block = BaseBlock
|
||||
|
||||
self.layer1 = MakeLayer(block, filter_list[0], filter_list[1], layers[0], stride=2)
|
||||
self.layer2 = MakeLayer(block, filter_list[1], filter_list[2], layers[1], stride=2)
|
||||
self.layer3 = MakeLayer(block, filter_list[2], filter_list[3], layers[2], stride=2)
|
||||
self.layer4 = MakeLayer(block, filter_list[3], filter_list[4], layers[3], stride=2)
|
||||
|
||||
self.fc = fc_with_initialize(fc_size, feature_dim)
|
||||
self.last_bn = nn.BatchNorm1d(feature_dim, momentum=0.9).add_flags_recursive(fp32=True)
|
||||
self.cast = P.Cast()
|
||||
self.l2norm = P.L2Normalize(axis=1)
|
||||
|
||||
for _, cell in self.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Dense)):
|
||||
if cell.bias is not None:
|
||||
cell.weight.set_data(initializer(me_init.ReidKaimingUniform(a=math.sqrt(5), mode='fan_out'),
|
||||
cell.weight.shape))
|
||||
cell.bias.set_data(initializer('zeros', cell.bias.shape))
|
||||
else:
|
||||
cell.weight.set_data(initializer(me_init.ReidXavierUniform(), cell.weight.shape))
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
b, _, _, _ = self.shape(x)
|
||||
x = self.reshape(x, (b, -1))
|
||||
x = self.fc(x)
|
||||
x = self.last_bn(x)
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.l2norm(x)
|
||||
x = self.cast(x, mstype.float32)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CombineMarginFC(nn.Cell):
|
||||
'''CombineMarginFC'''
|
||||
def __init__(self, embbeding_size=128, classnum=270762, s=32, a=1.0, m=0.3, b=0.2):
|
||||
super(CombineMarginFC, self).__init__()
|
||||
weight_shape = [classnum, embbeding_size]
|
||||
weight_init = initializer(me_init.ReidXavierUniform(), weight_shape)
|
||||
self.weight = Parameter(weight_init, name='weight')
|
||||
self.m = m
|
||||
self.s = s
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.m_const = Tensor(self.m, dtype=mstype.float32)
|
||||
self.a_const = Tensor(self.a, dtype=mstype.float32)
|
||||
self.b_const = Tensor(self.b, dtype=mstype.float32)
|
||||
self.s_const = Tensor(self.s, dtype=mstype.float32)
|
||||
self.m_const_zero = Tensor(0.0, dtype=mstype.float32)
|
||||
self.a_const_one = Tensor(1.0, dtype=mstype.float32)
|
||||
self.normalize = P.L2Normalize(axis=1)
|
||||
self.fc = P.MatMul(transpose_b=True)
|
||||
self.onehot = P.OneHot()
|
||||
self.transpose = P.Transpose()
|
||||
self.acos = P.ACos()
|
||||
self.cos = P.Cos()
|
||||
self.cast = P.Cast()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, x, label):
|
||||
'''Construct function.'''
|
||||
w = self.normalize(self.weight)
|
||||
cosine = self.fc(self.cast(x, mstype.float16), self.cast(w, mstype.float16))
|
||||
cosine = self.cast(cosine, mstype.float32)
|
||||
cosine_shape = F.shape(cosine)
|
||||
|
||||
one_hot_float = self.onehot(
|
||||
self.cast(label, mstype.int32), cosine_shape[1], self.on_value, self.off_value)
|
||||
theta = self.acos(cosine)
|
||||
theta = self.a_const * theta
|
||||
theta = self.m_const + theta
|
||||
body = self.cos(theta)
|
||||
body = body - self.b_const
|
||||
cos_mask = F.scalar_to_array(1.0) - one_hot_float
|
||||
output = body * one_hot_float + cosine * cos_mask
|
||||
output = output * self.s_const
|
||||
return output, cosine
|
||||
|
||||
|
||||
class CombineMarginFCFp16(nn.Cell):
|
||||
'''CombineMarginFCFp16'''
|
||||
def __init__(self, embbeding_size=128, classnum=270762, s=32, a=1.0, m=0.3, b=0.2):
|
||||
super(CombineMarginFCFp16, self).__init__()
|
||||
weight_shape = [classnum, embbeding_size]
|
||||
weight_init = initializer(me_init.ReidXavierUniform(), weight_shape)
|
||||
self.weight = Parameter(weight_init, name='weight')
|
||||
|
||||
self.m = m
|
||||
self.s = s
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.m_const = Tensor(self.m, dtype=mstype.float16)
|
||||
self.a_const = Tensor(self.a, dtype=mstype.float16)
|
||||
self.b_const = Tensor(self.b, dtype=mstype.float16)
|
||||
self.s_const = Tensor(self.s, dtype=mstype.float16)
|
||||
self.m_const_zero = Tensor(0, dtype=mstype.float16)
|
||||
self.a_const_one = Tensor(1, dtype=mstype.float16)
|
||||
self.normalize = P.L2Normalize(axis=1)
|
||||
self.fc = P.MatMul(transpose_b=True)
|
||||
|
||||
self.onehot = P.OneHot()
|
||||
self.transpose = P.Transpose()
|
||||
self.acos = P.ACos()
|
||||
self.cos = P.Cos()
|
||||
self.cast = P.Cast()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, x, label):
|
||||
'''Construct function.'''
|
||||
w = self.normalize(self.weight)
|
||||
cosine = self.fc(x, w)
|
||||
cosine_shape = F.shape(cosine)
|
||||
|
||||
one_hot_float = self.onehot(
|
||||
self.cast(label, mstype.int32), cosine_shape[1], self.on_value, self.off_value)
|
||||
one_hot_float = self.cast(one_hot_float, mstype.float16)
|
||||
theta = self.acos(cosine)
|
||||
theta = self.a_const * theta
|
||||
theta = self.m_const + theta
|
||||
body = self.cos(theta)
|
||||
body = body - self.b_const
|
||||
cos_mask = self.cast(F.scalar_to_array(1.0), mstype.float16) - one_hot_float
|
||||
output = body * one_hot_float + cosine * cos_mask
|
||||
output = output * self.s_const
|
||||
|
||||
return output, cosine
|
||||
|
||||
|
||||
class BuildTrainNetwork(Cell):
|
||||
def __init__(self, network, criterion):
|
||||
super(BuildTrainNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.criterion = criterion
|
||||
|
||||
def construct(self, input_data, label):
|
||||
output = self.network(input_data)
|
||||
loss = self.criterion(output, label)
|
||||
return loss
|
||||
|
||||
|
||||
class BuildTrainNetworkWithHead(nn.Cell):
|
||||
'''Build TrainNetwork With Head.'''
|
||||
def __init__(self, model, head, criterion):
|
||||
super(BuildTrainNetworkWithHead, self).__init__()
|
||||
self.model = model
|
||||
self.head = head
|
||||
self.criterion = criterion
|
||||
|
||||
def construct(self, input_data, labels):
|
||||
embeddings = self.model(input_data)
|
||||
thetas, _ = self.head(embeddings, labels)
|
||||
loss = self.criterion(thetas, labels)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,197 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face Recognition train."""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import datetime
|
||||
import warnings
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
|
||||
from mindspore.nn.optim import SGD
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.communication.management import get_group_size, init, get_rank
|
||||
|
||||
from src.dataset import get_de_dataset
|
||||
from src.config import reid_1p_cfg, reid_8p_cfg
|
||||
from src.lr_generator import step_lr
|
||||
from src.log import get_logger, AverageMeter
|
||||
from src.reid import SphereNet, CombineMarginFCFp16, BuildTrainNetworkWithHead
|
||||
from src.loss import CrossEntropy
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid)
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Cifar10 classification')
|
||||
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||
parser.add_argument('--data_dir', type=str, default='', help='image label list file, e.g. /home/label.txt')
|
||||
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.is_distributed == 0:
|
||||
cfg = reid_1p_cfg
|
||||
else:
|
||||
cfg = reid_8p_cfg
|
||||
cfg.pretrained = args.pretrained
|
||||
cfg.data_dir = args.data_dir
|
||||
|
||||
# Init distributed
|
||||
if args.is_distributed:
|
||||
init()
|
||||
cfg.local_rank = get_rank()
|
||||
cfg.world_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
else:
|
||||
parallel_mode = ParallelMode.STAND_ALONE
|
||||
|
||||
# parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.world_size,
|
||||
gradients_mean=True)
|
||||
|
||||
mindspore.common.set_seed(1)
|
||||
|
||||
# logger
|
||||
cfg.outputs_dir = os.path.join(cfg.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank)
|
||||
loss_meter = AverageMeter('loss')
|
||||
|
||||
# Show cfg
|
||||
cfg.logger.save_args(cfg)
|
||||
|
||||
# dataloader
|
||||
cfg.logger.info('start create dataloader')
|
||||
de_dataset, steps_per_epoch, class_num = get_de_dataset(cfg)
|
||||
cfg.steps_per_epoch = steps_per_epoch
|
||||
cfg.logger.info('step per epoch: ' + str(cfg.steps_per_epoch))
|
||||
de_dataloader = de_dataset.create_tuple_iterator()
|
||||
|
||||
cfg.logger.info('class num original: ' + str(class_num))
|
||||
if class_num % 16 != 0:
|
||||
class_num = (class_num // 16 + 1) * 16
|
||||
cfg.class_num = class_num
|
||||
cfg.logger.info('change the class num to :' + str(cfg.class_num))
|
||||
cfg.logger.info('end create dataloader')
|
||||
|
||||
# backbone and loss
|
||||
cfg.logger.important_info('start create network')
|
||||
create_network_start = time.time()
|
||||
|
||||
network = SphereNet(num_layers=cfg.net_depth, feature_dim=cfg.embedding_size, shape=cfg.input_size)
|
||||
head = CombineMarginFCFp16(embbeding_size=cfg.embedding_size, classnum=cfg.class_num)
|
||||
criterion = CrossEntropy()
|
||||
|
||||
# load the pretrained model
|
||||
if os.path.isfile(cfg.pretrained):
|
||||
param_dict = load_checkpoint(cfg.pretrained)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
cfg.logger.info('load model {} success'.format(cfg.pretrained))
|
||||
|
||||
# mixed precision training
|
||||
network.add_flags_recursive(fp16=True)
|
||||
head.add_flags_recursive(fp16=True)
|
||||
criterion.add_flags_recursive(fp32=True)
|
||||
|
||||
train_net = BuildTrainNetworkWithHead(network, head, criterion)
|
||||
|
||||
# optimizer and lr scheduler
|
||||
lr = step_lr(lr=cfg.lr, epoch_size=cfg.epoch_size, steps_per_epoch=cfg.steps_per_epoch, max_epoch=cfg.max_epoch,
|
||||
gamma=cfg.lr_gamma)
|
||||
opt = SGD(params=train_net.trainable_params(), learning_rate=lr, momentum=cfg.momentum,
|
||||
weight_decay=cfg.weight_decay, loss_scale=cfg.loss_scale)
|
||||
|
||||
# package training process, adjust lr + forward + backward + optimizer
|
||||
train_net = TrainOneStepCell(train_net, opt, sens=cfg.loss_scale)
|
||||
|
||||
# checkpoint save
|
||||
if cfg.local_rank == 0:
|
||||
ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval
|
||||
train_config = CheckpointConfig(save_checkpoint_steps=cfg.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
|
||||
ckpt_cb = ModelCheckpoint(config=train_config, directory=cfg.outputs_dir, prefix='{}'.format(cfg.local_rank))
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = train_net
|
||||
cb_params.epoch_num = ckpt_max_num
|
||||
cb_params.cur_epoch_num = 1
|
||||
run_context = RunContext(cb_params)
|
||||
ckpt_cb.begin(run_context)
|
||||
|
||||
train_net.set_train()
|
||||
t_end = time.time()
|
||||
t_epoch = time.time()
|
||||
old_progress = -1
|
||||
|
||||
cfg.logger.important_info('====start train====')
|
||||
for i, total_data in enumerate(de_dataloader):
|
||||
data, gt = total_data
|
||||
data = Tensor(data)
|
||||
gt = Tensor(gt)
|
||||
|
||||
loss = train_net(data, gt)
|
||||
loss_meter.update(loss.asnumpy())
|
||||
|
||||
# ckpt
|
||||
if cfg.local_rank == 0:
|
||||
cb_params.cur_step_num = i + 1 # current step number
|
||||
cb_params.batch_num = i + 2
|
||||
ckpt_cb.step_end(run_context)
|
||||
|
||||
# logging loss, fps, ...
|
||||
if i == 0:
|
||||
time_for_graph_compile = time.time() - create_network_start
|
||||
cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(cfg.task, time_for_graph_compile))
|
||||
|
||||
if i % cfg.log_interval == 0 and cfg.local_rank == 0:
|
||||
time_used = time.time() - t_end
|
||||
epoch = int(i / cfg.steps_per_epoch)
|
||||
fps = cfg.per_batch_size * (i - old_progress) * cfg.world_size / time_used
|
||||
cfg.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr={}'.format(epoch, i, loss_meter, fps, lr[i]))
|
||||
t_end = time.time()
|
||||
loss_meter.reset()
|
||||
old_progress = i
|
||||
|
||||
if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0:
|
||||
epoch_time_used = time.time() - t_epoch
|
||||
epoch = int(i / cfg.steps_per_epoch)
|
||||
fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used
|
||||
cfg.logger.info('=================================================')
|
||||
cfg.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
|
||||
cfg.logger.info('=================================================')
|
||||
t_epoch = time.time()
|
||||
|
||||
cfg.logger.important_info('====train end====')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue