arcface
This commit is contained in:
parent
44dd7d994b
commit
b9559dee59
|
@ -0,0 +1,228 @@
|
|||
目录
|
||||
|
||||
- [目录](#目录)
|
||||
- [Arcface概述](#Arcface概述)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本和样例代码](#脚本和样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [推理性能](#推理性能)
|
||||
- [使用方法](#使用方法)
|
||||
- [推理](#推理)
|
||||
- [迁移学习](#迁移学习)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#ModelZoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# Arcface概述
|
||||
|
||||
使用深度卷积神经网络进行大规模人脸识别的特征学习中的主要挑战之一是设计适当的损失函数以增强判别能力。继SoftmaxLoss、Center Loss、A-Softmax Loss、Cosine Margin Loss之后,Arcface在人脸识别中具有更加良好的表现。Arcface是传统softmax的改进, 将类之间的距离映射到超球面的间距,论文给出了对此的清晰几何解释。同时,基于10多个人脸识别基准的实验评估,证明了Arcface优于现有的技术,并且可以轻松实现。
|
||||
|
||||
[论文](https://arxiv.org/pdf/1801.07698v3.pdf): Deng J , Guo J , Zafeiriou S . ArcFace: Additive Angular Margin Loss for Deep Face Recognition[J]. 2018.
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的训练数据集:[MS1MV2](https://github.com/deepinsight/insightface/wiki/Dataset-Zoo)
|
||||
|
||||
验证数据集:lfw,cfp-fp,agedb,cplfw,calfw,[IJB-B,IJB-C](https://pan.baidu.com/s/1oer0p4_mcOrs4cfdeWfbFg)
|
||||
|
||||
训练集:5,822,653张图片,85742个类
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件:昇腾处理器(Ascend)
|
||||
- 使用Ascend处理器来搭建硬件环境。
|
||||
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
```python
|
||||
# 分布式训练运行示例
|
||||
sh scripts/run_distribute_train.sh rank_size /path/dataset
|
||||
|
||||
# 单机训练运行示例
|
||||
sh scripts/run_standalone_train.sh /path/dataset
|
||||
|
||||
# 运行评估示例
|
||||
sh scripts/run_eval.sh /path/evalset /path/ckpt
|
||||
```
|
||||
|
||||
## 脚本说明
|
||||
|
||||
## 脚本和样例代码
|
||||
|
||||
```path
|
||||
└── Arcface
|
||||
├── README.md // Arcface相关描述
|
||||
├── scripts
|
||||
├── run_distribute_train.sh // 用于分布式训练的shell脚本
|
||||
├── run_standalone_train.sh // 用于单机训练的shell脚本
|
||||
├── run_eval_ijbc.sh // 用于IJBC数据集评估的shell脚本
|
||||
└── run_eval.sh // 用于评估的shell脚本
|
||||
├──src
|
||||
├── export.py
|
||||
├── loss.py //损失函数
|
||||
├── dataset.py // 创建数据集
|
||||
├── iresnet.py // ResNet架构
|
||||
├──val.py // 测试脚本
|
||||
├──train.py // 训练脚本
|
||||
├──requirements.txt
|
||||
|
||||
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
```python
|
||||
train.py和val.py中主要参数如下:
|
||||
|
||||
-- modelarts:是否使用modelarts平台训练。可选值为True、False。默认为False。
|
||||
-- device_id:用于训练或评估数据集的设备ID。当使用train.sh进行分布式训练时,忽略此参数。
|
||||
-- device_num:使用tra进行分布式训练时使用的设备数。
|
||||
-- train_url:checkpoint的输出路径。
|
||||
-- data_url:训练集路径。
|
||||
-- ckpt_url:checkpoint路径。
|
||||
-- eval_url:验证集路径。
|
||||
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 分布式训练
|
||||
|
||||
```shell
|
||||
sh scripts/run_distribute_train.sh rank_size /path/dataset
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。可以通过`device[X]/train.log`文件查看结果。
|
||||
采用以下方式达到损失值:
|
||||
|
||||
```log
|
||||
epoch: 2 step: 11372, loss is 12.807039
|
||||
epoch time: 1104549.619 ms, per step time: 97.129 ms
|
||||
epoch: 3 step: 11372, loss is 9.13787
|
||||
...
|
||||
epoch: 21 step: 11372, loss is 1.5028578
|
||||
epoch time: 1104673.362 ms, per step time: 97.140 ms
|
||||
epoch: 22 step: 11372, loss is 0.8846929
|
||||
epoch time: 1104929.793 ms, per step time: 97.162 ms
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
- 在Ascend环境运行时评估lfw、cfp_fp、agedb_30、calfw、cplfw数据集
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。请将检查点路径设置为绝对全路径,例如“username/arcface/arcface-11372-1.ckpt”。
|
||||
|
||||
```bash
|
||||
sh scripts/run_eval.sh /path/evalset /path/ckpt
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过eval.log文件查看结果。测试数据集的准确性如下:
|
||||
|
||||
```bash
|
||||
[lfw]Accuracy-Flip: 0.99817+-0.00273
|
||||
[cfp_fp]Accuracy-Flip: 0.98000+-0.00586
|
||||
[agedb_30]Accuracy-Flip: 0.98100+-0.00642
|
||||
[calfw]Accuracy-Flip: 0.96150+-0.01099
|
||||
[cplfw]Accuracy-Flip: 0.92583+-0.01367
|
||||
```
|
||||
|
||||
- 在Ascend环境运行时评估IJB-B、IJB-C数据集
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。请将检查点路径设置为绝对全路径,例如“username/arcface/arcface-11372-1.ckpt”。
|
||||
|
||||
同时,情确保传入的评估数据集路径为“IJB_release/IJBB/”或“IJB_release/IJBC/”。
|
||||
|
||||
```bash
|
||||
sh scripts/run_eval_ijbc.sh /path/evalset /path/ckpt
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过eval.log文件查看结果。测试数据集的准确性如下:
|
||||
|
||||
```bash
|
||||
+-----------+-------+-------+--------+-------+-------+-------+
|
||||
| Methods | 1e-06 | 1e-05 | 0.0001 | 0.001 | 0.01 | 0.1 |
|
||||
+-----------+-------+-------+--------+-------+-------+-------+
|
||||
| ijbb-IJBB | 40.01 | 87.91 | 94.36 | 96.48 | 97.72 | 98.70 |
|
||||
+-----------+-------+-------+--------+-------+-------+-------+
|
||||
|
||||
+-----------+-------+-------+--------+-------+-------+-------+
|
||||
| Methods | 1e-06 | 1e-05 | 0.0001 | 0.001 | 0.01 | 0.1 |
|
||||
+-----------+-------+-------+--------+-------+-------+-------+
|
||||
| ijbc-IJBC | 82.08 | 93.37 | 95.87 | 97.40 | 98.40 | 99.05 |
|
||||
+-----------+-------+-------+--------+-------+-------+-------+
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
| 参数 | Arcface |
|
||||
| ------------- | ------------------------------------------------------------ |
|
||||
| 模型版本 | arcface |
|
||||
| 资源 | Ascend 910; CPU: 2.60GHz,192内核;内存,755G |
|
||||
| 上传日期 | 2021-05-30 |
|
||||
| MindSpore版本 | 1.2.0-c77-python3.7-aarch64 |
|
||||
| 数据集 | MS1MV2 |
|
||||
| 训练参数 | lr=0.08; gamma=0.1 |
|
||||
| 优化器 | SGD |
|
||||
| 损失函数 | Arcface |
|
||||
| 输出 | 概率 |
|
||||
| 损失 | 0.6 |
|
||||
| 速度 | 1卡:108毫秒/步;8卡:97毫秒/步 |
|
||||
| 总时间 | 1卡:65小时;8卡:8.5小时 |
|
||||
| 参数(M) | 85.2 |
|
||||
| 微调检查点 | 1249M (.ckpt file) |
|
||||
| 脚本 | [脚本路径](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/arcface) |
|
||||
|
||||
### 推理性能
|
||||
|
||||
| 参数 | Arcface |
|
||||
| ------------- | ------------------------ |
|
||||
| 模型版本 | arcface |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021/05/30 |
|
||||
| MindSpore版本 | 1.2.0-c77-python3.7-aarch64 |
|
||||
| 数据集 | IJBC、IJBB、lfw、cfp_fp、agedb_30、calfw、cplfw |
|
||||
| 输出 | 概率 |
|
||||
| 准确性 | lfw:0.998 cfp_fp:0.98 agedb_30:0.981 calfw:0.961 cplfw:0.926 IJB-B:0.943 IJB-C:0.958 |
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 推理
|
||||
|
||||
如果您需要使用已训练模型在GPU、Ascend 910、Ascend 310等多个硬件平台上进行推理,可参考[此处](https://www.mindspore.cn/tutorial/inference/zh-CN/r1.2/index.html)。
|
||||
|
||||
### 迁移学习
|
||||
|
||||
待补充
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
网络的初始参数均为随即初始化。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,524 @@
|
|||
# coding: utf-8
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
evaluation of IJBB or IJBC
|
||||
'''
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import timeit
|
||||
import argparse
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
import sklearn
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
|
||||
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
||||
from prettytable import PrettyTable
|
||||
|
||||
import cv2
|
||||
|
||||
from skimage import transform as trans
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from src.iresnet import iresnet100
|
||||
|
||||
matplotlib.use('Agg')
|
||||
|
||||
sys.path.insert(0, "../")
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
parser = argparse.ArgumentParser(description='do ijb test')
|
||||
# general
|
||||
parser.add_argument('--model-prefix', default='', help='path to load model.')
|
||||
parser.add_argument('--image-path', default='', type=str, help='')
|
||||
parser.add_argument('--result-dir', default='.', type=str, help='')
|
||||
parser.add_argument('--batch-size', default=128, type=int, help='')
|
||||
parser.add_argument('--network', default='iresnet50', type=str, help='')
|
||||
parser.add_argument('--job', default='insightface', type=str, help='job name')
|
||||
parser.add_argument('--target', default='IJBC', type=str,
|
||||
help='target, set to IJBC or IJBB')
|
||||
args = parser.parse_args()
|
||||
|
||||
target = args.target
|
||||
model_path = args.model_prefix
|
||||
image_path = args.image_path
|
||||
result_dir = args.result_dir
|
||||
gpu_id = None
|
||||
use_norm_score = True # if Ture, TestMode(N1)
|
||||
use_detector_score = True # if Ture, TestMode(D1)
|
||||
use_flip_test = True # if Ture, TestMode(F1)
|
||||
job = args.job
|
||||
|
||||
|
||||
class Embedding(nn.Cell):
|
||||
'''Embedding
|
||||
'''
|
||||
def __init__(self, prefix, data_shape, batch_size=1):
|
||||
super().__init__()
|
||||
image_size = (112, 112)
|
||||
self.image_size = image_size
|
||||
resnet = iresnet100()
|
||||
param_dict = load_checkpoint(args.model_prefix)
|
||||
load_param_into_net(resnet, param_dict)
|
||||
self.model = resnet
|
||||
src = np.array([
|
||||
[30.2946, 51.6963],
|
||||
[65.5318, 51.5014],
|
||||
[48.0252, 71.7366],
|
||||
[33.5493, 92.3655],
|
||||
[62.7299, 92.2041]], dtype=np.float32)
|
||||
src[:, 0] += 8.0
|
||||
self.src = src
|
||||
self.batch_size = batch_size
|
||||
self.data_shape = data_shape
|
||||
self.reshape = ops.Reshape()
|
||||
self.div = ops.Div()
|
||||
self.sub = ops.Sub()
|
||||
self.shape = ops.Shape()
|
||||
self.print = ops.Print()
|
||||
|
||||
def get(self, rimg, landmark):
|
||||
'''get
|
||||
'''
|
||||
assert landmark.shape[0] == 68 or landmark.shape[0] == 5
|
||||
assert landmark.shape[1] == 2
|
||||
if landmark.shape[0] == 68:
|
||||
landmark5 = np.zeros((5, 2), dtype=np.float32)
|
||||
landmark5[0] = (landmark[36] + landmark[39]) / 2
|
||||
landmark5[1] = (landmark[42] + landmark[45]) / 2
|
||||
landmark5[2] = landmark[30]
|
||||
landmark5[3] = landmark[48]
|
||||
landmark5[4] = landmark[54]
|
||||
else:
|
||||
landmark5 = landmark
|
||||
tform = trans.SimilarityTransform()
|
||||
tform.estimate(landmark5, self.src)
|
||||
M = tform.params[0:2, :]
|
||||
img = cv2.warpAffine(rimg,
|
||||
M, (self.image_size[1], self.image_size[0]),
|
||||
borderValue=0.0)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img_flip = np.fliplr(img)
|
||||
img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
|
||||
img_flip = np.transpose(img_flip, (2, 0, 1))
|
||||
input_blob = np.zeros(
|
||||
(2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
|
||||
input_blob[0] = img
|
||||
input_blob[1] = img_flip
|
||||
return input_blob
|
||||
|
||||
# @torch.no_grad()
|
||||
def forward_db(self, batch_data):
|
||||
'''forward_db
|
||||
'''
|
||||
imgs = Tensor(batch_data, mstype.float32)
|
||||
imgs = self.div(imgs, 255)
|
||||
imgs = self.sub(imgs, 0.5)
|
||||
imgs = self.div(imgs, 0.5)
|
||||
feat = self.model(imgs)
|
||||
shape = self.shape(feat)
|
||||
feat = self.reshape(feat, (self.batch_size, 2 * shape[1]))
|
||||
return feat.asnumpy()
|
||||
|
||||
|
||||
# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
|
||||
def divideIntoNstrand(listTemp, n):
|
||||
twoList = [[] for i in range(n)]
|
||||
for i, e in enumerate(listTemp):
|
||||
twoList[i % n].append(e)
|
||||
return twoList
|
||||
|
||||
|
||||
def read_template_media_list(path):
|
||||
ijb_meta = pd.read_csv(path, sep=' ', header=None).values
|
||||
templates = ijb_meta[:, 1].astype(np.int)
|
||||
media = ijb_meta[:, 2].astype(np.int)
|
||||
return templates, media
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def read_template_pair_list(path):
|
||||
pairs = pd.read_csv(path, sep=' ', header=None).values
|
||||
t1 = pairs[:, 0].astype(np.int)
|
||||
t2 = pairs[:, 1].astype(np.int)
|
||||
label = pairs[:, 2].astype(np.int)
|
||||
return t1, t2, label
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def read_image_feature(path):
|
||||
with open(path, 'rb') as fid:
|
||||
img_feats = pickle.load(fid)
|
||||
return img_feats
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def get_image_feature(img_path, files_list, _model_path, epoch):
|
||||
'''get_image_feature
|
||||
'''
|
||||
batch_size = args.batch_size
|
||||
data_shape = (3, 112, 112)
|
||||
|
||||
files = files_list
|
||||
print('files:', len(files))
|
||||
rare_size = len(files) % batch_size
|
||||
faceness_scores = []
|
||||
batch = 0
|
||||
img_feats = np.empty((len(files), 1024), dtype=np.float32)
|
||||
|
||||
batch_data = np.empty((2 * batch_size, 3, 112, 112))
|
||||
embedding = Embedding(_model_path, data_shape, batch_size)
|
||||
for img_index, each_line in enumerate(files[:len(files) - rare_size]):
|
||||
name_lmk_score = each_line.strip().split(' ')
|
||||
img_name = os.path.join(img_path, name_lmk_score[0])
|
||||
img = cv2.imread(img_name)
|
||||
lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
|
||||
dtype=np.float32)
|
||||
lmk = lmk.reshape((5, 2))
|
||||
input_blob = embedding.get(img, lmk)
|
||||
|
||||
batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
|
||||
batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
|
||||
if (img_index + 1) % batch_size == 0:
|
||||
print('batch', batch)
|
||||
img_feats[batch * batch_size:batch * batch_size +
|
||||
batch_size][:] = embedding.forward_db(batch_data)
|
||||
batch += 1
|
||||
faceness_scores.append(name_lmk_score[-1])
|
||||
|
||||
batch_data = np.empty((2 * rare_size, 3, 112, 112))
|
||||
embedding = Embedding(_model_path, data_shape, rare_size)
|
||||
for img_index, each_line in enumerate(files[len(files) - rare_size:]):
|
||||
name_lmk_score = each_line.strip().split(' ')
|
||||
img_name = os.path.join(img_path, name_lmk_score[0])
|
||||
img = cv2.imread(img_name)
|
||||
lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
|
||||
dtype=np.float32)
|
||||
lmk = lmk.reshape((5, 2))
|
||||
input_blob = embedding.get(img, lmk)
|
||||
batch_data[2 * img_index][:] = input_blob[0]
|
||||
batch_data[2 * img_index + 1][:] = input_blob[1]
|
||||
if (img_index + 1) % rare_size == 0:
|
||||
print('batch', batch)
|
||||
img_feats[len(files) -
|
||||
rare_size:][:] = embedding.forward_db(batch_data)
|
||||
batch += 1
|
||||
faceness_scores.append(name_lmk_score[-1])
|
||||
faceness_scores = np.array(faceness_scores).astype(np.float32)
|
||||
return img_feats, faceness_scores
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def image2template_feature(img_feats=None, templates=None, media=None):
|
||||
'''image2template_feature
|
||||
'''
|
||||
# ==========================================================
|
||||
# 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
|
||||
# 2. compute media feature.
|
||||
# 3. compute template feature.
|
||||
# ==========================================================
|
||||
unique_templates = np.unique(templates)
|
||||
template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
|
||||
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
|
||||
(ind_t,) = np.where(templates == uqt)
|
||||
face_norm_feats = img_feats[ind_t]
|
||||
face_media = media[ind_t]
|
||||
unique_media, unique_media_counts = np.unique(face_media,
|
||||
return_counts=True)
|
||||
media_norm_feats = []
|
||||
for u, ct in zip(unique_media, unique_media_counts):
|
||||
(ind_m,) = np.where(face_media == u)
|
||||
if ct == 1:
|
||||
media_norm_feats += [face_norm_feats[ind_m]]
|
||||
else: # image features from the same video will be aggregated into one feature
|
||||
media_norm_feats += [
|
||||
np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
|
||||
]
|
||||
media_norm_feats = np.array(media_norm_feats)
|
||||
template_feats[count_template] = np.sum(media_norm_feats, axis=0)
|
||||
if count_template % 2000 == 0:
|
||||
print('Finish Calculating {} template features.'.format(
|
||||
count_template))
|
||||
template_norm_feats = sklearn.preprocessing.normalize(template_feats)
|
||||
return template_norm_feats, unique_templates
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def verification(template_norm_feats=None,
|
||||
unique_templates=None,
|
||||
p1=None,
|
||||
p2=None):
|
||||
'''verification
|
||||
'''
|
||||
# ==========================================================
|
||||
# Compute set-to-set Similarity Score.
|
||||
# ==========================================================
|
||||
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
template2id[uqt] = count_template
|
||||
|
||||
score = np.zeros((len(p1),)) # save cosine distance between pairs
|
||||
|
||||
total_pairs = np.array(range(len(p1)))
|
||||
# small batchsize instead of all pairs in one batch due to the memory limiation
|
||||
batchsize = 100000
|
||||
sublists = [
|
||||
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
||||
]
|
||||
total_sublists = len(sublists)
|
||||
for c, s in enumerate(sublists):
|
||||
feat1 = template_norm_feats[template2id[p1[s]]]
|
||||
feat2 = template_norm_feats[template2id[p2[s]]]
|
||||
similarity_score = np.sum(feat1 * feat2, -1)
|
||||
score[s] = similarity_score.flatten()
|
||||
if c % 10 == 0:
|
||||
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
||||
return score
|
||||
|
||||
|
||||
# In[ ]:
|
||||
def verification2(template_norm_feats=None,
|
||||
unique_templates=None,
|
||||
p1=None,
|
||||
p2=None):
|
||||
'''verification2
|
||||
'''
|
||||
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
template2id[uqt] = count_template
|
||||
score = np.zeros((len(p1),)) # save cosine distance between pairs
|
||||
total_pairs = np.array(range(len(p1)))
|
||||
# small batchsize instead of all pairs in one batch due to the memory limiation
|
||||
batchsize = 100000
|
||||
sublists = [
|
||||
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
||||
]
|
||||
total_sublists = len(sublists)
|
||||
for c, s in enumerate(sublists):
|
||||
feat1 = template_norm_feats[template2id[p1[s]]]
|
||||
feat2 = template_norm_feats[template2id[p2[s]]]
|
||||
similarity_score = np.sum(feat1 * feat2, -1)
|
||||
score[s] = similarity_score.flatten()
|
||||
if c % 10 == 0:
|
||||
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
||||
return score
|
||||
|
||||
|
||||
def read_score(path):
|
||||
with open(path, 'rb') as fid:
|
||||
img_feats = pickle.load(fid)
|
||||
return img_feats
|
||||
|
||||
|
||||
def main():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_id=0)
|
||||
# # Step1: Load Meta Data
|
||||
|
||||
# In[ ]:
|
||||
|
||||
assert target in ('IJBC', 'IJBB')
|
||||
|
||||
# =============================================================
|
||||
# load image and template relationships for template feature embedding
|
||||
# tid --> template id, mid --> media id
|
||||
# format:
|
||||
# image_name tid mid
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
templates, media = read_template_media_list(
|
||||
os.path.join('%s/meta' % image_path,
|
||||
'%s_face_tid_mid.txt' % target.lower()))
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# load template pairs for template-to-template verification
|
||||
# tid : template id, label : 1/0
|
||||
# format:
|
||||
# tid_1 tid_2 label
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
p1, p2, label = read_template_pair_list(
|
||||
os.path.join('%s/meta' % image_path,
|
||||
'%s_template_pair_label.txt' % target.lower()))
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# # Step 2: Get Image Features
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# load image features
|
||||
# format:
|
||||
# img_feats: [image_num x feats_dim] (227630, 512)
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
img_path = '%s/loose_crop' % image_path
|
||||
img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
|
||||
img_list = open(img_list_path)
|
||||
files = img_list.readlines()
|
||||
files_list = files
|
||||
|
||||
# img_feats
|
||||
img_feats, faceness_scores = get_image_feature(img_path, files_list,
|
||||
model_path, 0)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
|
||||
img_feats.shape[1]))
|
||||
|
||||
# # Step3: Get Template Features
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# compute template features from image features.
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
# ==========================================================
|
||||
# Norm feature before aggregation into template feature?
|
||||
# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
|
||||
# ==========================================================
|
||||
# 1. FaceScore (Feature Norm)
|
||||
# 2. FaceScore (Detector)
|
||||
|
||||
if use_flip_test:
|
||||
# concat --- F1
|
||||
# img_input_feats = img_feats
|
||||
# add --- F2
|
||||
img_input_feats = img_feats[:, 0:img_feats.shape[1] //
|
||||
2] + img_feats[:, img_feats.shape[1] // 2:]
|
||||
else:
|
||||
img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
|
||||
|
||||
if use_norm_score:
|
||||
img_input_feats = img_input_feats
|
||||
else:
|
||||
# normalise features to remove norm information
|
||||
img_input_feats = img_input_feats / np.sqrt(
|
||||
np.sum(img_input_feats ** 2, -1, keepdims=True))
|
||||
|
||||
if use_detector_score:
|
||||
print(img_input_feats.shape, faceness_scores.shape)
|
||||
img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
|
||||
else:
|
||||
img_input_feats = img_input_feats
|
||||
|
||||
template_norm_feats, unique_templates = image2template_feature(
|
||||
img_input_feats, templates, media)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# # Step 4: Get Template Similarity Scores
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# compute verification scores between template pairs.
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
score = verification(template_norm_feats, unique_templates, p1, p2)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# In[ ]:
|
||||
save_path = os.path.join(result_dir, args.job)
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
|
||||
np.save(score_save_file, score)
|
||||
|
||||
# # Step 5: Get ROC Curves and TPR@FPR Table
|
||||
|
||||
# In[ ]:
|
||||
|
||||
files = [score_save_file]
|
||||
methods = []
|
||||
scores = []
|
||||
for file in files:
|
||||
methods.append(Path(file).stem)
|
||||
scores.append(np.load(file))
|
||||
|
||||
methods = np.array(methods)
|
||||
scores = dict(zip(methods, scores))
|
||||
colours = dict(
|
||||
zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
|
||||
x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
|
||||
tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
||||
fig = plt.figure()
|
||||
for method in methods:
|
||||
fpr, tpr, _ = roc_curve(label, scores[method])
|
||||
roc_auc = auc(fpr, tpr)
|
||||
fpr = np.flipud(fpr)
|
||||
tpr = np.flipud(tpr) # select largest tpr at same fpr
|
||||
plt.plot(fpr,
|
||||
tpr,
|
||||
color=colours[method],
|
||||
lw=1,
|
||||
label=('[%s (AUC = %0.4f %%)]' %
|
||||
(method.split('-')[-1], roc_auc * 100)))
|
||||
tpr_fpr_row = []
|
||||
tpr_fpr_row.append("%s-%s" % (method, target))
|
||||
for fpr_iter in np.arange(len(x_labels)):
|
||||
_, min_index = min(
|
||||
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
||||
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
||||
tpr_fpr_table.add_row(tpr_fpr_row)
|
||||
plt.xlim([10 ** -6, 0.1])
|
||||
plt.ylim([0.3, 1.0])
|
||||
plt.grid(linestyle='--', linewidth=1)
|
||||
plt.xticks(x_labels)
|
||||
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
||||
plt.xscale('log')
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.title('ROC on IJB')
|
||||
plt.legend(loc="lower right")
|
||||
fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower()))
|
||||
print(tpr_fpr_table)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,63 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh RANK_SIZE DATA_PATH"
|
||||
echo "For example: bash run.sh 8 path/dataset"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
RANK_SIZE=$1
|
||||
DATA_PATH=$2
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
test_dist_8pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
||||
export RANK_SIZE=8
|
||||
}
|
||||
|
||||
test_dist_2pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
||||
echo "$RANK_TABLE_FILE"
|
||||
export RANK_SIZE=2
|
||||
}
|
||||
|
||||
test_dist_${RANK_SIZE}pcs
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cp -r ./src/ ./device$i
|
||||
cp train.py ./device$i
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
python train.py \
|
||||
--data_url $DATA_PATH \
|
||||
--device_num RANK_SIZE \
|
||||
> train.log$i 2>&1 &
|
||||
cd ../
|
||||
done
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh EVAL_PATH CKPT_PATH"
|
||||
echo "For example: bash run.sh path/evalset path/ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EVAL_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
|
||||
python val.py \
|
||||
--ckpt_url "$CKPT_PATH" \
|
||||
--device_id 1 \
|
||||
--eval_url "$EVAL_PATH" \
|
||||
--target lfw,cfp_fp,agedb_30,calfw,cplfw \
|
||||
> eval.log 2>&1 &
|
|
@ -0,0 +1,34 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh EVAL_PATH CKPT_PATH"
|
||||
echo "For example: bash run.sh path/evalset path/ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EVAL_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
|
||||
python eval_ijbc.py \
|
||||
--model-prefix "$CKPT_PATH" \
|
||||
--image-path "$EVAL_PATH" \
|
||||
--result-dir ms1mv2_arcface_r100 \
|
||||
--batch-size 128 \
|
||||
--job ms1mv2_arcface_r100 \
|
||||
--target IJBC \
|
||||
--network iresnet100 \
|
||||
> eval.log 2>&1 &
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH"
|
||||
echo "For example: bash run.sh path/MS1M"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
# shellcheck disable=SC2034
|
||||
DATA_PATH=$1
|
||||
|
||||
python train.py \
|
||||
--data_url DATA_PATH \
|
||||
--device_num 1 \
|
||||
> train.log 2>&1 &
|
|
@ -0,0 +1,63 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh RANK_SIZE DATA_PATH"
|
||||
echo "For example: bash run.sh 8 path/dataset"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
RANK_SIZE=$1
|
||||
DATA_PATH=$2
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
test_dist_8pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
||||
export RANK_SIZE=8
|
||||
}
|
||||
|
||||
test_dist_2pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
||||
echo "$RANK_TABLE_FILE"
|
||||
export RANK_SIZE=2
|
||||
}
|
||||
|
||||
test_dist_${RANK_SIZE}pcs
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cp -r ./src/ ./device$i
|
||||
cp train.py ./device$i
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
python train.py \
|
||||
--data_url $DATA_PATH \
|
||||
--device_num $RANK_SIZE \
|
||||
> train.log$i 2>&1 &
|
||||
cd ../
|
||||
done
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh EVAL_PATH CKPT_PATH"
|
||||
echo "For example: bash run.sh path/evalset path/ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EVAL_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
|
||||
python val.py \
|
||||
--ckpt_url "$CKPT_PATH" \
|
||||
--device_id 1 \
|
||||
--eval_url "$EVAL_PATH" \
|
||||
--target lfw,cfp_fp,agedb_30,calfw,cplfw \
|
||||
> eval.log 2>&1 &
|
|
@ -0,0 +1,34 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh EVAL_PATH CKPT_PATH"
|
||||
echo "For example: bash run.sh path/evalset path/ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EVAL_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
|
||||
python eval_ijbc.py \
|
||||
--model-prefix "$CKPT_PATH" \
|
||||
--image-path "$EVAL_PATH" \
|
||||
--result-dir ms1mv2_arcface_r100 \
|
||||
--batch-size 128 \
|
||||
--job ms1mv2_arcface_r100 \
|
||||
--target IJBC \
|
||||
--network iresnet100 \
|
||||
> eval.log 2>&1 &
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH"
|
||||
echo "For example: bash run.sh path/MS1M"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
# shellcheck disable=SC2034
|
||||
DATA_PATH=$1
|
||||
|
||||
python train.py \
|
||||
--data_url DATA_PATH \
|
||||
--device_num 1 \
|
||||
> train.log 2>&1 &
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python dataset.py
|
||||
"""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
"""
|
||||
create a train dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init("nccl")
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDataset(
|
||||
dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 112
|
||||
mean = [0.5 * 255, 0.5 * 255, 0.5 * 255]
|
||||
std = [0.5 * 255, 0.5 * 255, 0.5 * 255]
|
||||
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
# C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
C.Decode(),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(image_size),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
ds = ds.map(input_columns="image",
|
||||
num_parallel_workers=8, operations=trans)
|
||||
ds = ds.map(input_columns="label", num_parallel_workers=8,
|
||||
operations=type_cast_op)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = int(os.environ.get("RANK_SIZE"))
|
||||
rank_id = int(os.environ.get("RANK_ID"))
|
||||
else:
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.iresnet import iresnet100
|
||||
|
||||
parser = argparse.ArgumentParser(description='Classification')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="ibnnet", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
parser.add_argument('--dataset_name', type=str, default='MS1MV2', choices=['MS1MV2'],
|
||||
help='dataset name.')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.dataset_name != 'MS1MV2':
|
||||
raise ValueError("dataset is not supported.")
|
||||
|
||||
net = iresnet100()
|
||||
|
||||
assert args.ckpt_file is not None, "args.ckpt_file is None."
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.ones([args.batch_size, 3, 112, 112]), mstype.float32)
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python iresnet.py
|
||||
"""
|
||||
from mindspore import nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100']
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
pad_mode='pad',
|
||||
group=groups,
|
||||
has_bias=False,
|
||||
dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
has_bias=False)
|
||||
|
||||
|
||||
class IBasicBlock(nn.Cell):
|
||||
'''IBasicBlock
|
||||
'''
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
groups=1, base_width=64, dilation=1):
|
||||
super(IBasicBlock, self).__init__()
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
"Dilation > 1 not supported in BasicBlock")
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(
|
||||
inplanes,
|
||||
eps=1e-05,
|
||||
)
|
||||
self.conv1 = conv3x3(inplanes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(
|
||||
planes,
|
||||
eps=1e-05,
|
||||
)
|
||||
self.prelu = nn.PReLU(planes) # ?
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bn3 = nn.BatchNorm2d(
|
||||
planes,
|
||||
eps=1e-05,
|
||||
)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def construct(self, x):
|
||||
'''construct
|
||||
'''
|
||||
identity = x
|
||||
|
||||
out = self.bn1(x)
|
||||
out = self.conv1(out)
|
||||
out = self.bn2(out)
|
||||
out = self.prelu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class IResNet(nn.Cell):
|
||||
'''IResNet
|
||||
'''
|
||||
fc_scale = 7 * 7
|
||||
|
||||
def __init__(self,
|
||||
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None):
|
||||
super(IResNet, self).__init__()
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3,
|
||||
stride=1, padding=1, pad_mode='pad', has_bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
||||
self.prelu = nn.PReLU(self.inplanes)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
||||
self.layer2 = self._make_layer(block,
|
||||
128,
|
||||
layers[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block,
|
||||
256,
|
||||
layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block,
|
||||
512,
|
||||
layers[3],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
|
||||
self.dropout = nn.Dropout(keep_prob=1.0-dropout)
|
||||
self.fc = nn.Dense(512 * block.expansion * self.fc_scale,
|
||||
num_features)
|
||||
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
||||
self.features.gamma.requires_grad = False
|
||||
self.reshape = ops.Reshape()
|
||||
self.flatten = ops.Flatten()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
'''make_layer
|
||||
'''
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.SequentialCell([
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
nn.BatchNorm2d(planes * block.expansion, eps=1e-05)
|
||||
])
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes,
|
||||
planes,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation))
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
'''construct
|
||||
'''
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.prelu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.bn2(x)
|
||||
x = self.flatten(x)
|
||||
# b, c, _, _ = x.shape
|
||||
# x = self.reshape(x, (b, -1))
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x)
|
||||
x = self.features(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = IResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
raise ValueError()
|
||||
return model
|
||||
|
||||
|
||||
def iresnet18(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def iresnet34(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def iresnet50(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def iresnet100(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
|
||||
progress, **kwargs)
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python loss.py
|
||||
"""
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Parameter
|
||||
import mindspore.ops as ops
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
|
||||
class ArcFace(nn.Cell):
|
||||
'''
|
||||
Arcface loss
|
||||
'''
|
||||
def __init__(self, world_size, s=64.0, m=0.5):
|
||||
super(ArcFace, self).__init__()
|
||||
self.s = s
|
||||
self.shape = ops.Shape()
|
||||
self.mul = ops.Mul()
|
||||
self.cos = ops.Cos()
|
||||
self.acos = ops.ACos()
|
||||
self.onehot = ops.OneHot().shard(((1, world_size), (), ()))
|
||||
# self.tile = ops.Tile().shard(((8, 1),))
|
||||
self.on_value = Tensor(m, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, cosine, label):
|
||||
m_hot = self.onehot(label, self.shape(
|
||||
cosine)[1], self.on_value, self.off_value)
|
||||
|
||||
cosine = self.acos(cosine)
|
||||
cosine += m_hot
|
||||
cosine = self.cos(cosine)
|
||||
cosine = self.mul(cosine, self.s)
|
||||
return cosine
|
||||
|
||||
|
||||
class SoftMaxCE(nn.Cell):
|
||||
'''
|
||||
softmax cross entrophy
|
||||
'''
|
||||
def __init__(self, world_size):
|
||||
super(SoftMaxCE, self).__init__()
|
||||
self.max = ops.ReduceMax(keep_dims=True)
|
||||
self.sum = ops.ReduceSum(keep_dims=True)
|
||||
self.mean = ops.ReduceMean(keep_dims=False)
|
||||
self.exp = ops.Exp()
|
||||
self.div = ops.Div()
|
||||
self.onehot = ops.OneHot().shard(((1, world_size), (), ()))
|
||||
self.mul = ops.Mul()
|
||||
self.log = ops.Log()
|
||||
self.onvalue = Tensor(1.0, mstype.float32)
|
||||
self.offvalue = Tensor(0.0, mstype.float32)
|
||||
self.eps = Tensor(1e-30, mstype.float32)
|
||||
|
||||
def construct(self, logits, total_label):
|
||||
'''construct
|
||||
'''
|
||||
max_fc = self.max(logits, 1)
|
||||
|
||||
logits_exp = self.exp(logits - max_fc)
|
||||
logits_sum_exp = self.sum(logits_exp, 1)
|
||||
|
||||
logits_exp = self.div(logits_exp, logits_sum_exp)
|
||||
|
||||
label = self.onehot(total_label, F.shape(
|
||||
logits)[1], self.onvalue, self.offvalue)
|
||||
|
||||
softmax_result_log = self.log(logits_exp + self.eps)
|
||||
loss = self.sum((self.mul(softmax_result_log, label)), -1)
|
||||
loss = self.mul(ops.scalar_to_array(-1.0), loss)
|
||||
loss_v = self.mean(loss, 0)
|
||||
|
||||
return loss_v
|
||||
|
||||
|
||||
class PartialFC(nn.Cell):
|
||||
'''partialFC
|
||||
'''
|
||||
def __init__(self, num_classes, world_size):
|
||||
super(PartialFC, self).__init__()
|
||||
self.L2Norm = ops.L2Normalize(axis=1)
|
||||
self.weight = Parameter(initializer(
|
||||
"normal", (num_classes, 512)), name="mp_weight")
|
||||
self.sub_weight = self.weight
|
||||
self.linear = ops.MatMul(transpose_b=True).shard(
|
||||
((1, 1), (world_size, 1)))
|
||||
self.margin_softmax = ArcFace(world_size=world_size)
|
||||
self.loss = SoftMaxCE(world_size=world_size)
|
||||
|
||||
def construct(self, features, label):
|
||||
total_label, norm_weight = self.prepare(label)
|
||||
total_features = self.L2Norm(features)
|
||||
logits = self.forward(total_features, norm_weight)
|
||||
logits = self.margin_softmax(logits, total_label)
|
||||
loss_v = self.loss(logits, total_label)
|
||||
return loss_v
|
||||
|
||||
def forward(self, total_features, norm_weight):
|
||||
logits = self.linear(F.cast(total_features, mstype.float16), F.cast(
|
||||
norm_weight, mstype.float16))
|
||||
return F.cast(logits, mstype.float32)
|
||||
|
||||
def prepare(self, label):
|
||||
total_label = label
|
||||
norm_weight = self.L2Norm(self.sub_weight)
|
||||
return total_label, norm_weight
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python train.py
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
import mindspore.ops as ops
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.parallel import _cost_model_context as cost_model_context
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.iresnet import iresnet100
|
||||
from src.loss import PartialFC
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
|
||||
# Datasets
|
||||
parser.add_argument('--train_url', default='.', type=str,
|
||||
help='output path')
|
||||
parser.add_argument('--data_url', default='data path', type=str)
|
||||
# Optimization options
|
||||
parser.add_argument('--epochs', default=25, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--num_classes', default=85742, type=int, metavar='N',
|
||||
help='num of classes')
|
||||
parser.add_argument('--batch_size', default=64, type=int, metavar='N',
|
||||
help='train batchsize (default: 256)')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.08, type=float,
|
||||
metavar='LR', help='initial learning rate')
|
||||
parser.add_argument('--schedule', type=int, nargs='+', default=[10, 16, 21],
|
||||
help='Decrease learning rate at these epochs.')
|
||||
parser.add_argument('--gamma', type=float, default=0.1,
|
||||
help='LR is multiplied by gamma on schedule.')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)')
|
||||
# Device options
|
||||
parser.add_argument('--device_target', type=str,
|
||||
default='Ascend', choices=['GPU', 'Ascend', 'CPU'])
|
||||
parser.add_argument('--device_num', type=int, default=8)
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
parser.add_argument('--modelarts', type=bool, default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def lr_generator(lr_init, total_epochs, steps_per_epoch):
|
||||
'''lr_generator
|
||||
'''
|
||||
lr_each_step = []
|
||||
for i in range(total_epochs):
|
||||
if i in args.schedule:
|
||||
lr_init *= args.gamma
|
||||
for _ in range(steps_per_epoch):
|
||||
lr_each_step.append(lr_init)
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return Tensor(lr_each_step)
|
||||
|
||||
|
||||
class MyNetWithLoss(nn.Cell):
|
||||
'''
|
||||
WithLossCell
|
||||
'''
|
||||
def __init__(self, backbone, cfg):
|
||||
super(MyNetWithLoss, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone.to_float(mstype.float16)
|
||||
self._loss_fn = PartialFC(num_classes=cfg.num_classes,
|
||||
world_size=cfg.device_num).to_float(mstype.float32)
|
||||
self.L2Norm = ops.L2Normalize(axis=1)
|
||||
|
||||
def construct(self, data, label):
|
||||
out = self._backbone(data)
|
||||
loss = self._loss_fn(out, label)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_epoch = args.epochs
|
||||
target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=target, save_graphs=False)
|
||||
if args.device_num > 1:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
if args.device_num > 1:
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
)
|
||||
cost_model_context.set_cost_model_context(device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
|
||||
costmodel_gamma=0.001,
|
||||
costmodel_beta=280.0)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
init()
|
||||
|
||||
if args.modelarts:
|
||||
import moxing as mox
|
||||
|
||||
mox.file.copy_parallel(
|
||||
src_url=args.data_url, dst_url='/cache/data_path_' + os.getenv('DEVICE_ID'))
|
||||
zip_command = "unzip -o -q /cache/data_path_" + os.getenv('DEVICE_ID') \
|
||||
+ "/MS1M.zip -d /cache/data_path_" + \
|
||||
os.getenv('DEVICE_ID')
|
||||
os.system(zip_command)
|
||||
train_dataset = create_dataset(dataset_path='/cache/data_path_' + os.getenv('DEVICE_ID') + '/MS1M/',
|
||||
do_train=True,
|
||||
repeat_num=1, batch_size=args.batch_size, target=target)
|
||||
else:
|
||||
train_dataset = create_dataset(dataset_path=args.data_url, do_train=True,
|
||||
repeat_num=1, batch_size=args.batch_size, target=target)
|
||||
step = train_dataset.get_dataset_size()
|
||||
lr = lr_generator(args.lr, train_epoch, steps_per_epoch=step)
|
||||
net = iresnet100()
|
||||
train_net = MyNetWithLoss(net, args)
|
||||
optimizer = nn.SGD(params=train_net.trainable_params(), learning_rate=lr / 512 * args.batch_size * args.device_num,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
|
||||
model = Model(train_net, optimizer=optimizer)
|
||||
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=60, keep_checkpoint_max=20)
|
||||
if args.modelarts:
|
||||
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
||||
directory='/cache/train_output/')
|
||||
else:
|
||||
ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
|
||||
directory=args.train_url)
|
||||
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
|
||||
loss_cb = LossMonitor()
|
||||
cb = [ckpt_cb, time_cb, loss_cb]
|
||||
if args.device_id == 0 or args.device_num == 1:
|
||||
model.train(train_epoch, train_dataset,
|
||||
callbacks=cb, dataset_sink_mode=True)
|
||||
else:
|
||||
model.train(train_epoch, train_dataset, dataset_sink_mode=True)
|
||||
if args.modelarts:
|
||||
mox.file.copy_parallel(
|
||||
src_url='/cache/train_output', dst_url=args.train_url)
|
|
@ -0,0 +1,339 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
evaluation of lfw, calfw, cfp_fp, agedb_30, cplfw
|
||||
'''
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import sklearn
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.model_selection import KFold
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy import interpolate
|
||||
import mindspore as ms
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import context
|
||||
|
||||
from src.iresnet import iresnet100
|
||||
|
||||
|
||||
class LFold:
|
||||
'''
|
||||
LFold
|
||||
'''
|
||||
def __init__(self, n_splits=2, shuffle=False):
|
||||
self.n_splits = n_splits
|
||||
if self.n_splits > 1:
|
||||
self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
|
||||
|
||||
def split(self, indices):
|
||||
if self.n_splits > 1:
|
||||
return self.k_fold.split(indices)
|
||||
return [(indices, indices)]
|
||||
|
||||
|
||||
def calculate_roc(thresholds,
|
||||
embeddings1,
|
||||
embeddings2,
|
||||
actual_issame,
|
||||
nrof_folds=10,
|
||||
pca=0):
|
||||
'''
|
||||
calculate_roc
|
||||
'''
|
||||
assert embeddings1.shape[0] == embeddings2.shape[0]
|
||||
assert embeddings1.shape[1] == embeddings2.shape[1]
|
||||
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
|
||||
nrof_thresholds = len(thresholds)
|
||||
k_fold = LFold(n_splits=nrof_folds, shuffle=False)
|
||||
|
||||
tprs = np.zeros((nrof_folds, nrof_thresholds))
|
||||
fprs = np.zeros((nrof_folds, nrof_thresholds))
|
||||
accuracy = np.zeros((nrof_folds))
|
||||
indices = np.arange(nrof_pairs)
|
||||
|
||||
if pca == 0:
|
||||
diff = np.subtract(embeddings1, embeddings2)
|
||||
dist = np.sum(np.square(diff), 1)
|
||||
|
||||
for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
|
||||
if pca > 0:
|
||||
print('doing pca on', fold_idx)
|
||||
embed1_train = embeddings1[train_set]
|
||||
embed2_train = embeddings2[train_set]
|
||||
_embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
|
||||
pca_model = PCA(n_components=pca)
|
||||
pca_model.fit(_embed_train)
|
||||
embed1 = pca_model.transform(embeddings1)
|
||||
embed2 = pca_model.transform(embeddings2)
|
||||
embed1 = sklearn.preprocessing.normalize(embed1)
|
||||
embed2 = sklearn.preprocessing.normalize(embed2)
|
||||
diff = np.subtract(embed1, embed2)
|
||||
dist = np.sum(np.square(diff), 1)
|
||||
|
||||
# Find the best threshold for the fold
|
||||
acc_train = np.zeros((nrof_thresholds))
|
||||
for threshold_idx, threshold in enumerate(thresholds):
|
||||
_, _, acc_train[threshold_idx] = calculate_accuracy(
|
||||
threshold, dist[train_set], actual_issame[train_set])
|
||||
best_threshold_index = np.argmax(acc_train)
|
||||
for threshold_idx, threshold in enumerate(thresholds):
|
||||
tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
|
||||
threshold, dist[test_set],
|
||||
actual_issame[test_set])
|
||||
_, _, accuracy[fold_idx] = calculate_accuracy(
|
||||
thresholds[best_threshold_index], dist[test_set],
|
||||
actual_issame[test_set])
|
||||
|
||||
tpr = np.mean(tprs, 0)
|
||||
fpr = np.mean(fprs, 0)
|
||||
return tpr, fpr, accuracy
|
||||
|
||||
|
||||
def calculate_accuracy(threshold, dist, actual_issame):
|
||||
'''calculate_acc
|
||||
'''
|
||||
predict_issame = np.less(dist, threshold)
|
||||
tp = np.sum(np.logical_and(predict_issame, actual_issame))
|
||||
fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
|
||||
tn = np.sum(
|
||||
np.logical_and(np.logical_not(predict_issame),
|
||||
np.logical_not(actual_issame)))
|
||||
fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
|
||||
|
||||
tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
|
||||
fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
|
||||
acc = float(tp + tn) / dist.size
|
||||
return tpr, fpr, acc
|
||||
|
||||
|
||||
def calculate_val(thresholds,
|
||||
embeddings1,
|
||||
embeddings2,
|
||||
actual_issame,
|
||||
far_target,
|
||||
nrof_folds=10):
|
||||
'''
|
||||
calculate_val
|
||||
'''
|
||||
assert embeddings1.shape[0] == embeddings2.shape[0]
|
||||
assert embeddings1.shape[1] == embeddings2.shape[1]
|
||||
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
|
||||
nrof_thresholds = len(thresholds)
|
||||
k_fold = LFold(n_splits=nrof_folds, shuffle=False)
|
||||
|
||||
val = np.zeros(nrof_folds)
|
||||
far = np.zeros(nrof_folds)
|
||||
|
||||
diff = np.subtract(embeddings1, embeddings2)
|
||||
dist = np.sum(np.square(diff), 1)
|
||||
indices = np.arange(nrof_pairs)
|
||||
|
||||
for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
|
||||
|
||||
# Find the threshold that gives FAR = far_target
|
||||
far_train = np.zeros(nrof_thresholds)
|
||||
for threshold_idx, threshold in enumerate(thresholds):
|
||||
_, far_train[threshold_idx] = calculate_val_far(
|
||||
threshold, dist[train_set], actual_issame[train_set])
|
||||
if np.max(far_train) >= far_target:
|
||||
f = interpolate.interp1d(far_train, thresholds, kind='slinear')
|
||||
threshold = f(far_target)
|
||||
else:
|
||||
threshold = 0.0
|
||||
|
||||
val[fold_idx], far[fold_idx] = calculate_val_far(
|
||||
threshold, dist[test_set], actual_issame[test_set])
|
||||
|
||||
val_mean = np.mean(val)
|
||||
far_mean = np.mean(far)
|
||||
val_std = np.std(val)
|
||||
return val_mean, val_std, far_mean
|
||||
|
||||
|
||||
def calculate_val_far(threshold, dist, actual_issame):
|
||||
'''calculate_val_far
|
||||
'''
|
||||
predict_issame = np.less(dist, threshold)
|
||||
true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
|
||||
false_accept = np.sum(
|
||||
np.logical_and(predict_issame, np.logical_not(actual_issame)))
|
||||
n_same = np.sum(actual_issame)
|
||||
n_diff = np.sum(np.logical_not(actual_issame))
|
||||
val = float(true_accept) / float(n_same)
|
||||
far = float(false_accept) / float(n_diff)
|
||||
return val, far
|
||||
|
||||
|
||||
def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
|
||||
'''evaluate
|
||||
'''
|
||||
# Calculate evaluation metrics
|
||||
thresholds = np.arange(0, 4, 0.01)
|
||||
embeddings1 = embeddings[0::2]
|
||||
embeddings2 = embeddings[1::2]
|
||||
tpr, fpr, accuracy = calculate_roc(thresholds,
|
||||
embeddings1,
|
||||
embeddings2,
|
||||
np.asarray(actual_issame),
|
||||
nrof_folds=nrof_folds,
|
||||
pca=pca)
|
||||
thresholds = np.arange(0, 4, 0.001)
|
||||
val, val_std, far = calculate_val(thresholds,
|
||||
embeddings1,
|
||||
embeddings2,
|
||||
np.asarray(actual_issame),
|
||||
1e-3,
|
||||
nrof_folds=nrof_folds)
|
||||
return tpr, fpr, accuracy, val, val_std, far
|
||||
|
||||
|
||||
def load_bin(path, image_size):
|
||||
'''load evalset of .bin
|
||||
'''
|
||||
try:
|
||||
with open(path, 'rb') as f:
|
||||
bins, issame_list = pickle.load(f) # py2
|
||||
except UnicodeDecodeError as _:
|
||||
with open(path, 'rb') as f:
|
||||
bins, issame_list = pickle.load(f, encoding='bytes') # py3
|
||||
data_list = []
|
||||
for _ in [0, 1]:
|
||||
data = np.zeros(
|
||||
(len(issame_list) * 2, 3, image_size[0], image_size[1]))
|
||||
data_list.append(data)
|
||||
for idx in range(len(issame_list) * 2):
|
||||
_bin = bins[idx]
|
||||
img = plt.imread(BytesIO(_bin), "jpg")
|
||||
if img.shape[1] != image_size[0]:
|
||||
img = mx.image.resize_short(img, image_size[0])
|
||||
img = np.transpose(img, axes=(2, 0, 1))
|
||||
for flip in [0, 1]:
|
||||
if flip == 1:
|
||||
img = np.flip(img, axis=2)
|
||||
data_list[flip][idx][:] = img
|
||||
return data_list, issame_list
|
||||
|
||||
|
||||
def test(data_set, backbone, batch_size, nfolds=10):
|
||||
'''test
|
||||
'''
|
||||
print('testing verification..')
|
||||
data_list = data_set[0]
|
||||
issame_list = data_set[1]
|
||||
embeddings_list = []
|
||||
time_consumed = 0.0
|
||||
for data in data_list:
|
||||
embeddings = None
|
||||
ba = 0
|
||||
while ba < data.shape[0]:
|
||||
bb = min(ba + batch_size, data.shape[0])
|
||||
count = bb - ba
|
||||
_data = data[bb - batch_size: bb]
|
||||
|
||||
time0 = datetime.datetime.now()
|
||||
img = ((_data / 255) - 0.5) / 0.5
|
||||
net_out = backbone(ms.Tensor(img, ms.float32))
|
||||
_embeddings = net_out.asnumpy()
|
||||
time_now = datetime.datetime.now()
|
||||
diff = time_now - time0
|
||||
time_consumed += diff.total_seconds()
|
||||
if embeddings is None:
|
||||
embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
|
||||
embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
|
||||
ba = bb
|
||||
embeddings_list.append(embeddings)
|
||||
_xnorm = 0.0
|
||||
_xnorm_cnt = 0
|
||||
for embed in embeddings_list:
|
||||
for i in range(embed.shape[0]):
|
||||
_em = embed[i]
|
||||
_norm = np.linalg.norm(_em)
|
||||
_xnorm += _norm
|
||||
_xnorm_cnt += 1
|
||||
_xnorm /= _xnorm_cnt
|
||||
|
||||
embeddings = embeddings_list[0].copy()
|
||||
embeddings = sklearn.preprocessing.normalize(embeddings)
|
||||
_, _, acc, _, _, _ = evaluate(embeddings, issame_list, nrof_folds=nfolds)
|
||||
acc1 = np.mean(acc)
|
||||
std1 = np.std(acc)
|
||||
embeddings = embeddings_list[0] + embeddings_list[1]
|
||||
embeddings = sklearn.preprocessing.normalize(embeddings)
|
||||
print(embeddings.shape)
|
||||
print('infer time', time_consumed)
|
||||
_, _, accuracy, _, _, _ = evaluate(
|
||||
embeddings, issame_list, nrof_folds=nfolds)
|
||||
acc2, std2 = np.mean(accuracy), np.std(accuracy)
|
||||
return acc1, std1, acc2, std2, _xnorm, embeddings_list
|
||||
|
||||
|
||||
def main():
|
||||
'''r
|
||||
main function
|
||||
'''
|
||||
parser = argparse.ArgumentParser(description='do verification')
|
||||
# general
|
||||
parser.add_argument(
|
||||
'--eval_url', default='/opt_data/zjc/arcface/', help='')
|
||||
parser.add_argument('--device_id', default=0, type=int, help='device id')
|
||||
parser.add_argument('--target',
|
||||
default='lfw,cfp_fp,agedb_30',
|
||||
help='test targets.')
|
||||
parser.add_argument(
|
||||
'--ckpt_url', default="/cache/ArcFace--25_11372.ckpt", type=str, help='ckpt path')
|
||||
parser.add_argument('--gpu', default=0, type=int, help='gpu id')
|
||||
parser.add_argument('--batch-size', default=64, type=int, help='')
|
||||
parser.add_argument('--max', default='', type=str, help='')
|
||||
parser.add_argument('--nfolds', default=10, type=int, help='')
|
||||
args = parser.parse_args()
|
||||
context.set_context(device_id=args.device_id, mode=context.GRAPH_MODE)
|
||||
image_size = [112, 112]
|
||||
time0 = datetime.datetime.now()
|
||||
model = iresnet100()
|
||||
param_dict = load_checkpoint(args.ckpt_url)
|
||||
load_param_into_net(model, param_dict)
|
||||
time_now = datetime.datetime.now()
|
||||
diff = time_now - time0
|
||||
print('model loading time', diff.total_seconds())
|
||||
|
||||
ver_list = []
|
||||
ver_name_list = []
|
||||
for name in args.target.split(','):
|
||||
path = os.path.join(args.eval_url, name + ".bin")
|
||||
if os.path.exists(path):
|
||||
print('loading.. ', name)
|
||||
data_set = load_bin(path, image_size)
|
||||
ver_list.append(data_set)
|
||||
ver_name_list.append(name)
|
||||
|
||||
length = len(ver_list)
|
||||
for i in range(length):
|
||||
acc1, std1, acc2, std2, xnorm, _ = test(
|
||||
ver_list[i], model, args.batch_size, args.nfolds)
|
||||
print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
|
||||
print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
|
||||
print('[%s]Accuracy-Flip: %1.5f+-%1.5f' %
|
||||
(ver_name_list[i], acc2, std2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue