!19075 Add LightCNN to master
Merge pull request !19075 from 王治坤/test-master
This commit is contained in:
commit
d7a37b877f
|
@ -0,0 +1,445 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [LightCNN描述](#lightcnn描述)
|
||||
- [描述](#描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [特性](#特性)
|
||||
- [混合精度](#混合精度)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [用法](#用法)
|
||||
- [Ascend处理器环境运行](#ascend处理器环境运行)
|
||||
- [结果](#结果)
|
||||
- [Ascend处理器环境运行](#ascend处理器环境运行-1)
|
||||
- [评估过程](#评估过程)
|
||||
- [用法](#用法-1)
|
||||
- [Ascend处理器环境运行](#ascend处理器环境运行-2)
|
||||
- [结果](#结果-1)
|
||||
- [训练准确率](#训练准确率)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# LightCNN描述
|
||||
|
||||
## 描述
|
||||
|
||||
LightCNN适用于有大量噪声的人脸识别数据集,提出了maxout 的变体,命名为Max-Feature-Map (MFM) 。与maxout 使用多个特征图进行任意凸激活函数的线性近似,MFM 使用一种竞争关系选择凸激活函数,可以将噪声与有用的信息分隔开,也可以在两个特征图之间进行特征选择。
|
||||
|
||||
有关网络详细信息,请参阅[论文][1]`Wu, Xiang, et al. "A light cnn for deep face representation with noisy labels." IEEE Transactions on Information Forensics and Security 13.11 (2018): 2884-2896.`
|
||||
|
||||
# 模型架构
|
||||
|
||||
轻量级的CNN网络结构,可在包含大量噪声的训练样本中训练人脸识别任务:
|
||||
|
||||
- 在CNN的每层卷积层中引入了maxout激活概念,得到一个具有少量参数的Max-Feature-Map(MFM)。与ReLU通过阈值或偏置来抑制神经元的做法不同,MFM是通过竞争关系来抑制。不仅可以将噪声信号和有用信号分离开来,还可以在特征选择上起来关键作用。
|
||||
- 该网络基于MFM,有5个卷积层和4个Network in Network(NIN)层。小的卷积核与NIN是为了减少参数,提升性能。
|
||||
- 采用通过预训练模型的一种semantic bootstrapping的方法,提高模型在噪声样本中的稳定性。错误样本可以通过预测的概率被检测出来。
|
||||
- 实验证明该网络可以在包含大量噪声的训练样本中训练轻量级的模型,而单模型输出256维特征向量,在5个人脸测试集上达到state-of-art的效果。且在CPU上速度达到67ms。
|
||||
|
||||
# 数据集
|
||||
|
||||
训练集:微软人脸识别数据库(MS-Celeb-1M)。MS-Celeb-1M原数据集包含800多万张图像,LightCNN原作者提供了一份清洗后的文件清单MS-Celeb-1M_clean_list.txt,共包含79077个人,5049824张人脸图像。原数据集因侵权问题被微软官方删除,提供一个可用的[第三方下载链接][4]。通过该连接下载数据集后,应使用`FaceImageCroppedWithAlignment.tsv`,即对齐后的数据。
|
||||
|
||||
训练集列表:原作者将清洗后的训练列表`MS-Celeb-1M_clean_list.txt`上传至[Baidu Yun][2], [Google Drive][3],以供下载。
|
||||
|
||||
测试集:LFW人脸数据集(Labeled Faces in the Wild)。LFW数据集共包含来自5749个人的13233张人脸图像。LightCNN原作者提供的对齐后的[测试集链接][5]。
|
||||
|
||||
测试集列表:原作者并未提供测试集列表,只能根据原作者给出的测试结果反推测试集列表。首先下载[blufr官方测试文件包][7]和[原作者测试结果][9],将`blufr官方测试文件包`解压的文件夹,与`原作者测试结果--LightenedCNN_B_lfw.mat`、`LightCNN/src/get_list.py(本脚本提供)`放在同一个目录内,运行`python get_list.py`,即可在`LightCNN/src/`下生成`image_list_for_lfw.txt`和`image_list_for_blufr.txt`。
|
||||
|
||||
- 下载训练集、训练集列表、测试集和生成测试集列表。
|
||||
|
||||
- 将下载的训练集(tsv文件)转为图片集。运行脚本: `bash scripts/convert.sh FILE_PATH OUTPUT_PATH`,其中`FILE_PATH`为tsv文件位置,`OUTPUT_PATH`为输出文件夹,需要用户自行创建,推荐名称为`FaceImageCroppedWithAlignment`。
|
||||
|
||||
- 数据集结构
|
||||
|
||||
```shell
|
||||
.
|
||||
└──data
|
||||
├── FaceImageCroppedWithAlignment # 训练数据集 MS-Celeb-1M
|
||||
│ ├── m.0_0zl
|
||||
│ ├── m.0_0zy
|
||||
│ ├── m.01_06j
|
||||
│ ├── m.0107_f
|
||||
│ ...
|
||||
│
|
||||
├── lfw # 测试数据集 LFW
|
||||
│ ├── image
|
||||
│ │ ├── Aaron_Eckhart
|
||||
│ │ ├── Aaron_Guiel
|
||||
│ │ ├── Aaron_Patterson
|
||||
│ │ ├── Aaron_Peirsol
|
||||
│ │ ├── Aaron_Pena
|
||||
│ │ ...
|
||||
│ │
|
||||
│ ├── image_list_for_blufr.txt # lfw BLUFR protocols 测试集列表,需用户生成,方法见上文
|
||||
│ └── image_list_for_lfw.txt # lfw 6,000 pairs 测试集列表,需用户生成,方法见上文
|
||||
│
|
||||
└── MS-Celeb-1M_clean_list.txt # 清洗后的训练集列表
|
||||
```
|
||||
|
||||
# 特性
|
||||
|
||||
## 混合精度
|
||||
|
||||
采用[混合精度][6]的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
|
||||
以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 准备Ascend处理器搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
- 生成config json文件用于8卡训练。
|
||||
- [简易教程](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
|
||||
- 详细配置方法请参照[官网教程](https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/advanced_use/distributed_training_ascend.html#id4)。
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- 运行前准备
|
||||
|
||||
修改配置文件`src/config.py`,特别是要修改正确的[数据集](#数据集)路径。
|
||||
|
||||
```python
|
||||
from easydict import EasyDict as edict
|
||||
lightcnn_cfg = edict({
|
||||
# training setting
|
||||
'network_type': 'LightCNN_9Layers',
|
||||
'epochs': 80,
|
||||
'lr': 0.01,
|
||||
'num_classes': 79077,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 1e-4,
|
||||
'batch_size': 128,
|
||||
'image_size': 128,
|
||||
'save_checkpoint_steps': 60000,
|
||||
'keep_checkpoint_max': 40,
|
||||
# train data location
|
||||
'data_path': '/data/MS-Celeb-1M/FaceImageCroppedWithAlignment', # 绝对路径(需要修改)
|
||||
'train_list': '/data/MS-Celeb-1M_clean_list.txt', # 绝对路径(需要修改)
|
||||
# test data location
|
||||
'root_path': '/data/lfw/image', # 绝对路径(需要修改)
|
||||
'lfw_img_list': 'image_list_for_lfw.txt', # 文件名
|
||||
'lfw_pairs_mat_path': 'mat_files/lfw_pairs.mat', # 运行测试脚本位置的相对路径
|
||||
'blufr_img_list': 'image_list_for_blufr.txt', # 文件名
|
||||
'blufr_config_mat_path': 'mat_files/blufr_lfw_config.mat' # 运行测试脚本位置的相对路径
|
||||
})
|
||||
|
||||
```
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
在LightCNN原始论文的基础上,我们对MS-Celeb-1M数据集进行了训练实验,并对LFW数据集进行了评估。
|
||||
|
||||
运行以下训练脚本配置单卡训练参数:
|
||||
|
||||
```bash
|
||||
# 进入根目录
|
||||
cd LightCNN/
|
||||
|
||||
# 运行单卡训练
|
||||
# DEVICE_ID: Ascend处理器的id,需用户指定
|
||||
sh scripts/train_standalone.sh DEVICE_ID
|
||||
```
|
||||
|
||||
运行一下训练脚本配置多卡训练参数:
|
||||
|
||||
```bash
|
||||
cd LightCNN/scripts
|
||||
|
||||
# 运行2卡或4卡训练
|
||||
# hccl.json: Ascend配置信息,需用户自行配置,与八卡不同,详见官网教程
|
||||
# DEVICE_NUM应与train_distribute.sh中修改device_ids的长度相同
|
||||
# 需进入train_distribute.sh 修改device_ids=(id1 id2) 或 device_ids=(id1 id2 id3 id4)
|
||||
sh train_distribute.sh hccl.json DEVICE_NUM
|
||||
|
||||
# 运行8卡训练
|
||||
# hccl.json: Ascend配置信息,需用户自行配置
|
||||
sh train_distribute_8p.sh hccl.json
|
||||
```
|
||||
|
||||
评估步骤如下:
|
||||
|
||||
```bash
|
||||
# 进入根目录
|
||||
cd LightCNN/
|
||||
|
||||
# 评估LightCNN在lfw 6,000 pairs上的表现
|
||||
# DEVICE_ID: Ascend处理器id
|
||||
# CKPT_FILE: checkpoint权重文件
|
||||
sh scripts/eval_lfw.sh DEVICE_ID CKPT_FILE
|
||||
|
||||
# 评估LightCNN在lfw BLUFR protocols上的表现
|
||||
# DEVICE_ID: Ascend处理器id
|
||||
# CKPT_FILE: checkpoint权重文件
|
||||
sh scripts/eval_blufr.sh DEVICE_ID CKPT_FILE
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```shell
|
||||
.
|
||||
├── mat_files
|
||||
│ ├── blufr_lfw_config.mat # lfw 6,000 pairs测试配置文件
|
||||
│ └── lfw_pairs.mat # lfw BLUFR protocols测试配置文件
|
||||
├── scripts
|
||||
│ ├── eval_blufr.sh # lfw BLUFR protocols测试脚本
|
||||
│ ├── eval_lfw.sh # lfw 6,000 pairs测试脚本
|
||||
│ ├── convert.sh # 训练数据集格式转换脚本
|
||||
│ ├── train_distribute_8p.sh # 8卡并行训练脚本
|
||||
│ ├── train_distribute.sh # 多卡(2卡/4卡)并行训练脚本
|
||||
│ └── train_standalone.sh # 单卡训练脚本
|
||||
├── src
|
||||
│ ├── config.py # 训练参数配置文件
|
||||
│ ├── convert.py # 训练数据集转换脚本
|
||||
│ ├── dataset.py # 加载训练数据集
|
||||
│ ├── get_list.py # 获取测试集列表
|
||||
│ ├── lightcnn.py # LightCNN模型文件
|
||||
│ └── lr_generator.py # 动态学习率生成脚本
|
||||
│
|
||||
├── eval_blufr.py # lfw BLUFR protocols测试脚本
|
||||
├── eval_lfw.py # lfw 6,000 pairs测试脚本
|
||||
├── train.py # 训练脚本
|
||||
└── README.md
|
||||
```
|
||||
|
||||
注:`mat_files`文件夹中的两个mat文件需要用户自行下载。`blufr_lfw_config.mat`是由[Benchmark of Large-scale Unconstrained Face Recognition][7]下载,解压后文件位置在`/BLUFR/config/lfw/blufr_lfw_config.mat`;`lfw_pairs.mat`由原作者官方代码提供,可[点此][8]跳转下载。
|
||||
|
||||
## 脚本参数
|
||||
|
||||
默认训练配置
|
||||
|
||||
```bash
|
||||
'network_type': 'LightCNN_9Layers', # 模型名称
|
||||
'epochs': 80, # 总训练epoch数
|
||||
'lr': 0.01, # 训练学习率
|
||||
'num_classes': 79077, # 分类总类别数量
|
||||
'momentum': 0.9, # 动量
|
||||
'weight_decay': 1e-4, # 权重衰减
|
||||
'batch_size': 128, # batch size
|
||||
'image_size': 128, # 输入模型的图像尺寸
|
||||
'save_checkpoint_steps': 60000, # 保存checkpoint的间隔step数
|
||||
'keep_checkpoint_max': 40, # 只保存最后一个keep_checkpoint_max检查点
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 用法
|
||||
|
||||
#### Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
# trian_standalone.sh
|
||||
python3 train.py \
|
||||
--device_target Ascend \
|
||||
--device_id "$DEVICE_ID" \
|
||||
--ckpt_path ./ckpt_files > train_standalone_log.log 2>&1 &
|
||||
```
|
||||
|
||||
```bash
|
||||
# train_distribute_8p.sh
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python3 train.py \
|
||||
--device_target Ascend \
|
||||
--device_id "$DEVICE_ID" \
|
||||
--run_distribute 1 \
|
||||
--ckpt_path ./ckpt_files > train_distribute_8p.log 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
```
|
||||
|
||||
```bash
|
||||
# train_distribute.sh
|
||||
|
||||
# distributed devices id
|
||||
device_ids=(0 1 2 3)
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=${device_ids[i]}
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python3 train.py \
|
||||
--device_target Ascend \
|
||||
--device_id $DEVICE_ID \
|
||||
--run_distribute 1 \
|
||||
--ckpt_path ./ckpt_files > train_distribute.log 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
#### Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
# 单卡训练结果
|
||||
epoch: 1 step: 39451, loss is 4.6629214
|
||||
epoch time: 4850141.061 ms, per step time: 122.941 ms
|
||||
epoch: 2 step: 39451, loss is 3.6382508
|
||||
epoch time: 4148247.801 ms, per step time: 105.149 ms
|
||||
epoch: 3 step: 39451, loss is 2.9592063
|
||||
epoch time: 4146129.041 ms, per step time: 105.096 ms
|
||||
epoch: 4 step: 39451, loss is 3.6300964
|
||||
epoch time: 4128986.449 ms, per step time: 104.661 ms
|
||||
epoch: 5 step: 39451, loss is 2.9682
|
||||
epoch time: 4117678.376 ms, per step time: 104.374 ms
|
||||
epoch: 6 step: 39451, loss is 3.2115498
|
||||
epoch time: 4139044.713 ms, per step time: 104.916 ms
|
||||
...
|
||||
```
|
||||
|
||||
```bash
|
||||
# 分布式训练结果(8P)
|
||||
epoch: 1 step: 4931, loss is 8.716646
|
||||
epoch time: 1215603.837 ms, per step time: 246.523 ms
|
||||
epoch: 2 step: 4931, loss is 3.6822505
|
||||
epoch time: 1038280.276 ms, per step time: 210.562 ms
|
||||
epoch: 3 step: 4931, loss is 1.8040423
|
||||
epoch time: 1033455.542 ms, per step time: 209.583 ms
|
||||
epoch: 4 step: 4931, loss is 1.6634097
|
||||
epoch time: 1047134.763 ms, per step time: 212.357 ms
|
||||
epoch: 5 step: 4931, loss is 1.369437
|
||||
epoch time: 1053151.674 ms, per step time: 213.578 ms
|
||||
epoch: 6 step: 4931, loss is 1.3599608
|
||||
epoch time: 1064338.712 ms, per step time: 215.846 ms
|
||||
...
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 用法
|
||||
|
||||
#### Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
# 进入根目录
|
||||
cd LightCNN/
|
||||
|
||||
# 评估LightCNN在lfw 6,000 pairs上的表现
|
||||
# DEVICE_ID: Ascend处理器id
|
||||
# CKPT_FILE: checkpoint权重文件
|
||||
sh scripts/eval_lfw.sh DEVICE_ID CKPT_FILE
|
||||
|
||||
# 评估LightCNN在lfw BLUFR protocols上的表现
|
||||
# DEVICE_ID: Ascend处理器id
|
||||
# CKPT_FILE: checkpoint权重文件
|
||||
sh scripts/eval_blufr.sh DEVICE_ID CKPT_FILE
|
||||
```
|
||||
|
||||
测试脚本示例如下:
|
||||
|
||||
```bash
|
||||
# eval_lfw.sh
|
||||
# ${DEVICE_ID}: Ascend处理器id
|
||||
# ${ckpt_file}: checkpoint权重文件,由用户输入
|
||||
# eval_lfw.log:保存的测试结果
|
||||
python3 eval_lfw.py \
|
||||
--device_target Ascend \
|
||||
--device_id "${DEVICE_ID}" \
|
||||
--resume "${ckpt_file}" > eval_lfw.log 2>&1 &
|
||||
```
|
||||
|
||||
```bash
|
||||
# eval_blufr.sh
|
||||
# ${DEVICE_ID}: Ascend处理器id
|
||||
# ${ckpt_file}: checkpoint权重文件,由用户输入
|
||||
# eval_blufr.log:保存的测试结果
|
||||
# Tips:在eval_blufr.py中,可以使用numba库加速计算。如果引入了numba库,可以用'@jit'语法糖进行加速,去掉注释即可
|
||||
python3 eval_blfur.py \
|
||||
--device_target Ascend \
|
||||
--device_id "${DEVICE_ID}" \
|
||||
--resume "${ckpt_file}" > eval_blufr.log 2>&1 &
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
运行适用的训练脚本获取结果。要获得相同的结果,请按照快速入门中的步骤操作。
|
||||
|
||||
#### 训练准确率
|
||||
|
||||
> 注:该部分展示的是Ascend单卡训练结果。
|
||||
|
||||
- 在lfw 6,000 pairs上的评估结果
|
||||
|
||||
| **网络** | 100% - EER | TPR@RAR=1% | TPR@FAR=0.1% | TPR@FAR|
|
||||
| :----------: | :-----: | :----: | :----: | :-----:|
|
||||
| LightCNN-9(MindSpore版本)| 98.57%| 98.47% | 95.5% | 89.87% |
|
||||
| LightCNN-9(PyTorch版本)| 98.53%| 98.47% | 94.67% | 77.13% |
|
||||
|
||||
- 在lfw BLUFR protoclos上的评估结果
|
||||
|
||||
| **网络** | VR@FAR=0.1% | DIR@RAR=1% |
|
||||
| :----------: | :-----: | :----: |
|
||||
| LightCNN-9(MindSpore版本) | 96.26% | 81.66%|
|
||||
| LightCNN-9(PyTorch版本) | 95.56% | 79.77%|
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
| 参数 | Ascend 910|
|
||||
| -------------------------- | -------------------------------------- |
|
||||
| 模型版本 | LightCNN |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021-05-16 |
|
||||
| MindSpore版本 | 1.1.1 |
|
||||
| 数据集 | MS-Celeb-1M, LFW |
|
||||
| 训练参数 | epoch = 80, batch_size = 128, lr = 0.01 |
|
||||
| 优化器 | SGD |
|
||||
| 损失函数 | Softmax交叉熵 |
|
||||
| 输出 | 概率 |
|
||||
| 损失 | 0.10905003 |
|
||||
| 性能 | 369,144,120.56 ms(单卡)<br> 85,369,778.48 ms(八卡) |
|
||||
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/LightCNN) |
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
||||
|
||||
[1]: https://arxiv.org/pdf/1511.02683
|
||||
[2]: http://pan.baidu.com/s/1gfxB0iB
|
||||
[3]: https://drive.google.com/file/d/0ByNaVHFekDPRbFg1YTNiMUxNYXc/view?usp=sharing
|
||||
[4]: https://hyper.ai/datasets/5543
|
||||
[5]: https://pan.baidu.com/s/1eR6vHFO
|
||||
[6]: https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html
|
||||
[7]: http://www.cbsr.ia.ac.cn/users/scliao/projects/blufr/BLUFR.zip
|
||||
[8]: https://github.com/AlfredXiangWu/face_verification_experiment/blob/master/code/lfw_pairs.mat
|
||||
[9]: https://github.com/AlfredXiangWu/face_verification_experiment/blob/master/results/LightenedCNN_B_lfw.mat
|
|
@ -0,0 +1,408 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval blfur"""
|
||||
import os
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
# from numba import jit
|
||||
|
||||
from mindspore import context, load_param_into_net, load_checkpoint, Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.lightcnn import lightCNN_9Layers4Test
|
||||
from src.config import lightcnn_cfg as cfg
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Feature Extracting')
|
||||
parser.add_argument('--device_target', default='Ascend', choices=['Ascend', 'GPU', 'CPU'], type=str)
|
||||
parser.add_argument('--device_id', default=0, type=int)
|
||||
parser.add_argument('--resume', default='',
|
||||
type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--num_classes', default=79077, type=int,
|
||||
metavar='N', help='number of classes (default: 79077)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def extract_feature(img_list):
|
||||
"""extract features from model's predictions"""
|
||||
model = lightCNN_9Layers4Test(num_classes=args.num_classes)
|
||||
model.set_train(False)
|
||||
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
params_dict = load_checkpoint(args.resume)
|
||||
load_param_into_net(model, params_dict)
|
||||
else:
|
||||
print("=> ERROR: No checkpoint found at '{}'".format(args.resume))
|
||||
exit(0)
|
||||
|
||||
features_shape = (len(img_list), 256)
|
||||
features = np.empty(features_shape, dtype='float32', order='C')
|
||||
|
||||
for idx, img_name in enumerate(img_list):
|
||||
print('%d images processed' % (idx + 1,))
|
||||
img = cv2.imread(os.path.join(cfg.root_path, img_name), cv2.IMREAD_GRAYSCALE)
|
||||
if img.shape != (128, 128):
|
||||
img = cv2.resize(img, (128, 128))
|
||||
img = np.reshape(img, (1, 1, 128, 128))
|
||||
inputs = img.astype(np.float32) / 255.0
|
||||
inputs = Tensor(inputs, mstype.float32)
|
||||
_, feature = model(inputs)
|
||||
features[idx:idx + 1, :] = feature.asnumpy()
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def load_image_list(img_dir, list_file_name):
|
||||
"""get image list"""
|
||||
img_dir_cp = img_dir.replace('/image', '')
|
||||
list_file_path = os.path.join(img_dir_cp, list_file_name)
|
||||
f = open(list_file_path, 'r')
|
||||
image_list = []
|
||||
for line in f:
|
||||
img_name = line[:-4]
|
||||
person_name = line[:img_name.rfind('_')]
|
||||
path = person_name + '/' + img_name + 'bmp'
|
||||
image_list.append(path)
|
||||
return image_list
|
||||
|
||||
|
||||
def string_list_to_cells(lst):
|
||||
"""
|
||||
Uses numpy.ndarray with dtype=object. Convert list to np.ndarray().
|
||||
"""
|
||||
cells = np.ndarray(len(lst), dtype='object')
|
||||
for idx, ele in enumerate(lst):
|
||||
cells[idx] = ele
|
||||
return cells
|
||||
|
||||
|
||||
def extract_features_to_dic(image_dir, list_file):
|
||||
"""extract features and save them in dict"""
|
||||
img_list = load_image_list(image_dir, list_file)
|
||||
ftr = extract_feature(img_list)
|
||||
dic = {'Descriptors': ftr}
|
||||
return dic
|
||||
|
||||
|
||||
def compute_cosine_score(feature1, feature2):
|
||||
"""compute cosine score"""
|
||||
feature1_norm = np.linalg.norm(feature1)
|
||||
feature2_norm = np.linalg.norm(feature2)
|
||||
score = np.dot(feature1, feature2) / (feature1_norm * feature2_norm)
|
||||
return score
|
||||
|
||||
|
||||
def normr(data):
|
||||
"""compute normr"""
|
||||
ratio = np.sqrt(np.sum(np.power(data, 2)))
|
||||
return data / ratio
|
||||
|
||||
|
||||
# @jit(nopython=True)
|
||||
def bsxfun_eq(galLabels, probLabels, binaryLabels):
|
||||
"""get bsxfun_eq"""
|
||||
for idx1, ele1 in enumerate(galLabels):
|
||||
for idx2, ele2 in enumerate(probLabels):
|
||||
binaryLabels[idx1, idx2] = 1 if ele1 == ele2 else 0
|
||||
return binaryLabels
|
||||
|
||||
|
||||
# @jit(nopython=True)
|
||||
def bsxfun_eq2(galLabels, probLabels, binaryLabels):
|
||||
"""get bsxfun_eq2"""
|
||||
for i, _ in enumerate(galLabels):
|
||||
for j, ele in enumerate(probLabels):
|
||||
binaryLabels[i, j] = 1 if galLabels[i, j] == ele else 0
|
||||
return binaryLabels
|
||||
|
||||
|
||||
# @jit(nopython=True)
|
||||
def bsxfun_ge(genScore, thresholds):
|
||||
"""get bsxfun_ge"""
|
||||
temp = np.zeros((len(genScore), len(thresholds)))
|
||||
for i, ele1 in enumerate(genScore):
|
||||
for j, ele2 in enumerate(thresholds):
|
||||
temp[i, j] = 1 if ele1 >= ele2 else 0
|
||||
return temp
|
||||
|
||||
|
||||
# @jit(nopython=True)
|
||||
def bsxfun_le(genScore, thresholds):
|
||||
"""get bsxfun_le"""
|
||||
temp = np.zeros((len(genScore), len(thresholds)))
|
||||
for i, ele1 in enumerate(genScore):
|
||||
for j, ele2 in enumerate(thresholds):
|
||||
temp[i, j] = 1 if ele1 <= ele2 else 0
|
||||
return temp
|
||||
|
||||
|
||||
# @jit(nopython=True)
|
||||
def bsxfun_and(T1, T2):
|
||||
"""get bsxfun_and"""
|
||||
temp = np.zeros((T2.shape[0], T2.shape[1], T1.shape[1]))
|
||||
for i in range(temp.shape[0]):
|
||||
for j in range(temp.shape[1]):
|
||||
for k in range(temp.shape[2]):
|
||||
temp[i, j, k] = 1 if T1[i, k] * T2[i, j] != 0 else 0
|
||||
return temp
|
||||
|
||||
|
||||
def ismember(a, b):
|
||||
"""get bsxfun_and"""
|
||||
tf = np.in1d(a, b)
|
||||
index = np.array([(np.where(b == i))[0][-1] if t else 0 for i, t in zip(a, tf)])
|
||||
return tf, index
|
||||
|
||||
|
||||
def EvalROC(score, galLabels, farPoints):
|
||||
"""eval ROC"""
|
||||
probLabels = galLabels
|
||||
scoreMask = np.tril(np.ones_like(score), k=-1)
|
||||
binaryLabels = np.zeros_like(score)
|
||||
binaryLabels = bsxfun_eq(galLabels, probLabels, binaryLabels)
|
||||
|
||||
score_ = score[scoreMask == 1]
|
||||
binaryLabels_ = binaryLabels[scoreMask == 1]
|
||||
|
||||
genScore = score_[binaryLabels_ == 1]
|
||||
impScore = score_[binaryLabels_ == 0]
|
||||
del score, score_, binaryLabels, binaryLabels_
|
||||
|
||||
Nimp = len(impScore)
|
||||
falseAlarms = np.round(farPoints * Nimp)
|
||||
|
||||
impScore = np.sort(impScore)
|
||||
impScore = impScore[::-1]
|
||||
|
||||
isZeroFAR = np.zeros_like(falseAlarms)
|
||||
isZeroFAR[np.squeeze(np.where(falseAlarms == 0))] = 1
|
||||
|
||||
isOneFAR = np.zeros_like(falseAlarms)
|
||||
isOneFAR[np.squeeze(np.where(falseAlarms == Nimp))] = 1
|
||||
|
||||
thresholds = np.zeros_like(falseAlarms)
|
||||
for i, _ in enumerate(isZeroFAR):
|
||||
thresholds[i] = impScore[int(falseAlarms[i]) - 1] if isZeroFAR[i] != 1 and isOneFAR[i] != 1 else 0
|
||||
|
||||
highGenScore = genScore[genScore > impScore[0]]
|
||||
eps = 1.490116119384766e-08
|
||||
if highGenScore.size:
|
||||
thresholds[isZeroFAR == 1] = (impScore[0] + np.min(highGenScore)) / 2
|
||||
else:
|
||||
thresholds[isZeroFAR == 1] = impScore[0] + eps
|
||||
|
||||
thresholds[isOneFAR == 1] = np.minimum(impScore[-1], np.min(genScore)) - np.sqrt(eps)
|
||||
|
||||
FAR = falseAlarms / Nimp
|
||||
VR = np.mean(bsxfun_ge(genScore, thresholds), axis=0)
|
||||
|
||||
return VR, FAR
|
||||
|
||||
|
||||
def OpenSetROC(score, galLabels, probLabels, farPoints):
|
||||
"""open set ROC"""
|
||||
rankPoints = np.zeros(19)
|
||||
for i in range(10):
|
||||
rankPoints[i] = i + 1
|
||||
rankPoints[i + 9] = (i + 1) * 10
|
||||
probLabels = probLabels.T
|
||||
|
||||
binaryLabels = np.zeros_like(score)
|
||||
binaryLabels = bsxfun_eq(galLabels, probLabels, binaryLabels)
|
||||
|
||||
t = np.any(binaryLabels, axis=0)
|
||||
genProbIndex = np.squeeze(np.where(t))
|
||||
impProbIndex = np.squeeze(np.where(~t))
|
||||
# Ngen = len(genProbIndex)
|
||||
Nimp = len(impProbIndex)
|
||||
falseAlarms = np.round(farPoints * Nimp)
|
||||
|
||||
# get detection scores and matching ranks of each probe
|
||||
impScore = [np.max(score[:, i]) for i in impProbIndex]
|
||||
impScore = np.sort(impScore)
|
||||
impScore = impScore[::-1]
|
||||
|
||||
S = np.zeros((score.shape[0], len(genProbIndex)))
|
||||
for i, ele in enumerate(genProbIndex):
|
||||
S[:, i] = score[:, ele]
|
||||
sortedIndex = np.argsort(S, axis=0)
|
||||
sortedIndex = np.flipud(sortedIndex)
|
||||
M = np.zeros((binaryLabels.shape[0], len(genProbIndex)))
|
||||
for i, ele in enumerate(genProbIndex):
|
||||
M[:, i] = binaryLabels[:, ele]
|
||||
del binaryLabels
|
||||
S[M == 0] = -np.Inf
|
||||
del M
|
||||
genScore, genGalIndex = np.max(S, axis=0), np.argmax(S, axis=0)
|
||||
del S
|
||||
temp = np.zeros_like(sortedIndex)
|
||||
temp = bsxfun_eq2(sortedIndex, genGalIndex, temp)
|
||||
probRanks = (temp != 0).argmax(axis=0)
|
||||
del sortedIndex
|
||||
|
||||
# compute thresholds
|
||||
isZeroFAR = np.zeros_like(falseAlarms)
|
||||
isZeroFAR[np.squeeze(np.where(falseAlarms == 0))] = 1
|
||||
|
||||
isOneFAR = np.zeros_like(falseAlarms)
|
||||
isOneFAR[np.squeeze(np.where(falseAlarms == Nimp))] = 1
|
||||
|
||||
thresholds = np.zeros_like(falseAlarms)
|
||||
for i, _ in enumerate(isZeroFAR):
|
||||
thresholds[i] = impScore[int(falseAlarms[i]) - 1] if isZeroFAR[i] != 1 and isOneFAR[i] != 1 else 0
|
||||
|
||||
highGenScore = genScore[genScore > impScore[0]]
|
||||
eps = 1.490116119384766e-08
|
||||
if highGenScore.size:
|
||||
thresholds[isZeroFAR == 1] = (impScore[0] + np.min(highGenScore)) / 2
|
||||
else:
|
||||
thresholds[isZeroFAR == 1] = impScore[0] + eps
|
||||
|
||||
thresholds[isOneFAR == 1] = np.minimum(impScore[-1], np.min(genScore)) - np.sqrt(eps)
|
||||
|
||||
# evaluate
|
||||
genScore = genScore.T
|
||||
T1 = bsxfun_ge(genScore, thresholds)
|
||||
T2 = bsxfun_le(probRanks, rankPoints)
|
||||
T = bsxfun_and(T1, T2)
|
||||
DIR = np.squeeze(np.mean(T, axis=0))
|
||||
FAR = falseAlarms / Nimp
|
||||
return DIR, FAR
|
||||
|
||||
|
||||
def blufr_eval(lightcnn_result, config_file_path):
|
||||
"""eval blufr"""
|
||||
Descriptors = lightcnn_result['Descriptors']
|
||||
config_file = scipy.io.loadmat(config_file_path)
|
||||
testIndex = config_file['testIndex']
|
||||
galIndex = config_file['galIndex']
|
||||
probIndex = config_file['probIndex']
|
||||
labels = config_file['labels']
|
||||
|
||||
veriFarPoints = [0]
|
||||
for i in range(1, 9):
|
||||
for j in range(1, 10):
|
||||
veriFarPoints.append(round(j * pow(10, i - 9), 9 - i))
|
||||
veriFarPoints.append(1)
|
||||
veriFarPoints = np.array(veriFarPoints)
|
||||
|
||||
osiFarPoints = [0]
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 10):
|
||||
osiFarPoints.append(round(j * pow(10, i - 5), 5 - i))
|
||||
osiFarPoints.append(1)
|
||||
osiFarPoints = np.array(osiFarPoints)
|
||||
|
||||
rankPoints = []
|
||||
for i in range(0, 2):
|
||||
for j in range(1, 10):
|
||||
rankPoints.append(j * pow(10, i))
|
||||
rankPoints.append(100)
|
||||
rankPoints = np.array(rankPoints)
|
||||
|
||||
reportVeriFar = 0.001
|
||||
reportOsiFar = 0.01
|
||||
reportRank = 1
|
||||
|
||||
numTrials = len(testIndex)
|
||||
numVeriFarPoints = len(veriFarPoints)
|
||||
|
||||
VR = np.zeros((numTrials, numVeriFarPoints))
|
||||
veriFAR = np.zeros((numTrials, numVeriFarPoints))
|
||||
|
||||
numOsiFarPoints = len(osiFarPoints)
|
||||
numRanks = len(rankPoints)
|
||||
|
||||
DIR = np.zeros((numRanks, numOsiFarPoints, numTrials))
|
||||
osiFAR = np.zeros((numTrials, numOsiFarPoints))
|
||||
|
||||
veriFarIndex = np.squeeze(np.where(veriFarPoints == reportVeriFar))
|
||||
osiFarIndex = np.squeeze(np.where(osiFarPoints == reportOsiFar))
|
||||
rankIndex = np.squeeze(np.where(rankPoints == reportRank))
|
||||
|
||||
for t in range(numTrials):
|
||||
print('Processing with trail %s ...' % str(t + 1))
|
||||
idx_list = testIndex[t][0]
|
||||
X = np.zeros((len(idx_list), 256))
|
||||
for k, ele in enumerate(idx_list):
|
||||
data = Descriptors[np.squeeze(ele) - 1, :]
|
||||
X[k, :] = normr(data)
|
||||
score = np.dot(X, X.T)
|
||||
|
||||
testLabels = np.zeros(len(idx_list), dtype=np.int)
|
||||
for k, ele in enumerate(idx_list):
|
||||
testLabels[k] = labels[np.squeeze(ele) - 1]
|
||||
|
||||
VR[t, :], veriFAR[t, :] = EvalROC(score, testLabels, veriFarPoints)
|
||||
|
||||
_, gIdx = ismember(galIndex[t][0], testIndex[t][0])
|
||||
_, pIdx = ismember(probIndex[t][0], testIndex[t][0])
|
||||
|
||||
score_sub = np.zeros((len(gIdx), len(pIdx)))
|
||||
for i, ele1 in enumerate(gIdx):
|
||||
for j, ele2 in enumerate(pIdx):
|
||||
score_sub[i, j] = score[ele1, ele2]
|
||||
|
||||
testLabels_gIdx = np.zeros(len(gIdx), dtype=np.int)
|
||||
for i, ele in enumerate(gIdx):
|
||||
testLabels_gIdx[i] = testLabels[ele]
|
||||
|
||||
testLabels_pIdx = np.zeros(len(pIdx), dtype=np.int)
|
||||
for i, ele in enumerate(pIdx):
|
||||
testLabels_pIdx[i] = testLabels[ele]
|
||||
|
||||
DIR[:, :, t], osiFAR[t, :] = OpenSetROC(score_sub, testLabels_gIdx, testLabels_pIdx, osiFarPoints)
|
||||
|
||||
print('Verification:')
|
||||
print('\t@ FAR = %s%%: VR = %.4f%%' % (reportVeriFar * 100, VR[t, veriFarIndex] * 100))
|
||||
|
||||
print('Open-set Identification:')
|
||||
print('\t@ Rank = %d, FAR = %s%%: DIR = %.4f%%\n'
|
||||
% (reportRank, reportOsiFar * 100, DIR[rankIndex, osiFarIndex, t] * 100))
|
||||
|
||||
del X, score
|
||||
|
||||
# meanVerFAR = np.mean(veriFAR, axis=0)
|
||||
meanVR = np.mean(VR, axis=0)
|
||||
stdVR = np.std(VR, axis=0)
|
||||
reportMeanVR = meanVR[veriFarIndex]
|
||||
reportStdVR = stdVR[veriFarIndex]
|
||||
|
||||
# meanOsiFAR = np.mean(osiFAR, axis=0)
|
||||
meanDIR = np.mean(DIR, axis=2)
|
||||
stdDIR = np.std(DIR, axis=2)
|
||||
reportMeanDIR = meanDIR[rankIndex, osiFarIndex]
|
||||
reportStdDIR = stdDIR[rankIndex, osiFarIndex]
|
||||
|
||||
# Get the mu - sigma performance measures
|
||||
# fusedVR = (meanVR - stdVR) * 100
|
||||
reportVR = (reportMeanVR - reportStdVR) * 100
|
||||
# fusedDIR = (meanDIR - stdDIR) * 100
|
||||
reportDIR = (reportMeanDIR - reportStdDIR) * 100
|
||||
|
||||
# Display the benchmark performance
|
||||
print('Verification:')
|
||||
print('\t@ FAR = %s%%: VR = %.2f%%' % (reportVeriFar * 100, reportVR))
|
||||
print('\t@ Rank = %d, FAR = %s%%: DIR = %.2f%%.' % (reportRank, reportOsiFar * 100, reportDIR))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=device_id)
|
||||
feature_dict = extract_features_to_dic(image_dir=cfg.root_path, list_file=cfg.blufr_img_list)
|
||||
blufr_eval(feature_dict, config_file_path=cfg.blufr_config_mat_path)
|
|
@ -0,0 +1,172 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval lfw"""
|
||||
import os
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, load_param_into_net, load_checkpoint, Tensor
|
||||
from sklearn.metrics import roc_curve
|
||||
|
||||
from src.lightcnn import lightCNN_9Layers4Test
|
||||
from src.config import lightcnn_cfg as cfg
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Feature Extracting')
|
||||
parser.add_argument('--device_target', default='Ascend', choices=['Ascend', 'GPU', 'CPU'], type=str)
|
||||
parser.add_argument('--device_id', default=0, type=int)
|
||||
parser.add_argument('--resume', default='',
|
||||
type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--num_classes', default=79077, type=int, # !!!
|
||||
metavar='N', help='number of classes (default: 79077)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def extract_feature(img_list):
|
||||
"""extra features"""
|
||||
model = lightCNN_9Layers4Test(num_classes=args.num_classes)
|
||||
model.set_train(False)
|
||||
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
params_dict = load_checkpoint(args.resume)
|
||||
load_param_into_net(model, params_dict)
|
||||
else:
|
||||
print("=> ERROR: No checkpoint found at '{}'".format(args.resume))
|
||||
exit(0)
|
||||
|
||||
features_shape = (len(img_list), 256)
|
||||
features = np.empty(features_shape, dtype='float32', order='C')
|
||||
|
||||
for idx, img_name in enumerate(img_list):
|
||||
print('%d images processed' % (idx + 1,))
|
||||
img = cv2.imread(os.path.join(cfg.root_path, img_name), cv2.IMREAD_GRAYSCALE)
|
||||
if img.shape != (128, 128):
|
||||
img = cv2.resize(img, (128, 128))
|
||||
img = np.reshape(img, (1, 1, 128, 128))
|
||||
inputs = img.astype(np.float32) / 255.0
|
||||
inputs = Tensor(inputs, mstype.float32)
|
||||
_, feature = model(inputs)
|
||||
features[idx:idx + 1, :] = feature.asnumpy()
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def load_image_list(img_dir, list_file_name):
|
||||
"""load image list"""
|
||||
img_dir_cp = img_dir.replace('/image', '')
|
||||
list_file_path = os.path.join(img_dir_cp, list_file_name)
|
||||
f = open(list_file_path, 'r')
|
||||
image_list = []
|
||||
labels = []
|
||||
for line in f:
|
||||
items = line.split()
|
||||
image_list.append(items[0].strip())
|
||||
labels.append(items[1].strip())
|
||||
return labels, image_list
|
||||
|
||||
|
||||
def labels_list_to_int(labels):
|
||||
"""convert type of labels to integer"""
|
||||
int_labels = []
|
||||
for e in labels:
|
||||
try:
|
||||
inte = int(e)
|
||||
except ValueError:
|
||||
print('Labels are not int numbers. A mapping will be used.')
|
||||
break
|
||||
int_labels.append(inte)
|
||||
if len(int_labels) == len(labels):
|
||||
return int_labels
|
||||
return None
|
||||
|
||||
|
||||
def string_list_to_cells(lst):
|
||||
"""
|
||||
Uses numpy.ndarray with dtype=object. Convert list to np.ndarray().
|
||||
"""
|
||||
cells = np.ndarray(len(lst), dtype='object')
|
||||
for idx, ele in enumerate(lst):
|
||||
cells[idx] = ele
|
||||
return cells
|
||||
|
||||
|
||||
def extract_features_to_dict(image_dir, list_file):
|
||||
"""extract features and save them with dictionary"""
|
||||
labels, img_list = load_image_list(image_dir, list_file)
|
||||
ftr = extract_feature(img_list)
|
||||
integer_labels = labels_list_to_int(labels)
|
||||
feature_dict = {'features': ftr,
|
||||
'labels': integer_labels,
|
||||
'labels_original': string_list_to_cells(labels),
|
||||
'image_path': string_list_to_cells(img_list)}
|
||||
return feature_dict
|
||||
|
||||
|
||||
def compute_cosine_score(feature1, feature2):
|
||||
"""compute cosine score"""
|
||||
feature1_norm = np.linalg.norm(feature1)
|
||||
feature2_norm = np.linalg.norm(feature2)
|
||||
score = np.dot(feature1, feature2) / (feature1_norm * feature2_norm)
|
||||
return score
|
||||
|
||||
|
||||
def lfw_eval(lightcnn_result, lfw_pairs_mat_path):
|
||||
"""eval lfw"""
|
||||
features = lightcnn_result['features']
|
||||
lfw_pairs_mat = scipy.io.loadmat(lfw_pairs_mat_path)
|
||||
pos_pair = lfw_pairs_mat['pos_pair']
|
||||
neg_pair = lfw_pairs_mat['neg_pair']
|
||||
|
||||
pos_scores = np.zeros(len(pos_pair[1]))
|
||||
|
||||
for idx, _ in enumerate(pos_pair[1]):
|
||||
feat1 = features[pos_pair[0, idx] - 1, :]
|
||||
feat2 = features[pos_pair[1, idx] - 1, :]
|
||||
pos_scores[idx] = compute_cosine_score(feat1, feat2)
|
||||
pos_label = np.ones(len(pos_pair[1]))
|
||||
|
||||
neg_scores = np.zeros(len(neg_pair[1]))
|
||||
for idx, _ in enumerate(neg_pair[1]):
|
||||
feat1 = features[neg_pair[0, idx] - 1, :]
|
||||
feat2 = features[neg_pair[1, idx] - 1, :]
|
||||
neg_scores[idx] = compute_cosine_score(feat1, feat2)
|
||||
neg_label = -1 * np.ones(len(neg_pair[1]))
|
||||
|
||||
scores = np.concatenate((pos_scores, neg_scores), axis=0)
|
||||
label = np.concatenate((pos_label, neg_label), axis=0)
|
||||
|
||||
fpr, tpr, _ = roc_curve(label, scores, pos_label=1)
|
||||
res = tpr - (1 - fpr)
|
||||
|
||||
eer = tpr[np.squeeze(np.where(res >= 0))[0]] * 100
|
||||
far_10 = tpr[np.squeeze(np.where(fpr <= 0.01))[-1]] * 100
|
||||
far_01 = tpr[np.squeeze(np.where(fpr <= 0.001))[-1]] * 100
|
||||
far_00 = tpr[np.squeeze(np.where(fpr <= 0.0))[-1]] * 100
|
||||
|
||||
print('100%eer: ', round(eer, 2))
|
||||
print('tpr@far=1%: ', round(far_10, 2))
|
||||
print('tpr@far=0.1%: ', round(far_01, 2))
|
||||
print('tpr@far=0%: ', round(far_00, 2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=device_id)
|
||||
dic = extract_features_to_dict(image_dir=cfg.root_path, list_file=cfg.lfw_img_list)
|
||||
lfw_eval(dic, lfw_pairs_mat_path=cfg.lfw_pairs_mat_path)
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
from src.config import lightcnn_cfg as cfg
|
||||
from src.lightcnn import lightCNN_9Layers4Test, lightCNN_9Layers
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore LightCNN Example')
|
||||
parser.add_argument("--device_id", type=int, default=4, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="lightcnn", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# define LightCNN network
|
||||
network = lightCNN_9Layers(cfg.num_classes)
|
||||
network4Test = lightCNN_9Layers4Test(cfg.num_classes)
|
||||
|
||||
# load network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(network, param_dict)
|
||||
load_param_into_net(network4Test, param_dict)
|
||||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([args.batch_size, 1, cfg.image_size, cfg.image_size]), mindspore.float32)
|
||||
export(network, inputs, file_name=args.file_name, file_format=args.file_format)
|
||||
export(network4Test, inputs, file_name=args.file_name + '4Test', file_format=args.file_format)
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
file_path=$1
|
||||
output_dir=$2
|
||||
|
||||
python3 src/convert.py \
|
||||
--file_path "${file_path}" \
|
||||
--output_dir "${output_dir}" > convert.log 2>&1 &
|
||||
|
||||
echo "running convert.py, convert [${file_path}] to [${output_dir}]"
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
export DEVICE_ID=$1
|
||||
ckpt_file=$2
|
||||
|
||||
python3 eval_blfur.py \
|
||||
--device_target Ascend \
|
||||
--device_id "${DEVICE_ID}" \
|
||||
--resume "${ckpt_file}" > eval_blfur.log 2>&1 &
|
||||
|
||||
echo "run standalone test on device ${DEVICE_ID}"
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
export DEVICE_ID=$1
|
||||
ckpt_file=$2
|
||||
|
||||
python3 eval_lfw.py \
|
||||
--device_target Ascend \
|
||||
--device_id "${DEVICE_ID}" \
|
||||
--resume "${ckpt_file}" > eval_lfw.log 2>&1 &
|
||||
|
||||
echo "run standalone test on device ${DEVICE_ID}"
|
|
@ -0,0 +1,63 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh train_distribute.sh [RANK_TABLE_FILE] [DEVICE_NUM]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=$2
|
||||
export RANK_SIZE=$2
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
# distributed devices id
|
||||
device_ids=(0 1 2 3)
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=${device_ids[i]}
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python3 train.py \
|
||||
--device_target Ascend \
|
||||
--device_id $DEVICE_ID \
|
||||
--run_distribute 1 \
|
||||
--ckpt_path ./ckpt_files > train_distribute.log 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: sh train_distribute_8p [RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
if [ ! -f "$PATH1" ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python3 train.py \
|
||||
--device_target Ascend \
|
||||
--device_id "$DEVICE_ID" \
|
||||
--run_distribute 1 \
|
||||
--ckpt_path ./ckpt_files > train_distribute_8p.log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: sh train_standalone.sh [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_ID=$1
|
||||
export DEVICE_NUM=1
|
||||
export RANK_ID=0
|
||||
|
||||
python3 train.py \
|
||||
--device_target Ascend \
|
||||
--device_id "$DEVICE_ID" \
|
||||
--ckpt_path ./ckpt_files > train_standalone_log.log 2>&1 &
|
||||
|
||||
echo "run standalone training on device ${DEVICE_ID}"
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
lightcnn_cfg = edict({
|
||||
# training setting
|
||||
'network_type': 'LightCNN_9Layers',
|
||||
'epochs': 80,
|
||||
'lr': 0.01,
|
||||
'num_classes': 79077,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 1e-4,
|
||||
'batch_size': 128,
|
||||
'image_size': 128,
|
||||
'save_checkpoint_steps': 60000,
|
||||
'keep_checkpoint_max': 40,
|
||||
# train data location
|
||||
'data_path': '/opt/data/lightcnn_data/FaceImageCroppedWithAlignment/',
|
||||
'train_list': '/opt/data/lightcnn_data/MS-Celeb-1M_clean_list.txt',
|
||||
# test data location
|
||||
'root_path': '/opt/data/lightcnn_data/lfw/image',
|
||||
'lfw_img_list': 'image_list_for_lfw.txt',
|
||||
'lfw_pairs_mat_path': 'mat_files/lfw_pairs.mat',
|
||||
'blufr_img_list': 'image_list_for_blufr.txt',
|
||||
'blufr_config_mat_path': 'mat_files/blufr_lfw_config.mat'
|
||||
})
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""convert tsv file to images"""
|
||||
import base64
|
||||
import struct
|
||||
import os
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='convert tsv file to images')
|
||||
parser.add_argument('--file_path', default='./FaceImageCroppedWithAlignment.tsv', type=str,
|
||||
help='the path of csv file')
|
||||
parser.add_argument('--output_dir', default='./FaceImageCroppedWithAlignment/', type=str,
|
||||
help='the path of converted images')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def read_line(line):
|
||||
"""read line"""
|
||||
m_id, image_search_rank, image_url, page_url, face_id, face_rectangle, face_data = line.split("\t")
|
||||
rect = struct.unpack("ffff", base64.b64decode(face_rectangle))
|
||||
result = {
|
||||
'm_id': m_id,
|
||||
'image_search_rank': image_search_rank,
|
||||
'image_url': image_url,
|
||||
'page_url': page_url,
|
||||
'face_id': face_id,
|
||||
'rect': rect,
|
||||
'face_data': base64.b64decode(face_data)
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def write_image(filename, data):
|
||||
"""write image"""
|
||||
with open(filename, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
def unpack(file_name, output_dir):
|
||||
"""unpack file"""
|
||||
i = 0
|
||||
with open(file_name, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
result = read_line(line)
|
||||
img_dir = os.path.join(output_dir, result['m_id'])
|
||||
if not os.path.exists(img_dir):
|
||||
os.mkdir(img_dir)
|
||||
img_name = "%s-%s" % (result['image_search_rank'], result['face_id']) + ".jpg"
|
||||
write_image(os.path.join(img_dir, img_name), result['face_data'])
|
||||
i += 1
|
||||
if i % 1000 == 0:
|
||||
print(i, "images finished")
|
||||
print("all finished")
|
||||
|
||||
|
||||
def main(file_name, output_dir):
|
||||
"""main function"""
|
||||
unpack(file_name, output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(file_name=args.file_path, output_dir=args.output_dir)
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""get dataset loader"""
|
||||
import os
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
from mindspore.dataset.transforms.py_transforms import Compose
|
||||
|
||||
|
||||
def img_loader(path):
|
||||
"""load image"""
|
||||
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
||||
return img
|
||||
|
||||
|
||||
def list_reader(fileList):
|
||||
"""read image list"""
|
||||
imgList = []
|
||||
with open(fileList, 'r') as f:
|
||||
for line in f.readlines():
|
||||
imgPath, label = line.strip().split(' ')
|
||||
imgList.append((imgPath, int(label)))
|
||||
return imgList
|
||||
|
||||
|
||||
class ImageList:
|
||||
"""
|
||||
class for load dataset
|
||||
"""
|
||||
def __init__(self, root, fileList):
|
||||
self.root = root
|
||||
self.loader = img_loader
|
||||
self.imgList = list_reader(fileList)
|
||||
|
||||
def __getitem__(self, index):
|
||||
imgPath, target = self.imgList[index]
|
||||
img = self.loader(os.path.join(self.root, imgPath))
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgList)
|
||||
|
||||
|
||||
def create_dataset(mode, data_url, data_list, batch_size, resize_size=144,
|
||||
input_size=128, num_of_workers=8, is_distributed=False,
|
||||
rank=0, group_size=1, seed=0):
|
||||
"""
|
||||
create dataset for train or test
|
||||
"""
|
||||
image_ops, shuffle, drop_last = None, None, None
|
||||
if mode == 'Train':
|
||||
shuffle = True
|
||||
drop_last = True
|
||||
image_ops = Compose([py_vision.ToPIL(),
|
||||
py_vision.Resize(resize_size),
|
||||
py_vision.RandomCrop(input_size),
|
||||
py_vision.RandomHorizontalFlip(),
|
||||
py_vision.ToTensor()])
|
||||
|
||||
elif mode == 'Val':
|
||||
shuffle = False
|
||||
drop_last = False
|
||||
image_ops = Compose([py_vision.ToPIL(),
|
||||
py_vision.Resize(resize_size),
|
||||
py_vision.CenterCrop(input_size),
|
||||
py_vision.ToTensor()])
|
||||
|
||||
dataset_generator = ImageList(root=data_url, fileList=data_list)
|
||||
|
||||
sampler = None
|
||||
if is_distributed:
|
||||
sampler = DistributedSampler(dataset=dataset_generator, rank=rank,
|
||||
group_size=group_size, shuffle=shuffle, seed=seed)
|
||||
|
||||
dataset = ds.GeneratorDataset(dataset_generator, ["image", "label"],
|
||||
shuffle=shuffle, sampler=sampler,
|
||||
num_parallel_workers=num_of_workers)
|
||||
|
||||
dataset = dataset.map(input_columns=["image"],
|
||||
operations=image_ops,
|
||||
num_parallel_workers=num_of_workers)
|
||||
|
||||
dataset = dataset.batch(batch_size, num_parallel_workers=num_of_workers, drop_remainder=drop_last)
|
||||
dataset = dataset.repeat(1)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""
|
||||
Distributed sampler
|
||||
"""
|
||||
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.group_size = group_size
|
||||
self.dataset_length = len(self.dataset)
|
||||
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
|
||||
self.total_size = self.num_samples * self.group_size
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.seed = (self.seed + 1) & 0xffffffff
|
||||
np.random.seed(self.seed)
|
||||
indices = np.random.permutation(self.dataset_length).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset.classes)))
|
||||
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
indices = indices[self.rank::self.group_size]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Generate test lists"""
|
||||
import scipy.io as io
|
||||
import numpy as np
|
||||
|
||||
f1 = 'image_list_for_lfw.txt'
|
||||
|
||||
mat_lfw = io.loadmat('LightenedCNN_B_lfw.mat')
|
||||
lfw_path_list = mat_lfw['image_path']
|
||||
lfw_path_list = np.transpose(lfw_path_list)
|
||||
|
||||
lfw_label_list = mat_lfw['labels_original']
|
||||
lfw_label_list = np.transpose(lfw_label_list)
|
||||
|
||||
for idx, ele in enumerate(lfw_path_list):
|
||||
print(ele[0][0][10:], lfw_label_list[idx][0][0])
|
||||
with open(f1, 'a') as f:
|
||||
line = ele[0][0][10:] + ' ' + lfw_label_list[idx][0][0]
|
||||
f.write(line + '\n')
|
||||
|
||||
|
||||
f2 = 'image_list_for_blufr.txt'
|
||||
|
||||
mat_blufr = io.loadmat('BLUFR/config/lfw/blufr_lfw_config.mat')
|
||||
blufr_path_list = mat_blufr['imageList']
|
||||
|
||||
for _, ele in enumerate(blufr_path_list):
|
||||
print(ele[0][0])
|
||||
with open(f2, 'a') as f:
|
||||
f.write(ele[0][0] + '\n')
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""LightCNN network"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import Normal
|
||||
from mindspore.common.initializer import XavierUniform
|
||||
|
||||
|
||||
class Mfm(nn.Cell):
|
||||
"""Mfn module"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad_mode='valid', mode=1):
|
||||
super(Mfm, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
if mode == 1:
|
||||
self.filter = nn.Conv2d(in_channels, 2 * out_channels, kernel_size=kernel_size, stride=stride,
|
||||
pad_mode=pad_mode, weight_init=XavierUniform(), has_bias=True)
|
||||
elif mode == 0:
|
||||
self.filter = nn.Dense(in_channels, 2 * out_channels, weight_init=Normal(0.02))
|
||||
self.maximum = P.Maximum()
|
||||
self.split = P.Split(axis=1, output_num=2)
|
||||
|
||||
def construct(self, x):
|
||||
"""Mfn construct"""
|
||||
x = self.filter(x)
|
||||
out = self.split(x)
|
||||
out = self.maximum(out[0], out[1])
|
||||
return out
|
||||
|
||||
|
||||
class Group(nn.Cell):
|
||||
"""group module"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
||||
super(Group, self).__init__()
|
||||
self.conv_a = Mfm(in_channels, in_channels, 1, 1, pad_mode='same')
|
||||
self.conv = Mfm(in_channels, out_channels, kernel_size, stride, pad_mode='same')
|
||||
|
||||
def construct(self, x):
|
||||
"""Group construct"""
|
||||
x = self.conv_a(x)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Resblock(nn.Cell):
|
||||
"""res block"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Resblock, self).__init__()
|
||||
self.conv1 = Mfm(in_channels, out_channels, kernel_size=3, stride=1, pad_mode='same')
|
||||
self.conv2 = Mfm(out_channels, out_channels, kernel_size=3, stride=1, pad_mode='same')
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""Resblock construct"""
|
||||
res = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
out = self.add(x, res)
|
||||
return out
|
||||
|
||||
|
||||
def clip_gradient(dx):
|
||||
"""clip gradient"""
|
||||
ret = dx
|
||||
if ret > 5.0:
|
||||
ret = 5.0
|
||||
if ret < 0.05:
|
||||
ret = 0.05
|
||||
return ret
|
||||
|
||||
|
||||
class Network9Layers(nn.Cell):
|
||||
"""9layer LightCNN network for train"""
|
||||
|
||||
def __init__(self, num_classes):
|
||||
super(Network9Layers, self).__init__()
|
||||
self.features = nn.SequentialCell([
|
||||
Mfm(1, 48, kernel_size=5, stride=1, pad_mode='same'),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
Group(48, 96, 3, 1),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
Group(96, 192, 3, 1),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
Group(192, 128, 3, 1),
|
||||
Group(128, 128, 3, 1),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
])
|
||||
self.fc1 = Mfm(8 * 8 * 128, 256, mode=0)
|
||||
self.fc2 = nn.Dense(256, num_classes, weight_init=Normal(0.02))
|
||||
self.flatten = nn.Flatten()
|
||||
self.dropout = nn.Dropout(keep_prob=0.5)
|
||||
|
||||
def construct(self, x):
|
||||
"""network construct"""
|
||||
x = self.features(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.dropout(x)
|
||||
out = self.fc2(x)
|
||||
return out
|
||||
|
||||
|
||||
class Network9Layers4Test(Network9Layers):
|
||||
"""9layer LightCNN network for test"""
|
||||
|
||||
def construct(self, x):
|
||||
"""network construct"""
|
||||
x = self.features(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.dropout(x)
|
||||
out = self.fc2(x)
|
||||
return out, x
|
||||
|
||||
|
||||
def lightCNN_9Layers(num_classes):
|
||||
"""get 9layers model for train"""
|
||||
model = Network9Layers(num_classes)
|
||||
return model
|
||||
|
||||
|
||||
def lightCNN_9Layers4Test(num_classes):
|
||||
"""get 9layers model for test"""
|
||||
model = Network9Layers4Test(num_classes)
|
||||
return model
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(epoch_max, lr_base, steps_per_epoch, step=10, scale=0.457305051927326):
|
||||
"""generate learning rate"""
|
||||
lr_list = []
|
||||
for epoch in range(epoch_max):
|
||||
for _ in range(steps_per_epoch):
|
||||
lr_list.append(lr_base * (scale ** (epoch // step)))
|
||||
return np.array(lr_list)
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""""train LightCNN."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import context, Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.model import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.nn.metrics import Accuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy
|
||||
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import lightcnn_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.lightcnn import lightCNN_9Layers
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""parse train parameters."""
|
||||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', default=0, type=int)
|
||||
parser.add_argument('--ckpt_path', type=str, default="", help='if is test, must provide\
|
||||
path where the trained mat_files file')
|
||||
parser.add_argument('--run_distribute', type=int, default=0, help='0 -- run standalone, 1 -- run distribute')
|
||||
parser.add_argument('--resume', type=str, default='', help="resume model's checkpoint, please use \
|
||||
checkpoint file name")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entrance for training"""
|
||||
args = parse_args()
|
||||
set_seed(1)
|
||||
|
||||
# context parameters
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
device_num = 1
|
||||
rank_id = 0
|
||||
|
||||
# init environment(distribute or not)
|
||||
if args.run_distribute:
|
||||
device_num = int(os.getenv('DEVICE_NUM'))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target, device_id=device_id)
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=args.device_target, device_id=device_id)
|
||||
|
||||
# define save checkpoint flag
|
||||
is_save_checkpoint = True
|
||||
if rank_id != 0:
|
||||
is_save_checkpoint = False
|
||||
|
||||
# define dataset
|
||||
if args.run_distribute:
|
||||
ds_train = create_dataset(mode='Train',
|
||||
data_url=cfg.data_path,
|
||||
data_list=cfg.train_list,
|
||||
batch_size=cfg.batch_size,
|
||||
num_of_workers=8,
|
||||
is_distributed=True,
|
||||
group_size=get_group_size(),
|
||||
rank=get_rank(),
|
||||
seed=0
|
||||
)
|
||||
else:
|
||||
ds_train = create_dataset(mode='Train', data_url=cfg.data_path, data_list=cfg.train_list,
|
||||
batch_size=cfg.batch_size, num_of_workers=8)
|
||||
|
||||
# define network
|
||||
network = lightCNN_9Layers(cfg.num_classes)
|
||||
|
||||
# resume network
|
||||
if args.resume:
|
||||
if os.path.isfile(args.resume):
|
||||
net_parameters = load_checkpoint(args.resume)
|
||||
load_param_into_net(net_parameters, network)
|
||||
else:
|
||||
raise RuntimeError('No such file {}'.format(args.resume))
|
||||
|
||||
# define dynamic learning rate
|
||||
steps_per_epoch = ds_train.get_dataset_size()
|
||||
bias_fc2_lr = get_lr(epoch_max=cfg.epochs, lr_base=cfg.lr * 20, steps_per_epoch=steps_per_epoch)
|
||||
bias_nfc2_lr = get_lr(epoch_max=cfg.epochs, lr_base=cfg.lr * 2, steps_per_epoch=steps_per_epoch)
|
||||
fc2_lr = get_lr(epoch_max=cfg.epochs, lr_base=cfg.lr * 10, steps_per_epoch=steps_per_epoch)
|
||||
nfc2_lr = get_lr(epoch_max=cfg.epochs, lr_base=cfg.lr, steps_per_epoch=steps_per_epoch)
|
||||
|
||||
bias_fc2_lr = Tensor(bias_fc2_lr, mstype.float32)
|
||||
bias_nfc2_lr = Tensor(bias_nfc2_lr, mstype.float32)
|
||||
fc2_lr = Tensor(fc2_lr, mstype.float32)
|
||||
nfc2_lr = Tensor(nfc2_lr, mstype.float32)
|
||||
|
||||
# define optimizer parameter
|
||||
params_dict = dict(network.parameters_and_names())
|
||||
bias_fc2 = []
|
||||
bias_nfc2 = []
|
||||
fc2 = []
|
||||
nfc2 = []
|
||||
for k, param in params_dict.items():
|
||||
if 'bias' in k:
|
||||
if 'fc2' in k:
|
||||
bias_fc2.append(param)
|
||||
else:
|
||||
bias_nfc2.append(param)
|
||||
else:
|
||||
if 'fc2' in k:
|
||||
fc2.append(param)
|
||||
else:
|
||||
nfc2.append(param)
|
||||
params = [
|
||||
{'params': bias_fc2, 'lr': bias_fc2_lr, 'weight_decay': 0},
|
||||
{'params': bias_nfc2, 'lr': bias_nfc2_lr, 'weight_decay': 0},
|
||||
{'params': fc2, 'lr': fc2_lr},
|
||||
{'params': nfc2},
|
||||
]
|
||||
|
||||
# define loss function
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
|
||||
# define optimizer
|
||||
net_opt = nn.SGD(params, nfc2_lr, cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
|
||||
# define model
|
||||
model = Model(network, net_loss, net_opt,
|
||||
metrics={"Accuracy": Accuracy(),
|
||||
"Top1": Top1CategoricalAccuracy(),
|
||||
"Top5": Top5CategoricalAccuracy()},
|
||||
amp_level="O3")
|
||||
|
||||
# define callbacks
|
||||
callbacks = []
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
callbacks.append(time_cb)
|
||||
callbacks.append(LossMonitor())
|
||||
if is_save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lightcnn", directory=args.ckpt_path, config=config_ck)
|
||||
callbacks.append(ckpoint_cb)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(cfg['epochs'], ds_train, callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue