Add FaceRecognition to model_zoo/research/cv/

This commit is contained in:
zhanghuiyao 2020-12-14 10:27:47 +08:00
parent e10dbb4a7f
commit 972aefe561
23 changed files with 2771 additions and 0 deletions

View File

@ -0,0 +1,244 @@
# Contents
- [Face Recognition Description](#Face-Recognition-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Running Example](#running-example)
- [Model Description](#model-description)
- [Performance](#performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Face Recognition Description](#contents)
This is a face recognition network based on Resnet, with support for training and evaluation on Ascend910.
ResNet (residual neural network) was proposed by Kaiming He and other four Chinese of Microsoft Research Institute. Through the use of ResNet unit, it successfully trained 152 layers of neural network, and won the championship in ilsvrc2015. The error rate on top 5 was 3.57%, and the parameter quantity was lower than vggnet, so the effect was very outstanding. Traditional convolution network or full connection network will have more or less information loss. At the same time, it will lead to the disappearance or explosion of gradient, which leads to the failure of deep network training. ResNet solves this problem to a certain extent. By passing the input information to the output, the integrity of the information is protected. The whole network only needs to learn the part of the difference between input and output, which simplifies the learning objectives and difficulties.The structure of ResNet can accelerate the training of neural network very quickly, and the accuracy of the model is also greatly improved. At the same time, ResNet is very popular, even can be directly used in the concept net network.
[Paper](https://arxiv.org/pdf/1512.03385.pdf): Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition"
# [Model Architecture](#contents)
Face Recognition uses a Resnet network for performing feature extraction, more details are show below:[Link](https://arxiv.org/pdf/1512.03385.pdf)
# [Dataset](#contents)
We use about 4.7 million face images as training dataset and 1.1 million as evaluating dataset in this example, and you can also use your own datasets or open source datasets (e.g. face_emore).
The directory structure is as follows:
```python
.
└─ dataset
├─ train dataset
├─ ID1
├─ ID1_0001.jpg
├─ ID1_0002.jpg
...
├─ ID2
...
├─ ID3
...
...
├─ test dataset
├─ ID1
├─ ID1_0001.jpg
├─ ID1_0002.jpg
...
├─ ID2
...
├─ ID3
...
...
```
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to get Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
The entire code structure is as following:
```python
└─ face_recognition
├── README.md // descriptions about face_recognition
├── scripts
│ ├── run_distribute_train_base.sh // shell script for distributed training on Ascend
│ ├── run_distribute_train_beta.sh // shell script for distributed training on Ascend
│ ├── run_eval.sh // shell script for evaluation on Ascend
│ ├── run_export.sh // shell script for exporting air model
│ ├── run_standalone_train_base.sh // shell script for standalone training on Ascend
│ ├── run_standalone_train_beta.sh // shell script for standalone training on Ascend
├── src
│ ├── backbone
│ │ ├── head.py // head unit
│ │ ├── resnet.py // resnet architecture
│ ├── callback_factory.py // callback logging
│ ├── config.py // parameter configuration
│ ├── custom_dataset.py // custom dataset and sampler
│ ├── custom_net.py // custom cell define
│ ├── dataset_factory.py // creating dataset
│ ├── init_network.py // init network parameter
│ ├── my_logging.py // logging format setting
│ ├── loss_factory.py // loss calculation
│ ├── lrsche_factory.py // learning rate schedule
│ ├── me_init.py // network parameter init method
│ ├── metric_factory.py // metric fc layer
├─ train.py // training scripts
├─ eval.py // evaluation scripts
└─ export.py // export air model
```
## [Running Example](#contents)
### Train
- Stand alone mode
- base model
```bash
cd ./scripts
sh run_standalone_train_base.sh [USE_DEVICE_ID]
```
for example:
```bash
cd ./scripts
sh run_standalone_train_base.sh 0
```
- beta model
```bash
cd ./scripts
sh run_standalone_train_beta.sh [USE_DEVICE_ID]
```
for example:
```bash
cd ./scripts
sh run_standalone_train_beta.sh 0
```
- Distribute mode (recommended)
- base model
```bash
cd ./scripts
sh run_distribute_train_base.sh [RANK_TABLE]
```
for example:
```bash
cd ./scripts
sh run_distribute_train_base.sh ./rank_table_8p.json
```
- beta model
```bash
cd ./scripts
sh run_distribute_train_beta.sh [RANK_TABLE]
```
for example:
```bash
cd ./scripts
sh run_distribute_train_beta.sh ./rank_table_8p.json
```
You will get the loss value of each epoch as following in "./scripts/data_parallel_log_[DEVICE_ID]/outputs/logs/[TIME].log" or "./scripts/log_parallel_graph/face_recognition_[DEVICE_ID].log":
```python
epoch[0], iter[100], loss:(Tensor(shape=[], dtype=Float32, value= 50.2733), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 32768)), cur_lr:0.000660, mean_fps:743.09 imgs/sec
epoch[0], iter[200], loss:(Tensor(shape=[], dtype=Float32, value= 49.3693), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 32768)), cur_lr:0.001314, mean_fps:4426.42 imgs/sec
epoch[0], iter[300], loss:(Tensor(shape=[], dtype=Float32, value= 48.7081), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 16384)), cur_lr:0.001968, mean_fps:4428.09 imgs/sec
epoch[0], iter[400], loss:(Tensor(shape=[], dtype=Float32, value= 45.7791), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 16384)), cur_lr:0.002622, mean_fps:4428.17 imgs/sec
...
epoch[8], iter[27300], loss:(Tensor(shape=[], dtype=Float32, value= 2.13556), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4429.38 imgs/sec
epoch[8], iter[27400], loss:(Tensor(shape=[], dtype=Float32, value= 2.36922), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4429.88 imgs/sec
epoch[8], iter[27500], loss:(Tensor(shape=[], dtype=Float32, value= 2.08594), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4430.59 imgs/sec
epoch[8], iter[27600], loss:(Tensor(shape=[], dtype=Float32, value= 2.38706), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4430.37 imgs/sec
```
### Evaluation
```bash
cd ./scripts
sh run_eval.sh [USE_DEVICE_ID]
```
You will get the result as following in "./scripts/log_inference/outputs/models/logs/[TIME].log":
[test_dataset]: zj2jk=0.9495, jk2zj=0.9480, avg=0.9487
### Convert model
If you want to infer the network on Ascend 310, you should convert the model to AIR:
```bash
cd ./scripts
sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
```
for example:
```bash
cd ./scripts
sh run_export.sh 16 0 ./0-1_1.ckpt
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | Face Recognition |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | V1 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 09/30/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 4.7 million images |
| Training Parameters | epoch=100, batch_size=192, momentum=0.9 |
| Optimizer | Momentum |
| Loss Function | Cross Entropy |
| outputs | probability |
| Speed | 1pc: 300~400 ms/step; 8pcs: 40~50 ms/step |
| Total time | 1pc: NA hours; 8pcs: 10 hours |
| Checkpoint for Fine tuning | 584M (.ckpt file) |
### Evaluation Performance
| Parameters |Face Recognition For Tracking|
| ------------------- | --------------------------- |
| Model Version | V1 |
| Resource | Ascend 910 |
| Uploaded Date | 09/30/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 1.1 million images |
| batch_size | 512 |
| outputs | ACC |
| ACC | 0.9 |
| Model for inference | 584M (.ckpt file) |
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,333 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition eval."""
import os
import time
import math
from pprint import pformat
import numpy as np
import cv2
import mindspore.dataset.transforms.py_transforms as transforms
import mindspore.dataset.vision.py_transforms as vision
import mindspore.dataset as de
from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config_inference
from src.backbone.resnet import get_backbone
from src.my_logging import get_logger
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
class TxtDataset():
'''TxtDataset'''
def __init__(self, root_all, filenames):
super(TxtDataset, self).__init__()
self.imgs = []
self.labels = []
for root, filename in zip(root_all, filenames):
fin = open(filename, "r")
for line in fin:
self.imgs.append(os.path.join(root, line.strip().split(" ")[0]))
self.labels.append(line.strip())
fin.close()
def __getitem__(self, index):
try:
img = cv2.cvtColor(cv2.imread(self.imgs[index]), cv2.COLOR_BGR2RGB)
except:
print(self.imgs[index])
raise
return img, index
def __len__(self):
return len(self.imgs)
def get_all_labels(self):
return self.labels
class DistributedSampler():
'''DistributedSampler'''
def __init__(self, dataset):
self.dataset = dataset
self.num_replicas = 1
self.rank = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
def __iter__(self):
indices = list(range(len(self.dataset)))
indices = indices[self.rank::self.num_replicas]
return iter(indices)
def __len__(self):
return self.num_samples
def get_dataloader(img_predix_all, img_list_all, batch_size, img_transforms):
dataset = TxtDataset(img_predix_all, img_list_all)
sampler = DistributedSampler(dataset)
dataset_column_names = ["image", "index"]
ds = de.GeneratorDataset(dataset, column_names=dataset_column_names, sampler=sampler)
ds = ds.map(input_columns=["image"], operations=img_transforms)
ds = ds.batch(batch_size, num_parallel_workers=8, drop_remainder=False)
ds = ds.repeat(1)
return ds, len(dataset), dataset.get_all_labels()
def generate_test_pair(jk_list, zj_list):
'''generate_test_pair'''
file_paths = [jk_list, zj_list]
jk_dict = {}
zj_dict = {}
jk_zj_dict_list = [jk_dict, zj_dict]
for path, x_dict in zip(file_paths, jk_zj_dict_list):
with open(path, 'r') as fr:
for line in fr:
label = line.strip().split(' ')[1]
tmp = x_dict.get(label, [])
tmp.append(line.strip())
x_dict[label] = tmp
zj2jk_pairs = []
for key in jk_dict:
jk_file_list = jk_dict[key]
zj_file_list = zj_dict[key]
for zj_file in zj_file_list:
zj2jk_pairs.append([zj_file, jk_file_list])
return zj2jk_pairs
def check_minmax(data, min_value=0.99, max_value=1.01):
min_data = data.min()
max_data = data.max()
if np.isnan(min_data) or np.isnan(max_data):
args.logger.info('ERROR, nan happened, please check if used fp16 or other error')
raise Exception
if min_data < min_value or max_data > max_value:
args.logger.info('ERROR, min or max is out if range, range=[{}, {}], minmax=[{}, {}]'.format(
min_value, max_value, min_data, max_data))
raise Exception
def get_model(args):
'''get_model'''
net = get_backbone(args)
if args.fp16:
net.add_flags_recursive(fp16=True)
if args.weight.endswith('.ckpt'):
param_dict = load_checkpoint(args.weight)
param_dict_new = {}
for key, value in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = value
else:
param_dict_new[key] = value
load_param_into_net(net, param_dict_new)
args.logger.info('INFO, ------------- load model success--------------')
else:
args.logger.info('ERROR, not supprot file:{}, please check weight in config.py'.format(args.weight))
return 0
net.set_train(False)
return net
def topk(matrix, k, axis=1):
'''topk'''
if axis == 0:
row_index = np.arange(matrix.shape[1 - axis])
topk_index = np.argpartition(-matrix, k, axis=axis)[0:k, :]
topk_data = matrix[topk_index, row_index]
topk_index_sort = np.argsort(-topk_data, axis=axis)
topk_data_sort = topk_data[topk_index_sort, row_index]
topk_index_sort = topk_index[0:k, :][topk_index_sort, row_index]
else:
column_index = np.arange(matrix.shape[1 - axis])[:, None]
topk_index = np.argpartition(-matrix, k, axis=axis)[:, 0:k]
topk_data = matrix[column_index, topk_index]
topk_index_sort = np.argsort(-topk_data, axis=axis)
topk_data_sort = topk_data[column_index, topk_index_sort]
topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort]
return topk_data_sort, topk_index_sort
def cal_topk(idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot):
'''cal_topk'''
args.logger.info('start idx:{} subprocess...'.format(idx))
correct = np.array([0] * 2)
tot = np.array([0])
zj, jk_all = zj2jk_pairs[idx]
zj_embedding = test_embedding_tot[zj]
jk_all_embedding = np.concatenate([np.expand_dims(test_embedding_tot[jk], axis=0) for jk in jk_all], axis=0)
args.logger.info('INFO, calculate top1 acc index:{}, zj_embedding shape:{}'.format(idx, zj_embedding.shape))
args.logger.info('INFO, calculate top1 acc index:{}, jk_all_embedding shape:{}'.format(idx, jk_all_embedding.shape))
test_time = time.time()
mm = np.matmul(np.expand_dims(zj_embedding, axis=0), dis_embedding_tot)
top100_jk2zj = np.squeeze(topk(mm, 100)[0], axis=0)
top100_zj2jk = topk(np.matmul(jk_all_embedding, dis_embedding_tot), 100)[0]
test_time_used = time.time() - test_time
args.logger.info('INFO, calculate top1 acc index:{}, np.matmul().top(100) time used:{:.2f}s'.format(
idx, test_time_used))
tot[0] = len(jk_all)
for i, jk in enumerate(jk_all):
jk_embedding = test_embedding_tot[jk]
similarity = np.dot(jk_embedding, zj_embedding)
if similarity > top100_jk2zj[0]:
correct[0] += 1
if similarity > top100_zj2jk[i, 0]:
correct[1] += 1
return correct, tot
def l2normalize(features):
epsilon = 1e-12
l2norm = np.sum(np.abs(features) ** 2, axis=1, keepdims=True) ** (1./2)
l2norm[np.logical_and(l2norm < 0, l2norm > -epsilon)] = -epsilon
l2norm[np.logical_and(l2norm >= 0, l2norm < epsilon)] = epsilon
return features/l2norm
def main(args):
if not os.path.exists(args.test_dir):
args.logger.info('ERROR, test_dir is not exists, please set test_dir in config.py.')
return 0
all_start_time = time.time()
net = get_model(args)
compile_time_used = time.time() - all_start_time
args.logger.info('INFO, graph compile finished, time used:{:.2f}s, start calculate img embedding'.
format(compile_time_used))
img_transforms = transforms.Compose([vision.ToTensor(), vision.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#for test images
args.logger.info('INFO, start step1, calculate test img embedding, weight file = {}'.format(args.weight))
step1_start_time = time.time()
ds, img_tot, all_labels = get_dataloader(args.test_img_predix, args.test_img_list,
args.test_batch_size, img_transforms)
args.logger.info('INFO, dataset total test img:{}, total test batch:{}'.format(img_tot, ds.get_dataset_size()))
test_embedding_tot_np = np.zeros((img_tot, args.emb_size))
test_img_labels = all_labels
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
for i, data in enumerate(data_loader):
img, idxs = data["image"], data["index"]
out = net(Tensor(img)).asnumpy().astype(np.float32)
embeddings = l2normalize(out)
for batch in range(embeddings.shape[0]):
test_embedding_tot_np[idxs[batch]] = embeddings[batch]
try:
check_minmax(np.linalg.norm(test_embedding_tot_np, ord=2, axis=1))
except ValueError:
return 0
test_embedding_tot = {}
for idx, label in enumerate(test_img_labels):
test_embedding_tot[label] = test_embedding_tot_np[idx]
step2_start_time = time.time()
step1_time_used = step2_start_time - step1_start_time
args.logger.info('INFO, step1 finished, time used:{:.2f}s, start step2, calculate dis img embedding'.
format(step1_time_used))
# for dis images
ds_dis, img_tot, _ = get_dataloader(args.dis_img_predix, args.dis_img_list, args.dis_batch_size, img_transforms)
dis_embedding_tot_np = np.zeros((img_tot, args.emb_size))
total_batch = ds_dis.get_dataset_size()
args.logger.info('INFO, dataloader total dis img:{}, total dis batch:{}'.format(img_tot, total_batch))
start_time = time.time()
img_per_gpu = int(math.ceil(1.0 * img_tot / args.world_size))
delta_num = img_per_gpu * args.world_size - img_tot
start_idx = img_per_gpu * args.local_rank - max(0, args.local_rank - (args.world_size - delta_num))
data_loader = ds_dis.create_dict_iterator(output_numpy=True, num_epochs=1)
for idx, data in enumerate(data_loader):
img = data["image"]
out = net(Tensor(img)).asnumpy().astype(np.float32)
embeddings = l2normalize(out)
dis_embedding_tot_np[start_idx:(start_idx + embeddings.shape[0])] = embeddings
start_idx += embeddings.shape[0]
if args.local_rank % 8 == 0 and idx % args.log_interval == 0 and idx > 0:
speed = 1.0 * (args.dis_batch_size * args.log_interval * args.world_size) / (time.time() - start_time)
time_left = (total_batch - idx - 1) * args.dis_batch_size *args.world_size / speed
args.logger.info('INFO, processed [{}/{}], speed: {:.2f} img/s, left:{:.2f}s'.
format(idx, total_batch, speed, time_left))
start_time = time.time()
try:
check_minmax(np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1))
except ValueError:
return 0
step3_start_time = time.time()
step2_time_used = step3_start_time - step2_start_time
args.logger.info('INFO, step2 finished, time used:{:.2f}s, start step3, calculate top1 acc'.format(step2_time_used))
# clear npu memory
img = None
net = None
dis_embedding_tot_np = np.transpose(dis_embedding_tot_np, (1, 0))
args.logger.info('INFO, calculate top1 acc dis_embedding_tot_np shape:{}'.format(dis_embedding_tot_np.shape))
# find best match
assert len(args.test_img_list) % 2 == 0
task_num = int(len(args.test_img_list) / 2)
correct = np.array([0] * (2 * task_num))
tot = np.array([0] * task_num)
for i in range(int(len(args.test_img_list) / 2)):
jk_list = args.test_img_list[2 * i]
zj_list = args.test_img_list[2 * i + 1]
zj2jk_pairs = sorted(generate_test_pair(jk_list, zj_list))
sampler = DistributedSampler(zj2jk_pairs)
args.logger.info('INFO, calculate top1 acc sampler len:{}'.format(len(sampler)))
for idx in sampler:
out1, out2 = cal_topk(idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np)
correct[2 * i] += out1[0]
correct[2 * i + 1] += out1[1]
tot[i] += out2[0]
args.logger.info('local_rank={},tot={},correct={}'.format(args.local_rank, tot, correct))
step3_time_used = time.time() - step3_start_time
args.logger.info('INFO, step3 finished, time used:{:.2f}s'.format(step3_time_used))
args.logger.info('weight:{}'.format(args.weight))
for i in range(int(len(args.test_img_list) / 2)):
test_set_name = 'test_dataset'
zj2jk_acc = correct[2 * i] / tot[i]
jk2zj_acc = correct[2 * i + 1] / tot[i]
avg_acc = (zj2jk_acc + jk2zj_acc) / 2
results = '[{}]: zj2jk={:.4f}, jk2zj={:.4f}, avg={:.4f}'.format(test_set_name, zj2jk_acc, jk2zj_acc, avg_acc)
args.logger.info(results)
args.logger.info('INFO, tot time used: {:.2f}s'.format(time.time() - all_start_time))
return 0
if __name__ == '__main__':
arg = config_inference
arg.test_img_predix = [arg.test_dir, arg.test_dir]
arg.test_img_list = [os.path.join(arg.test_dir, 'lists/jk_list.txt'),
os.path.join(arg.test_dir, 'lists/zj_list.txt')]
arg.dis_img_predix = [arg.test_dir,]
arg.dis_img_list = [os.path.join(arg.test_dir, 'lists/dis_list.txt'),]
log_path = os.path.join(arg.ckpt_path, 'logs')
arg.logger = get_logger(log_path, arg.local_rank)
arg.logger.info('Config\n\n%s\n' % pformat(arg))
main(arg)

View File

@ -0,0 +1,81 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert ckpt to air."""
import os
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.backbone.resnet import get_backbone
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
def main(args):
network = get_backbone(args)
ckpt_path = args.pretrained
if os.path.isfile(ckpt_path):
param_dict = load_checkpoint(ckpt_path)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
print('-----------------------load model success-----------------------')
else:
print('-----------------------load model failed -----------------------')
network.add_flags_recursive(fp16=True)
network.set_train(False)
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 112, 112)).astype(np.float32)
tensor_input_data = Tensor(input_data)
file_path = ckpt_path.replace('.ckpt', '_' + str(args.batch_size) + 'b.air')
export(network, tensor_input_data, file_name=file_path, file_format='AIR')
print('-----------------------export model success, save file:{}-----------------------'.format(file_path))
def parse_args():
'''parse_args'''
parser = argparse.ArgumentParser(description='Convert ckpt to air')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
parser.add_argument('--pre_bn', type=int, default=0, help='1: bn-conv-bn-conv-bn, 0: conv-bn-conv-bn')
parser.add_argument('--inference', type=int, default=1, help='use inference backbone')
parser.add_argument('--use_se', type=int, default=0, help='use se block or not')
parser.add_argument('--emb_size', type=int, default=256, help='embedding size of the network')
parser.add_argument('--act_type', type=str, default='relu', help='activation layer type')
parser.add_argument('--backbone', type=str, default='r100', help='backbone network')
parser.add_argument('--head', type=str, default='0', help='head type, default is 0')
parser.add_argument('--use_drop', type=int, default=0, help='whether use dropout in network')
args = parser.parse_args()
return args
if __name__ == "__main__":
arg = parse_args()
main(arg)

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_distribute_train_base.sh [RANK_TABLE_FILE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
# Distribute config
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
EXECUTE_PATH=$(pwd)
echo *******************EXECUTE_PATH= $EXECUTE_PATH
if [ -d "${EXECUTE_PATH}/log_parallel_graph" ]; then
echo "[INFO] Delete old data_parallel log files"
rm -rf ${EXECUTE_PATH}/log_parallel_graph
fi
mkdir ${EXECUTE_PATH}/log_parallel_graph
for((i=0;i<=7;i++));
do
rm -rf ${EXECUTE_PATH}/data_parallel_log_$i
mkdir -p ${EXECUTE_PATH}/data_parallel_log_$i
cd ${EXECUTE_PATH}/data_parallel_log_$i
export RANK_ID=$i
export DEVICE_ID=$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log
python ${EXECUTE_PATH}/../train.py \
--train_stage=base \
--is_distributed=1 &> ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log &
done
echo "[INFO] Start training..."

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_distribute_train_beta.sh [RANK_TABLE_FILE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
# Distribute config
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
EXECUTE_PATH=$(pwd)
echo *******************EXECUTE_PATH= $EXECUTE_PATH
if [ -d "${EXECUTE_PATH}/log_parallel_graph" ]; then
echo "[INFO] Delete old data_parallel log files"
rm -rf ${EXECUTE_PATH}/log_parallel_graph
fi
mkdir ${EXECUTE_PATH}/log_parallel_graph
for((i=0;i<=7;i++));
do
rm -rf ${EXECUTE_PATH}/data_parallel_log_$i
mkdir -p ${EXECUTE_PATH}/data_parallel_log_$i
cd ${EXECUTE_PATH}/data_parallel_log_$i
export RANK_ID=$i
export DEVICE_ID=$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log
python ${EXECUTE_PATH}/../train.py \
--train_stage=beta \
--is_distributed=1 &> ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log &
done
echo "[INFO] Start training..."

View File

@ -0,0 +1,46 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_eval.sh [USE_DEVICE_ID]"
exit 1
fi
dirname_path=$(dirname "$(pwd)")
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
export RANK_ID=0
USE_DEVICE_ID=$1
echo 'start device '$USE_DEVICE_ID
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
EXECUTE_PATH=$(pwd)
echo *******************EXECUTE_PATH= $EXECUTE_PATH
if [ -d "${EXECUTE_PATH}/log_inference" ]; then
echo "[INFO] Delete old log_inference log files"
rm -rf ${EXECUTE_PATH}/log_inference
fi
mkdir ${EXECUTE_PATH}/log_inference
cd ${EXECUTE_PATH}/log_inference
env > ${EXECUTE_PATH}/log_inference/face_recognition.log
python ${EXECUTE_PATH}/../eval.py &> ${EXECUTE_PATH}/log_inference/face_recognition.log &
echo "[INFO] Start inference..."

View File

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname "$(pwd)")
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
SCRIPT_NAME='export.py'
ulimit -c unlimited
BATCH_SIZE=$1
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
echo $BATCH_SIZE
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
echo 'start converting'
export RANK_ID=0
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
echo 'start device '$USE_DEVICE_ID
mkdir ${current_exec_path}/device$USE_DEVICE_ID
cd ${current_exec_path}/device$USE_DEVICE_ID
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--pretrained=$PRETRAINED_BACKBONE \
--batch_size=$BATCH_SIZE > convert.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,52 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_standalone_train_base.sh [USE_DEVICE_ID]"
exit 1
fi
dirname_path=$(dirname "$(pwd)")
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
export RANK_ID=0
USE_DEVICE_ID=$1
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
EXECUTE_PATH=$(pwd)
echo *******************EXECUTE_PATH= $EXECUTE_PATH
if [ -d "${EXECUTE_PATH}/log_standalone_graph" ]; then
echo "[INFO] Delete old data_standalone log files"
rm -rf ${EXECUTE_PATH}/log_standalone_graph
fi
mkdir ${EXECUTE_PATH}/log_standalone_graph
rm -rf ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID
mkdir -p ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID
cd ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID
echo "start training for rank $RANK_ID, device $USE_DEVICE_ID"
env > ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log
python ${EXECUTE_PATH}/../train.py \
--train_stage=base \
--is_distributed=0 &> ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log &
echo "[INFO] Start training..."

View File

@ -0,0 +1,52 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_standalone_train_beta.sh [USE_DEVICE_ID]"
exit 1
fi
dirname_path=$(dirname "$(pwd)")
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
export RANK_ID=0
USE_DEVICE_ID=$1
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
EXECUTE_PATH=$(pwd)
echo *******************EXECUTE_PATH= $EXECUTE_PATH
if [ -d "${EXECUTE_PATH}/log_standalone_graph" ]; then
echo "[INFO] Delete old data_stanalone log files"
rm -rf ${EXECUTE_PATH}/log_standalone_graph
fi
mkdir ${EXECUTE_PATH}/log_standalone_graph
rm -rf ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID
mkdir -p ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID
cd ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID
echo "start training for rank $RANK_ID, device $USE_DEVICE_ID"
env > ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log
python ${EXECUTE_PATH}/../train.py \
--train_stage=beta \
--is_distributed=0 &> ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log &
echo "[INFO] Start training..."

View File

@ -0,0 +1,63 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition head."""
import mindspore.nn as nn
from mindspore.nn import Cell
from mindspore.ops import operations as P
from src.custom_net import Cut, bn_with_initialize, fc_with_initialize
__all__ = ['get_head']
class Head0(Cell):
'''Head0'''
def __init__(self, emb_size, args=None):
super(Head0, self).__init__()
if args.pre_bn == 1:
self.bn1 = bn_with_initialize(512, use_inference=args.inference)
else:
self.bn1 = Cut()
if args is not None:
if args.use_drop == 1:
self.drop = nn.Dropout(keep_prob=0.4)
else:
self.drop = Cut()
else:
self.drop = nn.Dropout(keep_prob=0.4)
self.fc1 = fc_with_initialize(512 * 7 * 7, emb_size)
if args.inference == 1:
self.bn2 = Cut()
else:
self.bn2 = nn.BatchNorm1d(emb_size, affine=False, momentum=0.9).add_flags_recursive(fp32=True)
self.reshape = P.Reshape()
self.shape = P.Shape()
def construct(self, x):
x = self.bn1(x)
x = self.drop(x)
b, _, _, _ = self.shape(x)
shp = (b, -1)
x = self.reshape(x, shp)
x = self.fc1(x)
x = self.bn2(x)
return x
def get_head(args):
emb_size = args.emb_size
return Head0(emb_size, args)

View File

@ -0,0 +1,264 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition backbone."""
import math
import numpy as np
import mindspore.nn as nn
from mindspore.nn import Cell
from mindspore.ops.operations import TensorAdd
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
from mindspore.common import dtype as mstype
from src.backbone.head import get_head
from src import me_init
from src.custom_net import Cut, bn_with_initialize, fc_with_initialize, conv1x1, conv3x3
__all__ = ['get_backbone']
class Sigmoid(Cell):
def __init__(self):
super(Sigmoid, self).__init__()
self.sigmoid = P.Sigmoid()
def construct(self, x):
out = self.sigmoid(x)
return out
class SEBlock(Cell):
'''SEBlock'''
def __init__(self, channel, reduction=16, act_type='relu'):
super(SEBlock, self).__init__()
self.fc1 = fc_with_initialize(channel, channel // reduction)
self.act_layer = nn.PReLU(
channel // reduction) if act_type == 'prelu' else P.ReLU()
self.fc2 = fc_with_initialize(channel // reduction, channel)
self.sigmoid = Sigmoid().add_flags_recursive(fp32=True)
self.reshape = P.Reshape()
self.shape = P.Shape()
self.reduce_mean = P.ReduceMean(True)
self.cast = P.Cast()
def construct(self, x):
'''construct'''
b, c, _, _ = self.shape(x)
y = self.reduce_mean(x, (2, 3))
y = self.reshape(y, (b, c))
y = self.fc1(y)
y = self.act_layer(y)
y = self.fc2(y)
y = self.sigmoid(y)
y = self.cast(y, mstype.float16)
y = self.reshape(y, (b, c, 1, 1))
return x * y
class IRBlock(Cell):
'''IRBlock'''
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=1, pre_bn=1, use_inference=0,
act_type='relu'):
super(IRBlock, self).__init__()
if pre_bn == 1:
self.bn1 = bn_with_initialize(inplanes, use_inference=use_inference)
else:
self.bn1 = Cut()
self.conv1 = conv3x3(inplanes, planes, stride=1)
self.bn2 = bn_with_initialize(planes, use_inference=use_inference)
self.act_layer = nn.PReLU(
planes) if act_type == 'prelu' else P.ReLU()
self.conv2 = conv3x3(planes, planes, stride=stride)
self.bn3 = bn_with_initialize(planes, use_inference=use_inference)
if downsample is None:
self.downsample = Cut()
else:
self.downsample = downsample
self.use_se = use_se
if use_se == 1:
self.se = SEBlock(planes, act_type=act_type)
self.add = TensorAdd()
self.cast = P.Cast()
def construct(self, x):
'''construct'''
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.act_layer(out)
out = self.conv2(out)
out = self.bn3(out)
if self.use_se == 1:
out = self.se(out)
identity = self.downsample(x)
identity = self.cast(identity, mstype.float16)
out = self.cast(out, mstype.float16)
out = self.add(out, identity)
return out
class DownSample(Cell):
'''DownSample'''
def __init__(self, inplanes, planes, expansion, stride, use_inference=0):
super(DownSample, self).__init__()
self.conv1 = conv1x1(inplanes, planes * expansion,
stride=stride, pad_mode="valid")
self.bn1 = bn_with_initialize(planes * expansion, use_inference=use_inference)
def construct(self, x):
out = self.conv1(x)
out = self.bn1(out)
return out
class MakeLayer(Cell):
'''MakeLayer'''
def __init__(self, block, inplanes, planes, blocks, args, stride=1):
super(MakeLayer, self).__init__()
self.inplanes = inplanes
self.downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
self.downsample = DownSample(
self.inplanes, planes, block.expansion, stride, use_inference=args.inference)
self.layers = []
self.layers.append(block(self.inplanes, planes, stride, self.downsample, use_se=args.use_se, pre_bn=args.pre_bn,
use_inference=args.inference, act_type=args.act_type))
self.inplanes = planes
for _ in range(1, blocks):
self.layers.append(block(self.inplanes, planes, use_se=args.use_se, pre_bn=args.pre_bn,
use_inference=args.inference, act_type=args.act_type))
self.layers = nn.CellList(self.layers)
def construct(self, x):
for block in self.layers:
x = block(x)
return x
class FaceResNet(Cell):
'''FaceResNet'''
def __init__(self, block, layers, args):
super(FaceResNet, self).__init__()
self.act_type = args.act_type
self.inplanes = 64
self.use_se = args.use_se
self.conv1 = conv3x3(3, 64, stride=1)
self.bn1 = bn_with_initialize(64, use_inference=args.inference)
self.prelu = nn.PReLU(64) if self.act_type == 'prelu' else P.ReLU()
self.layer1 = MakeLayer(
block, planes=64, inplanes=self.inplanes, blocks=layers[0], stride=2, args=args)
self.inplanes = 64
self.layer2 = MakeLayer(
block, planes=128, inplanes=self.inplanes, blocks=layers[1], stride=2, args=args)
self.inplanes = 128
self.layer3 = MakeLayer(
block, planes=256, inplanes=self.inplanes, blocks=layers[2], stride=2, args=args)
self.inplanes = 256
self.layer4 = MakeLayer(
block, planes=512, inplanes=self.inplanes, blocks=layers[3], stride=2, args=args)
self.head = get_head(args)
np.random.seed(1)
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(initializer(me_init.ReidKaimingUniform(a=math.sqrt(5), mode='fan_out'),
cell.weight.shape))
if cell.bias is not None:
cell.bias.set_data(initializer('zeros', cell.bias.shape))
elif isinstance(cell, nn.Dense):
cell.weight.set_data(initializer(me_init.ReidKaimingNormal(a=math.sqrt(5), mode='fan_out'),
cell.weight.shape))
if cell.bias is not None:
cell.bias.set_data(initializer('zeros', cell.bias.shape))
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
# defulat gamma 1 and beta 0, and if you set should be careful for the IRBlock gamma value
pass
for _, cell in self.cells_and_names():
if isinstance(cell, IRBlock):
# be careful for bn3 Do not change the name unless IRBlock last bn change name
cell.bn3.gamma.set_data(initializer('zeros', cell.bn3.gamma.shape))
def construct(self, x):
'''construct'''
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.head(x)
return x
def _faceresnet(block, layers, args):
model = FaceResNet(block, layers, args)
return model
def get_faceresnet(num_layers, args):
'''get_faceresnet'''
if num_layers == 9:
units = [1, 1, 1, 1]
elif num_layers == 18:
units = [2, 2, 2, 2]
elif num_layers == 34:
units = [3, 4, 6, 3]
elif num_layers == 49:
units = [3, 4, 14, 3]
elif num_layers == 50:
units = [3, 4, 14, 3]
elif num_layers == 74:
units = [3, 6, 24, 3]
elif num_layers == 90:
units = [3, 8, 30, 3]
elif num_layers == 100:
units = [3, 13, 30, 3]
elif num_layers == 101:
units = [3, 4, 23, 3]
elif num_layers == 152:
units = [3, 8, 36, 3]
elif num_layers == 200:
units = [3, 24, 36, 3]
elif num_layers == 269:
units = [3, 30, 48, 8]
else:
raise ValueError(
"no experiments done on num_layers {}, you can do it yourself".format(num_layers))
return _faceresnet(IRBlock, units, args)
def get_backbone_faceres(args):
backbone_type = args.backbone
layer_num = int(backbone_type[1:])
return get_faceresnet(layer_num, args)
def get_backbone(args):
return get_backbone_faceres(args)

View File

@ -0,0 +1,59 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition Callback."""
import time
from mindspore.train.callback import Callback
class ProgressMonitor(Callback):
'''ProgressMonitor'''
def __init__(self, reid_args):
super(ProgressMonitor, self).__init__()
self.epoch_start_time = 0
self.step_start_time = 0
self.globe_step_cnt = 0
self.local_step_cnt = 0
self.reid_args = reid_args
self._dataset_size = reid_args.steps_per_epoch
def begin(self, run_context):
self.run_context_ = run_context
if not self.reid_args.epoch_cnt:
self.reid_args.logger.info('start network train...')
def epoch_end(self, run_context):
cb_params = run_context.original_args()
if int(cb_params.cur_step_num / self._dataset_size) != self.reid_args.epoch_cnt:
self.reid_args.logger.info('epoch end, local passed')
self.reid_args.epoch_cnt += 1
def step_begin(self, run_context):
self.run_context_ = run_context
self.step_start_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
time_used = time.time() - self.step_start_time
cur_lr = self.reid_args.lrs[cb_params.cur_step_num]
fps_mean = self.reid_args.per_batch_size * self.reid_args.log_interval * self.reid_args.world_size / time_used
self.reid_args.logger.info('epoch[{}], iter[{}], loss:{}, cur_lr:{:.6f}, mean_fps:{:.2f} imgs/sec'.format(
self.reid_args.epoch_cnt, cb_params.cur_step_num, cb_params.net_outputs, cur_lr, fps_mean))
self.step_start_time = time.time()
def end(self, run_context):
self.run_context_ = run_context
self.reid_args.logger.info('end network train...')

View File

@ -0,0 +1,148 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" :===========================================================================
"""network config setting, will be used in train.py and eval.py."""
from easydict import EasyDict as edict
config_base = edict({
# dataset related
'data_dir': "your_dataset_path",
'num_classes': 1,
'per_batch_size': 192,
# network structure related
'backbone': 'r100',
'use_se': 1,
'emb_size': 512,
'act_type': 'relu',
'fp16': 1,
'pre_bn': 1,
'inference': 0,
'use_drop': 1,
'nc_16': 1,
# loss related
'margin_a': 1.0,
'margin_b': 0.2,
'margin_m': 0.3,
'margin_s': 64,
# optimizer related
'lr': 0.4,
'lr_scale': 1,
'lr_epochs': '8,14,18',
'weight_decay': 0.0002,
'momentum': 0.9,
'max_epoch': 20,
'pretrained': '',
'warmup_epochs': 2,
# distributed parameter
'is_distributed': 1,
'local_rank': 0,
'world_size': 1,
'model_parallel': 0,
# logging related
'log_interval': 100,
'ckpt_path': 'outputs',
'max_ckpts': -1,
'dynamic_init_loss_scale': 65536,
'ckpt_steps': 1000
})
config_beta = edict({
# dataset related
'data_dir': "your_dataset_path",
'num_classes': 1,
'per_batch_size': 192,
# network structure related
'backbone': 'r100',
'use_se': 0,
'emb_size': 256,
'act_type': 'relu',
'fp16': 1,
'pre_bn': 0,
'inference': 0,
'use_drop': 1,
'nc_16': 1,
# loss related
'margin_a': 1.0,
'margin_b': 0.2,
'margin_m': 0.3,
'margin_s': 64,
# optimizer related
'lr': 0.04,
'lr_scale': 1,
'lr_epochs': '8,14,18',
'weight_decay': 0.0002,
'momentum': 0.9,
'max_epoch': 20,
'pretrained': 'your_pretrained_model',
'warmup_epochs': 2,
# distributed parameter
'is_distributed': 1,
'local_rank': 0,
'world_size': 1,
'model_parallel': 0,
# logging related
'log_interval': 100,
'ckpt_path': 'outputs',
'max_ckpts': -1,
'dynamic_init_loss_scale': 65536,
'ckpt_steps': 1000
})
config_inference = edict({
# distributed parameter
'is_distributed': 0,
'local_rank': 0,
'world_size': 1,
# test weight
'weight': 'your_test_model',
'test_dir': 'your_dataset_path',
# model define
'backbone': 'r100',
'use_se': 0,
'emb_size': 256,
'act_type': 'relu',
'fp16': 1,
'pre_bn': 0,
'inference': 1,
'use_drop': 0,
# test and dis batch size
'test_batch_size': 128,
'dis_batch_size': 512,
# log
'log_interval': 100,
'ckpt_path': 'outputs/models',
# test and dis image list
'test_img_predix': '',
'test_img_list': '',
'dis_img_predix': '',
'dis_img_list': ''
})

View File

@ -0,0 +1,254 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition dataset."""
import sys
import os
import math
import pickle
from collections import defaultdict
import numpy as np
from PIL import Image, ImageFile
from mindspore.communication.management import get_group_size, get_rank
ImageFile.LOAD_TRUNCATED_IMAGES = True
__all__ = ['DistributedCustomSampler', 'CustomDataset']
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
class DistributedCustomSampler:
'''DistributedCustomSampler'''
def __init__(self, dataset, num_replicas=None, rank=None, is_distributed=1, shuffle=True, k=2):
assert isinstance(dataset, CustomDataset), 'Custom Sampler is Only Support Custom Dataset!!!'
if is_distributed:
if num_replicas is None:
num_replicas = get_group_size()
if rank is None:
rank = get_rank()
else:
if num_replicas is None:
num_replicas = 1
if rank is None:
rank = 0
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.ratio = 4.0
self.data_len = len(self.dataset.classes)
self.num_ids = int(math.ceil(self.data_len * 1.0 / self.num_replicas))
self.total_ids = self.num_ids * self.num_replicas
self.num_samples = math.ceil(len(self.dataset) * 1.0 / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.k = k
self.epoch_gen = 1
def _sample_(self, indices):
sampled = []
for indice in indices:
sampled_id = indice
sampled.extend(np.random.choice(self.dataset.id2range[sampled_id][:], self.k).tolist())
return sampled
def __iter__(self):
if self.shuffle:
# Note, the self.epoch parameter does not get updated in DE
self.epoch_gen = (self.epoch_gen + 1) & 0xffffffff
np.random.seed(self.epoch_gen)
indices = np.random.permutation(len(self.dataset.classes))
indices = indices.tolist()
else:
indices = list(range(len(self.dataset.classes)))
indices += indices[:(self.total_ids - len(indices))]
assert len(indices) == self.total_ids
indices = indices[self.rank*self.num_ids:(self.rank+1)*self.num_ids]
assert len(indices) == self.num_ids
sampled_idxs = self._sample_(indices)
return iter(sampled_idxs)
def __len__(self):
return self.num_ids * self.k
def set_epoch(self, epoch):
self.epoch = epoch
def merge_indices(self, list1, list2):
'''merge_indices'''
list_result = []
ct_1, ct_2 = 0, 0
for i in range(self.data_len):
if (i+1) % int(self.ratio+1) == 0:
list_result.append(list2[ct_2])
ct_2 += 1
else:
list_result.append(list1[ct_1])
ct_1 += 1
return list_result
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def make_dataset(dir_1, class_to_idx, extensions=None, is_valid_file=None):
'''make_dataset'''
images = []
dir_1 = os.path.expanduser(dir_1)
if not (extensions is None) ^ (is_valid_file is None):
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def f(x):
return has_file_allowed_extension(x, extensions)
is_valid_file = f
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir_1, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = (path, class_to_idx[target])
images.append(item)
return images
class ImageFolderDataset:
'''ImageFolderDataset'''
def __init__(self, root, cache_path, is_distributed):
if not os.path.isfile(cache_path):
self.classes, self.classes_to_idx = self._find_classes(root)
self.samples = make_dataset(root, self.classes_to_idx, IMG_EXTENSIONS, None)
self.id2range = self._build_id2range()
cache = dict()
cache['classes'] = self.classes
cache['classes_to_idx'] = self.classes_to_idx
cache['samples'] = self.samples
cache['id2range'] = self.id2range
if is_distributed:
print("******* TODO: All workers will write cache... Need to only dump when rank == 0 ******")
if get_rank() == 0:
with open(cache_path, 'wb') as fw:
pickle.dump(cache, fw)
print('local dump cache:{}'.format(cache_path))
else:
with open(cache_path, 'wb') as fw:
pickle.dump(cache, fw)
print('local dump cache:{}'.format(cache_path))
else:
print('loading cache from %s'%cache_path)
with open(cache_path, 'rb') as fr:
cache = pickle.load(fr)
self.classes, self.classes_to_idx, self.samples, self.id2range = cache['classes'], \
cache['classes_to_idx'], \
cache['samples'], cache['id2range']
self.all_image_idxs = range(len(self.samples))
self.classes = list(self.id2range.keys())
def _find_classes(self, dir_1):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir_1) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir_1) if os.path.isdir(os.path.join(dir_1, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def _build_id2range(self):
'''_build_id2range'''
id2range = defaultdict(list)
ret_range = defaultdict(list)
for idx, sample in enumerate(self.samples):
label = sample[1]
id2range[label].append((sample, idx))
for key in id2range:
id2range[key].sort(key=lambda x: int(os.path.basename(x[0][0]).split('.')[0]))
for item in id2range[key]:
ret_range[key].append(item[1])
return ret_range
def __getitem__(self, index):
return self.samples[index]
def __len__(self):
return len(self.samples)
def pil_loader(path):
"""
Loads the image
Args:
path: path to the image
Returns:
Object: pil_loader
"""
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class CustomDataset:
'''CustomDataset'''
def __init__(self, root, cache_path, is_distributed=1, transform=None, target_transform=None,
loader=pil_loader):
self.dataset = ImageFolderDataset(root, cache_path, is_distributed)
print('CustomDataset len(dataset):{}'.format(len(self.dataset)))
self.loader = loader
self.transform = transform
self.target_transform = target_transform
self.classes = self.dataset.classes
self.id2range = self.dataset.id2range
def __getitem__(self, index):
path, target = self.dataset[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.dataset)

View File

@ -0,0 +1,53 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom net layer"""
import mindspore.nn as nn
from mindspore.nn import Cell
from mindspore.ops import operations as P
class CustomMatMul(Cell):
def __init__(self, transpose_a=False, transpose_b=False):
super(CustomMatMul, self).__init__()
self.fc = P.MatMul(transpose_a=transpose_a, transpose_b=transpose_b)
def construct(self, x1, x2):
out = self.fc(x1, x2)
return out
class Cut(Cell):
def construct(self, x):
return x
def bn_with_initialize(out_channels, momentum=0.9, use_inference=0):
if use_inference == 1:
bn = nn.BatchNorm2d(out_channels, momentum=momentum, eps=1e-5)
else:
bn = nn.BatchNorm2d(out_channels, momentum=momentum, eps=1e-5).add_flags_recursive(fp32=True)
return bn
def fc_with_initialize(input_channels, out_channels):
return nn.Dense(input_channels, out_channels)
def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
pad_mode=pad_mode, group=groups, has_bias=False, dilation=dilation, padding=padding)
def conv1x1(in_channels, out_channels, pad_mode="pad", stride=1, padding=0):
"""1x1 convolution"""
return nn.Conv2d(in_channels, out_channels, pad_mode=pad_mode, kernel_size=1, stride=stride,
has_bias=False, padding=padding)

View File

@ -0,0 +1,58 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition dataset."""
import os
import math
import numpy as np
import mindspore.dataset as de
import mindspore.dataset.vision.py_transforms as F
import mindspore.dataset.transforms.py_transforms as F2
from src.custom_dataset import DistributedCustomSampler, CustomDataset
__all__ = ['get_de_dataset']
def get_de_dataset(args):
'''get_de_dataset'''
lbl_transforms = [F.ToType(np.int32)]
transform_label = F2.Compose(lbl_transforms)
drop_remainder = False
transforms = [F.ToPIL(),
F.RandomHorizontalFlip(),
F.ToTensor(),
F.Normalize(mean=[0.5], std=[0.5])]
transform = F2.Compose(transforms)
cache_path = os.path.join('cache', os.path.basename(args.data_dir), 'data_cache.pkl')
print(cache_path)
if not os.path.exists(os.path.dirname(cache_path)):
os.makedirs(os.path.dirname(cache_path))
dataset = CustomDataset(args.data_dir, cache_path, args.is_distributed)
args.logger.info("dataset len:{}".format(dataset.__len__()))
sampler = DistributedCustomSampler(dataset, num_replicas=args.world_size, rank=args.local_rank,
is_distributed=args.is_distributed)
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
args.logger.info("after sampler de_dataset datasize :{}".format(de_dataset.get_dataset_size()))
de_dataset = de_dataset.map(input_columns="image", operations=transform)
de_dataset = de_dataset.map(input_columns="label", operations=transform_label)
de_dataset = de_dataset.project(columns=["image", "label"])
de_dataset = de_dataset.batch(args.per_batch_size, drop_remainder=drop_remainder)
num_iter_per_npu = math.ceil(len(dataset) * 1.0 / args.world_size / args.per_batch_size)
num_classes = len(dataset.classes)
return de_dataset, num_iter_per_npu, num_classes

View File

@ -0,0 +1,116 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init network"""
import math
import numpy as np
import mindspore.nn as nn
from mindspore.common.initializer import initializer
from src.backbone.resnet import IRBlock
from src import metric_factory
from src import me_init
np.random.seed(1)
def init_net(args, network):
'''init_net'''
for name, cell in network.cells_and_names():
if isinstance(cell, nn.Conv2d):
find_flag = True
if cell.weight is not None:
cell.weight.set_data(initializer(me_init.ReidKaimingUniform(a=math.sqrt(5), mode='fan_out'),
cell.weight.shape))
if cell.bias is not None:
cell.bias.set_data(initializer('zeros', cell.bias.shape))
if find_flag:
find_info = 'PARAMETER FIND'
else:
find_info = 'PARAMETER UNFIND'
args.logger.info('---------------{}---------------'.format(find_info))
args.logger.info(f'{name} --> {cell.weight} {cell.bias}')
args.logger.info('---------------{}---------------'.format(find_info))
elif isinstance(cell, nn.Dense):
find_flag = True
if cell.weight is not None:
cell.weight.set_data(initializer(me_init.ReidKaimingNormal(a=math.sqrt(5), mode='fan_out'),
cell.weight.shape))
if cell.bias is not None:
cell.bias.set_data(initializer('zeros', cell.bias.shape))
if find_flag:
find_info = 'PARAMETER FIND'
else:
find_info = 'PARAMETER UNFIND'
args.logger.info('---------------{}---------------'.format(find_info))
args.logger.info(f'{name} --> {cell.weight} {cell.bias}')
args.logger.info('---------------{}---------------'.format(find_info))
elif isinstance(cell, nn.BatchNorm2d):
find_flag = True
# defulat gamma 1 and beta 0, and if you set should be careful for the IRBlock gamma value
if find_flag:
find_info = 'PARAMETER FIND'
else:
find_info = 'PARAMETER UNFIND'
args.logger.info('---------------{}---------------'.format(find_info))
args.logger.info(f'{name} --> {cell.gamma} {cell.beta}')
args.logger.info('---------------{}---------------'.format(find_info))
elif isinstance(cell, nn.BatchNorm1d):
pass
elif isinstance(cell, metric_factory.CombineMarginFC):
find_flag = True
if cell.weight is not None:
cell.weight.set_data(initializer(me_init.ReidKaimingUniform(a=math.sqrt(5), mode='fan_out'),
cell.weight.shape))
if find_flag:
find_info = 'PARAMETER FIND'
else:
find_info = 'PARAMETER UNFIND'
args.logger.info('---------------{}---------------'.format(find_info))
args.logger.info(f'{name} --> {cell.weight}')
args.logger.info('---------------{}---------------'.format(find_info))
elif isinstance(cell, nn.PReLU):
find_flag = True
if find_flag:
find_info = 'PARAMETER FIND'
else:
find_info = 'PARAMETER UNFIND'
args.logger.info('---------------{}---------------'.format(find_info))
args.logger.info(f'{name} --> {cell.w}')
args.logger.info('---------------{}---------------'.format(find_info))
elif isinstance(cell, IRBlock):
find_flag = True
# be careful for bn3 Do not change the name unless IRBlock last bn change name
cell.bn3.gamma.set_data(initializer('zeros', cell.bn3.gamma.shape))
if find_flag:
find_info = 'PARAMETER FIND'
else:
find_info = 'PARAMETER UNFIND'
args.logger.info('---------------{}---------------'.format(find_info))
args.logger.info(f'{name} --> {cell.bn3.gamma}')
args.logger.info('---------------{}---------------'.format(find_info))

View File

@ -0,0 +1,78 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition loss."""
from mindspore import Tensor
from mindspore.nn import Cell
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
__all__ = ['get_loss']
eps = 1e-24
class CrossEntropy(_Loss):
'''CrossEntropy'''
def __init__(self, args):
super(CrossEntropy, self).__init__()
self.args_1 = args
self.exp = P.Exp()
self.sum = P.ReduceSum()
self.reshape = P.Reshape()
self.log = P.Log()
self.cast = P.Cast()
self.eps_const = Tensor(eps, dtype=mstype.float32)
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
def construct(self, logit, label):
'''construct'''
exp = self.exp(logit)
exp_sum = self.sum(exp, -1)
exp_sum = self.reshape(exp_sum, (F.shape(exp_sum)[0], 1))
softmax_result = exp / exp_sum
one_hot_label = self.onehot(
self.cast(label, mstype.int32), F.shape(logit)[1], self.on_value, self.off_value)
loss = self.sum((self.log(softmax_result + self.eps_const) *
self.cast(one_hot_label, mstype.float32) *
self.cast(F.scalar_to_array(-1), mstype.float32)), -1)
batch_size = F.shape(logit)[0]
batch_size_tensor = self.cast(
F.scalar_to_array(batch_size), mstype.float32)
loss = self.sum(loss, -1) / batch_size_tensor
return loss
class Criterions(Cell):
'''Criterions'''
def __init__(self, args):
super(Criterions, self).__init__()
self.criterion_ce = CrossEntropy(args)
self.total_loss = Tensor(0.0, dtype=mstype.float32)
def construct(self, margin_logit, label):
total_loss = self.total_loss
loss_ce = self.criterion_ce(margin_logit, label)
total_loss = total_loss + loss_ce
return total_loss
def get_loss(args):
return Criterions(args)

View File

@ -0,0 +1,55 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition learning rate scheduler."""
from collections import Counter
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * (current_step + 1)
return learning_rate
def warmup_step_list(args, gamma=0.1):
'''warmup_step_list'''
base_lr = args.lr * args.lr_scale
warmup_init_lr = 0
total_steps = int(args.max_epoch * args.steps_per_epoch)
warmup_steps = int(args.warmup_epochs * args.steps_per_epoch)
milestones = args.lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * args.steps_per_epoch
milestones_steps.append(milestones_step)
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
lrs = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_learning_rate(
i, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lrs.append(lr)
args.logger.info('lrs[:10]:{}, lrs[-10:]:{}, total_steps:{}, len(lrs):{}'.
format(lrs[:10], lrs[-10:], total_steps, len(lrs)))
return lrs
def list_to_gen(nlist):
for nlist_item in nlist:
yield nlist_item
def warmup_step(args, gamma=0.1):
lrs = warmup_step_list(args, gamma=gamma)
for lr in lrs:
yield lr

View File

@ -0,0 +1,207 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init"""
import math
import numpy as np
from mindspore.common import initializer as init
from mindspore.common.initializer import _assignment
def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
Examples:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _calculate_correct_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `mindspore.Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
fan = _calculate_correct_fan(arr, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, arr.shape)
def kaiming_normal_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
normal distribution. The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `mindspore.Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
fan = _calculate_correct_fan(arr, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, arr.shape)
def _calculate_fan_in_and_fan_out(arr):
'''_calculate_fan_in_and_fan_out'''
dimensions = len(arr.shape)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
num_input_fmaps = arr.shape[1]
num_output_fmaps = arr.shape[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = arr[0][0].size
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def xavier_uniform_(arr, gain=1.):
# type: (Tensor, float) -> Tensor
r"""Fills the input `Tensor` with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-a, a)` where
.. math::
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
Also known as Glorot initialization.
Args:
tensor: an n-dimensional `mindspore.Tensor`
gain: an optional scaling factor
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(arr)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-a, a, arr.shape)
class ReidXavierUniform(init.Initializer):
def __init__(self, gain=1.):
super(ReidXavierUniform, self).__init__()
self.gain = gain
def _initialize(self, arr):
tmp = xavier_uniform_(arr, self.gain)
_assignment(arr, tmp)
class ReidKaimingUniform(init.Initializer):
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(ReidKaimingUniform, self).__init__()
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
def _initialize(self, arr):
tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
_assignment(arr, tmp)
class ReidKaimingNormal(init.Initializer):
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(ReidKaimingNormal, self).__init__()
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
def _initialize(self, arr):
tmp = kaiming_normal_(arr, self.a, self.mode, self.nonlinearity)
_assignment(arr, tmp)

View File

@ -0,0 +1,86 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metric"""
import numpy as np
from mindspore import Tensor, Parameter
from mindspore.nn import Cell
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from mindspore.common import dtype as mstype
from src import me_init
np.random.seed(1)
__all__ = ['get_metric_fc', 'CombineMarginFC']
class CombineMarginFC(Cell):
'''CombineMarginFC'''
def __init__(self, args):
super(CombineMarginFC, self).__init__()
weight_shape = [args.num_classes, args.emb_size]
weight_init = initializer(me_init.ReidXavierUniform(), weight_shape)
self.weight = Parameter(weight_init, name='weight')
self.m = args.margin_m
self.s = args.margin_s
self.a = args.margin_a
self.b = args.margin_b
self.m_const = Tensor(self.m, dtype=mstype.float16)
self.a_const = Tensor(self.a, dtype=mstype.float16)
self.b_const = Tensor(self.b, dtype=mstype.float16)
self.s_const = Tensor(self.s, dtype=mstype.float16)
self.m_const_zero = Tensor(0, dtype=mstype.float16)
self.a_const_one = Tensor(1, dtype=mstype.float16)
self.normalize = P.L2Normalize(axis=1)
self.fc = P.MatMul(transpose_b=True)
self.onehot = P.OneHot()
self.transpose = P.Transpose()
self.acos = P.ACos()
self.cos = P.Cos()
self.cast = P.Cast()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
def construct(self, x, label):
'''construct'''
x = self.normalize(x)
w = self.normalize(self.weight)
cosine = self.fc(x, w)
cosine_shape = F.shape(cosine)
one_hot_float = self.onehot(
self.cast(label, mstype.int32), cosine_shape[1], self.on_value, self.off_value)
one_hot_float = self.cast(one_hot_float, mstype.float16)
if self.m == 0 and self.a == 1:
one_hot_float = one_hot_float * self.b_const
output = cosine - one_hot_float
output = output * self.s_const
else:
theta = self.acos(cosine)
theta = self.a_const * theta
theta = self.m_const + theta
body = self.cos(theta)
body = body - self.b_const
cos_mask = self.cast(F.scalar_to_array(
1.0), mstype.float16) - one_hot_float
output = body * one_hot_float + cosine * cos_mask
output = output * self.s_const
return output
def get_metric_fc(args):
return CombineMarginFC(args)

View File

@ -0,0 +1,86 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom logger."""
import sys
import os
import logging
from datetime import datetime
logger_name_1 = 'FaceRecognition'
class HFLogger(logging.Logger):
'''HFLogger'''
def __init__(self, logger_name, local_rank=0):
super(HFLogger, self).__init__(logger_name)
if local_rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, local_rank=0):
'''setup_logging_file'''
self.local_rank = local_rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(local_rank)
log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
fh.close()
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def get_logger(path, rank):
logger = HFLogger(logger_name_1, rank)
logger.setup_logging_file(path, rank)
return logger
class AverageMeter():
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
return '{}'.format(self.avg)

View File

@ -0,0 +1,235 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition train."""
import os
import argparse
import mindspore
from mindspore.nn import Cell
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size, init, get_rank
from mindspore.nn.optim import Momentum
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config_base, config_beta
from src.my_logging import get_logger
from src.init_network import init_net
from src.dataset_factory import get_de_dataset
from src.backbone.resnet import get_backbone
from src.metric_factory import get_metric_fc
from src.loss_factory import get_loss
from src.lrsche_factory import warmup_step_list, list_to_gen
from src.callback_factory import ProgressMonitor
mindspore.common.seed.set_seed(1)
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
device_id=devid, reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
class DistributedHelper(Cell):
'''DistributedHelper'''
def __init__(self, backbone, margin_fc):
super(DistributedHelper, self).__init__()
self.backbone = backbone
self.margin_fc = margin_fc
if margin_fc is not None:
self.has_margin_fc = 1
else:
self.has_margin_fc = 0
def construct(self, x, label):
embeddings = self.backbone(x)
if self.has_margin_fc == 1:
return embeddings, self.margin_fc(embeddings, label)
return embeddings
class BuildTrainNetwork(Cell):
'''BuildTrainNetwork'''
def __init__(self, network, criterion, args_1):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
self.args = args_1
if int(args_1.model_parallel) == 0:
self.is_model_parallel = 0
else:
self.is_model_parallel = 1
def construct(self, input_data, label):
if self.is_model_parallel == 0:
_, output = self.network(input_data, label)
loss = self.criterion(output, label)
else:
_ = self.network(input_data, label)
loss = self.criterion(None, label)
return loss
def parse_args():
parser = argparse.ArgumentParser('MindSpore Face Recognition')
parser.add_argument('--train_stage', type=str, default='base', help='train stage, base or beta')
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
args_opt_1, _ = parser.parse_known_args()
return args_opt_1
if __name__ == "__main__":
args_opt = parse_args()
support_train_stage = ['base', 'beta']
if args_opt.train_stage.lower() not in support_train_stage:
args.logger.info('support train stage is:{}, while yours is:{}'.
format(support_train_stage, args_opt.train_stage))
raise ValueError('train stage not support.')
args = config_base if args_opt.train_stage.lower() == 'base' else config_beta
args.is_distributed = args_opt.is_distributed
if args_opt.is_distributed:
init()
args.local_rank = get_rank()
args.world_size = get_group_size()
parallel_mode = ParallelMode.HYBRID_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=args.world_size, gradients_mean=True)
if not os.path.exists(args.data_dir):
args.logger.info('ERROR, data_dir is not exists, please set data_dir in config.py')
raise ValueError('ERROR, data_dir is not exists, please set data_dir in config.py')
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
log_path = os.path.join(args.ckpt_path, 'logs')
args.logger = get_logger(log_path, args.local_rank)
if args.local_rank % 8 == 0:
if not os.path.exists(args.ckpt_path):
os.makedirs(args.ckpt_path)
args.logger.info('args.world_size:{}'.format(args.world_size))
args.logger.info('args.local_rank:{}'.format(args.local_rank))
args.logger.info('args.lr:{}'.format(args.lr))
momentum = args.momentum
weight_decay = args.weight_decay
de_dataset, steps_per_epoch, num_classes = get_de_dataset(args)
args.logger.info('de_dataset:{}'.format(de_dataset.get_dataset_size()))
args.steps_per_epoch = steps_per_epoch
args.num_classes = num_classes
args.logger.info('loaded, nums: {}'.format(args.num_classes))
if args.nc_16 == 1:
if args.model_parallel == 0:
if args.num_classes % 16 == 0:
args.logger.info('data parallel aleardy 16, nums: {}'.format(args.num_classes))
else:
args.num_classes = (args.num_classes // 16 + 1) * 16
else:
if args.num_classes % (args.world_size * 16) == 0:
args.logger.info('model parallel aleardy 16, nums: {}'.format(args.num_classes))
else:
args.num_classes = (args.num_classes // (args.world_size * 16) + 1) * args.world_size * 16
args.logger.info('for D, loaded, class nums: {}'.format(args.num_classes))
args.logger.info('steps_per_epoch:{}'.format(args.steps_per_epoch))
args.logger.info('img_total_num:{}'.format(args.steps_per_epoch * args.per_batch_size))
args.logger.info('get_backbone----in----')
_backbone = get_backbone(args)
args.logger.info('get_backbone----out----')
args.logger.info('get_metric_fc----in----')
margin_fc_1 = get_metric_fc(args)
args.logger.info('get_metric_fc----out----')
args.logger.info('DistributedHelper----in----')
network_1 = DistributedHelper(_backbone, margin_fc_1)
args.logger.info('DistributedHelper----out----')
args.logger.info('network fp16----in----')
if args.fp16 == 1:
network_1.add_flags_recursive(fp16=True)
args.logger.info('network fp16----out----')
criterion_1 = get_loss(args)
if args.fp16 == 1 and args.model_parallel == 0:
criterion_1.add_flags_recursive(fp32=True)
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
if args_opt.train_stage.lower() == 'base':
for key, value in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = value
else:
for key, value in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
if 'layers.' in key and 'bn1' in key:
continue
elif 'se' in key:
continue
elif 'head' in key:
continue
elif 'margin_fc.weight' in key:
continue
else:
param_dict_new[key[8:]] = value
load_param_into_net(network_1, param_dict_new)
args.logger.info('load model {} success'.format(args.pretrained))
else:
init_net(args, network_1)
train_net = BuildTrainNetwork(network_1, criterion_1, args)
args.logger.info('args:{}'.format(args))
# call warmup_step should behind the args steps_per_epoch
args.lrs = warmup_step_list(args, gamma=0.1)
lrs_gen = list_to_gen(args.lrs)
opt = Momentum(params=train_net.trainable_params(), learning_rate=lrs_gen, momentum=momentum,
weight_decay=weight_decay)
scale_manager = DynamicLossScaleManager(init_loss_scale=args.dynamic_init_loss_scale, scale_factor=2,
scale_window=2000)
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=scale_manager)
save_checkpoint_steps = args.ckpt_steps
args.logger.info('save_checkpoint_steps:{}'.format(save_checkpoint_steps))
if args.max_ckpts == -1:
keep_checkpoint_max = int(args.steps_per_epoch * args.max_epoch / save_checkpoint_steps) + 5 # for more than 5
else:
keep_checkpoint_max = args.max_ckpts
args.logger.info('keep_checkpoint_max:{}'.format(keep_checkpoint_max))
ckpt_config = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, keep_checkpoint_max=keep_checkpoint_max)
max_epoch_train = args.max_epoch
args.logger.info('max_epoch_train:{}'.format(max_epoch_train))
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.ckpt_path, prefix='{}'.format(args.local_rank))
args.epoch_cnt = 0
progress_cb = ProgressMonitor(args)
new_epoch_train = max_epoch_train * steps_per_epoch // args.log_interval
model.train(new_epoch_train, de_dataset, callbacks=[progress_cb, ckpt_cb], sink_size=args.log_interval)