forked from mindspore-Ecosystem/mindspore
!17633 [线上贡献]黄金赛段FishNet99网络精度性能调优提交PR+GPU网络模型征集活动
Merge pull request !17633 from huicui/FishNet99_bold
This commit is contained in:
commit
a656f40f41
|
@ -0,0 +1,313 @@
|
||||||
|
# 目录
|
||||||
|
|
||||||
|
<!-- TOC -->
|
||||||
|
|
||||||
|
- [目录](#目录)
|
||||||
|
- [FishNet99描述](#FishNet99描述)
|
||||||
|
- [模型架构](#模型架构)
|
||||||
|
- [数据集](#数据集)
|
||||||
|
- [特性](#特性)
|
||||||
|
- [混合精度](#混合精度)
|
||||||
|
- [环境要求](#环境要求)
|
||||||
|
- [快速入门](#快速入门)
|
||||||
|
- [脚本说明](#脚本说明)
|
||||||
|
- [脚本及样例代码](#脚本及样例代码)
|
||||||
|
- [脚本参数](#脚本参数)
|
||||||
|
- [训练过程](#训练过程)
|
||||||
|
- [训练](#训练)
|
||||||
|
- [分布式训练](#分布式训练)
|
||||||
|
- [评估过程](#评估过程)
|
||||||
|
- [评估](#评估)
|
||||||
|
- [导出过程](#导出过程)
|
||||||
|
- [导出](#导出)
|
||||||
|
- [推理过程](#推理过程)
|
||||||
|
- [推理](#推理)
|
||||||
|
- [模型描述](#模型描述)
|
||||||
|
- [性能](#性能)
|
||||||
|
- [评估性能](#评估性能)
|
||||||
|
- [ImageNet-1k上的FishNet99](#ImageNet-1k上的FishNet99)
|
||||||
|
- [推理性能](#推理性能)
|
||||||
|
- [ImageNet-1k上的FishNet99](#ImageNet-1k上的FishNet99)
|
||||||
|
- [ModelZoo主页](#modelzoo主页)
|
||||||
|
|
||||||
|
<!-- /TOC -->
|
||||||
|
|
||||||
|
# FishNet99描述
|
||||||
|
|
||||||
|
这是第一个统一为像素级、区域级和图像级任务设计的骨干网络;可以将梯度从非常深的层直接传播到较浅的层;可以保留并互相细化不同深度的特征。
|
||||||
|
|
||||||
|
[论文](http://papers.nips.cc/paper/7356-fishnet-a-versatile-backbone-for-image-region-and-pixel-level-prediction.pdf) :FishNet: a versatile backbone for image, region, and pixel level prediction. In Proceedings of the 32nd International Conference on Neural Information Processing Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 762–772.
|
||||||
|
|
||||||
|
# 模型架构
|
||||||
|
|
||||||
|
整个网络分为tail、body和head三个部分,其中tail是现有的如ResNet等CNN,随着网络的深入,特征分辨率会逐渐减小;body部分有多个上采样和细化块的结构,主要用来细化来自tail和body的特征;head则是有着数个下采样和细化块的结构,用来保留和细化来自tail和body的特征,最后一个卷积层的细化特征被用来处理图像分类等最终任务。
|
||||||
|
|
||||||
|
# 数据集
|
||||||
|
|
||||||
|
使用的数据集:[ImageNet2012](http://www.image-net.org/)
|
||||||
|
|
||||||
|
- 数据集大小:125G,共1000个类、125万张彩色图像
|
||||||
|
- 训练集:120G,共1,281,167张图像
|
||||||
|
- 测试集:5G,共50,000张图像
|
||||||
|
- 数据格式:RGB
|
||||||
|
- 注:数据将在src/dataset.py中处理
|
||||||
|
- 下载数据集,目录结构如下:
|
||||||
|
|
||||||
|
```text
|
||||||
|
└─dataset
|
||||||
|
├─ILSVRC2012_train # 训练数据集
|
||||||
|
└─ILSVRC2012_val # 评估数据集
|
||||||
|
```
|
||||||
|
|
||||||
|
# 特性
|
||||||
|
|
||||||
|
## 混合精度
|
||||||
|
|
||||||
|
采用[混合精度](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/enable_mixed_precision.html) 的训练方法,使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
|
||||||
|
|
||||||
|
# 环境要求
|
||||||
|
|
||||||
|
- 硬件(Ascend)
|
||||||
|
- 使用Ascend来搭建硬件环境。
|
||||||
|
- 框架
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- 如需查看详情,请参见如下资源:
|
||||||
|
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
# 快速入门
|
||||||
|
|
||||||
|
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||||
|
|
||||||
|
- Ascend处理器环境运行
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行训练示例
|
||||||
|
python train.py --device_id=0 --device_type='Ascend' > train.log 2>&1 &
|
||||||
|
|
||||||
|
# 运行分布式训练示例
|
||||||
|
bash ./scripts/run_train_ascend.sh [RANK_TABLE_FILE]
|
||||||
|
|
||||||
|
# 运行评估示例
|
||||||
|
python eval.py --checkpoint_path ./ckpt_0 --device_type='Ascend' > ./eval.log 2>&1 &
|
||||||
|
|
||||||
|
# 运行推理示例
|
||||||
|
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||||
|
```
|
||||||
|
|
||||||
|
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
|
||||||
|
|
||||||
|
请遵循以下链接中的说明:
|
||||||
|
|
||||||
|
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||||
|
|
||||||
|
- GPU处理器环境运行
|
||||||
|
|
||||||
|
为了在GPU处理器环境运行,请将配置文件src/config.py中的device_target从Ascend改为GPU。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行训练示例
|
||||||
|
python train.py --device_id=0 --device_type='GPU' > train_gpu.log 2>&1 &
|
||||||
|
|
||||||
|
# 运行分布式训练示例
|
||||||
|
bash ./scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
|
||||||
|
|
||||||
|
# 运行评估示例
|
||||||
|
python eval.py --checkpoint_path ./ckpt_0 --device_type='GPU' > ./eval_gpu.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
# 脚本说明
|
||||||
|
|
||||||
|
## 脚本及样例代码
|
||||||
|
|
||||||
|
```bash
|
||||||
|
├── model_zoo
|
||||||
|
├── README.md // 所有模型相关说明
|
||||||
|
├── fishnet99
|
||||||
|
├── README_CN.md // FishNet99相关说明
|
||||||
|
├── ascend310_infer // 实现310推理源代码
|
||||||
|
├── scripts
|
||||||
|
│ ├──run_eval_gpu.sh // GPU评估的shell脚本
|
||||||
|
│ ├──run_infer_310.sh // Ascend推理的shell脚本
|
||||||
|
│ ├──run_train_ascend.sh // 分布式到Ascend的shell脚本
|
||||||
|
│ ├──run_train_gpu.sh // 分布式到GPU处理器的shell脚本
|
||||||
|
├── src
|
||||||
|
│ ├──config.py // 参数配置
|
||||||
|
│ ├──dataset.py // 创建数据集
|
||||||
|
│ ├──fishnet.py // FishNet99架构
|
||||||
|
├── eval.py // 评估脚本
|
||||||
|
├── export.py // 将checkpoint文件导出到air/mindir
|
||||||
|
├── postprocess.py // 310推理后处理脚本
|
||||||
|
├── train.py // 训练脚本
|
||||||
|
```
|
||||||
|
|
||||||
|
## 脚本参数
|
||||||
|
|
||||||
|
在config.py中可以同时配置训练参数和评估参数。
|
||||||
|
|
||||||
|
- 配置FishNet99和ImageNet-1k数据集。
|
||||||
|
|
||||||
|
```python
|
||||||
|
'name':'imagenet' # 数据集
|
||||||
|
'pre_trained':'False' # 是否基于预训练模型训练
|
||||||
|
'num_classes':1000 # 数据集类数
|
||||||
|
'lr_init':0.05 # 初始学习率,Ascned单卡训练时设置为0.05,Ascned八卡并行训练时设置为0.4,GPU单卡训练时设置为0.05,GPU双卡并行训练时设置为0.1
|
||||||
|
'batch_size':128 # 训练批次大小
|
||||||
|
'epoch_size':160 # 总计训练epoch数,其中GPU双卡并行训练时设置为110
|
||||||
|
'T_max':150 # 学习率衰减相关参数,其中GPU双卡并行训练时设置为100
|
||||||
|
'momentum':0.9 # 动量
|
||||||
|
'weight_decay':1e-4 # 权重衰减值
|
||||||
|
'image_height':224 # 输入到模型的图像高度
|
||||||
|
'image_width':224 # 输入到模型的图像宽度
|
||||||
|
'data_path':'/data/ILSVRC2012_train/' # 训练数据集的绝对全路径
|
||||||
|
'val_data_path':'/data/ILSVRC2012_val/' # 评估数据集的绝对全路径
|
||||||
|
'device_target':'Ascend' # 运行设备
|
||||||
|
'device_id':0 # 用于训练或评估数据集的设备ID使用run_train.sh进行分布式训练时可以忽略。
|
||||||
|
'keep_checkpoint_max':25 # 最多保存25个ckpt模型文件
|
||||||
|
'checkpoint_path':'./ckpt/train_fishnet99_imagenet-146_10009.ckpt' # checkpoint文件保存的绝对全路径
|
||||||
|
```
|
||||||
|
|
||||||
|
更多配置细节请参考脚本`config.py`。
|
||||||
|
|
||||||
|
## 训练过程
|
||||||
|
|
||||||
|
### 训练
|
||||||
|
|
||||||
|
- Ascend处理器环境运行
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py --device_id=0 --device_type='Ascend' > train.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
上述python命令将在后台运行,可以通过生成的train.log文件查看结果。
|
||||||
|
|
||||||
|
训练结束后,可以在默认脚本文件夹下找到检查点文件,采用以下方式得到损失值:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# grep "loss is " train.log
|
||||||
|
...
|
||||||
|
epoch: 8 step: 10009, loss is 3.0276418
|
||||||
|
epoch: 9 step: 10009, loss is 3.0397775
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
- GPU处理器环境运行
|
||||||
|
|
||||||
|
为了在GPU处理器环境运行,请将配置文件src/config.py中的device_target从Ascend改为GPU。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py --device_id=0 --device_type='GPU' > train_gpu.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
上述python命令将在后台运行,可以通过生成的train_gpu.log文件查看结果。
|
||||||
|
|
||||||
|
训练结束后,可以在默认./ckpt_0/脚本文件夹下找到检查点文件。
|
||||||
|
|
||||||
|
### 分布式训练
|
||||||
|
|
||||||
|
- Ascend处理器环境运行
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash ./scripts/run_train_ascend.sh [RANK_TABLE_FILE]
|
||||||
|
```
|
||||||
|
|
||||||
|
上述shell脚本将在后台运行分布训练。
|
||||||
|
|
||||||
|
- GPU处理器环境运行
|
||||||
|
|
||||||
|
为了在GPU处理器环境运行,请将配置文件src/config.py中的device_target从Ascend改为GPU。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash ./scripts/run_train_gpu.sh 2 0,1
|
||||||
|
```
|
||||||
|
|
||||||
|
上述shell脚本将在后台运行分布训练。可以在生成的train文件夹中查看结果。
|
||||||
|
|
||||||
|
## 评估过程
|
||||||
|
|
||||||
|
### 评估
|
||||||
|
|
||||||
|
- 在Ascend环境运行时评估ImageNet-1k数据集
|
||||||
|
|
||||||
|
“./ckpt_0”是保存了训练好的.ckpt模型文件的目录。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py --checkpoint_path ./ckpt_0 --device_type='Ascend' > ./eval.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
- 在GPU处理器环境运行时评估ImageNet-1k数据集
|
||||||
|
|
||||||
|
“./ckpt_0”是保存了训练好的.ckpt模型文件的目录。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py --checkpoint_path ./ckpt_0 --device_type='GPU' > ./eval_gpu.log 2>&1 &
|
||||||
|
OR
|
||||||
|
bash ./scripts/run_eval.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## 导出过程
|
||||||
|
|
||||||
|
### 导出
|
||||||
|
|
||||||
|
将checkpoint文件导出成mindir格式模型。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python export.py --ckpt_file [CKPT_FILE]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 推理过程
|
||||||
|
|
||||||
|
### 推理
|
||||||
|
|
||||||
|
在进行推理之前我们需要先导出模型。mindir可以在任意环境上导出,air模型只能在昇腾910环境上导出。以下展示了使用mindir模型执行推理的示例。
|
||||||
|
|
||||||
|
- 在昇腾310上使用ImageNet-1k数据集进行推理
|
||||||
|
|
||||||
|
推理的结果保存在scripts目录下,在acc.log日志文件中可以找到类似以下的结果。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Ascend310 inference
|
||||||
|
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||||
|
Total data: 50000, top1 accuracy: 0.78242, top5 accuracy: 0.94042.
|
||||||
|
```
|
||||||
|
|
||||||
|
# 模型描述
|
||||||
|
|
||||||
|
## 性能
|
||||||
|
|
||||||
|
### 评估性能
|
||||||
|
|
||||||
|
#### ImageNet-1k上的FishNet99
|
||||||
|
|
||||||
|
| 参数 | Ascend | GPU |
|
||||||
|
| -------------------------- | ----------------------------------------------------------- | -------------------------- |
|
||||||
|
| 模型版本 | FishNet99 | FishNet99 |
|
||||||
|
| 资源 | Ascend 910 | Tesla V100-32G |
|
||||||
|
| 上传日期 | 2021-06-02 | 2021-07-24 |
|
||||||
|
| MindSpore版本 | 1.2.0 | 1.2.0 |
|
||||||
|
| 数据集 | ImageNet2012 | ImageNet2012 |
|
||||||
|
| 训练参数 | epoch=160, batch_size=128, lr_init=0.05(单卡为0.05,八卡为0.4) | epoch=160(单卡160,双卡110), T_max=150(单卡150,双卡100), batch_size=128, lr_init=0.05(单卡0.05,双卡0.1) |
|
||||||
|
| 优化器 | Momentum | Momentum |
|
||||||
|
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
|
||||||
|
| 输出 | 概率 | 概率 |
|
||||||
|
| 分类准确率 | 单卡:top1:78.24%, top5:94.03%;八卡:top1:78.33%, top5:93.96% | 单卡:top1:78.12%, top5:94.13%;双卡:top1:77.97%, top5:93.98% |
|
||||||
|
| 速度 | 单卡:132毫秒/步;八卡:135毫秒/步 | 单卡:227毫秒/步;双卡:450毫秒/步 |
|
||||||
|
| 总时长 | 单卡:58.5小时/160轮;八卡:7.7小时/160轮 | 单卡:109.6小时/160轮;双卡:69.1小时/110轮 |
|
||||||
|
|
||||||
|
### 推理性能
|
||||||
|
|
||||||
|
#### ImageNet-1k上的FishNet99
|
||||||
|
|
||||||
|
| 参数 | Ascend |
|
||||||
|
| -------------------------- | ----------------------------------------------------------- |
|
||||||
|
| 模型版本 | FishNet99 |
|
||||||
|
| 资源 | Ascend 310 |
|
||||||
|
| 上传日期 | 2021-06-16 |
|
||||||
|
| MindSpore版本 | 1.2.0 |
|
||||||
|
| 数据集 | ImageNet2012 |
|
||||||
|
| 分类准确率 | top1:78.24%,top5:94.04% |
|
||||||
|
| 速度 | Average time 5.17187 ms of infer_count 50000 |
|
||||||
|
|
||||||
|
# ModelZoo主页
|
||||||
|
|
||||||
|
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,35 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||||
|
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||||
|
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/types.h"
|
||||||
|
|
||||||
|
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||||
|
DIR *OpenDir(std::string_view dirName);
|
||||||
|
std::string RealPath(std::string_view path);
|
||||||
|
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||||
|
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||||
|
std::vector<std::string> GetAllFiles(std::string dir_name);
|
||||||
|
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name);
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,14 @@
|
||||||
|
cmake_minimum_required(VERSION 3.14.1)
|
||||||
|
project(MindSporeCxxTestcase[CXX])
|
||||||
|
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||||
|
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||||
|
option(MINDSPORE_PATH "mindspore install path" "")
|
||||||
|
include_directories(${MINDSPORE_PATH})
|
||||||
|
include_directories(${MINDSPORE_PATH}/include)
|
||||||
|
include_directories(${PROJECT_SRC_ROOT}/../)
|
||||||
|
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||||
|
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||||
|
|
||||||
|
add_executable(main main.cc utils.cc)
|
||||||
|
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,18 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
cmake . -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||||
|
make
|
|
@ -0,0 +1,146 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <sys/time.h>
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iosfwd>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "include/api/model.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/types.h"
|
||||||
|
#include "include/api/serialization.h"
|
||||||
|
#include "minddata/dataset/include/vision_ascend.h"
|
||||||
|
#include "minddata/dataset/include/execute.h"
|
||||||
|
#include "minddata/dataset/include/transforms.h"
|
||||||
|
#include "minddata/dataset/include/vision.h"
|
||||||
|
#include "inc/utils.h"
|
||||||
|
|
||||||
|
using mindspore::dataset::vision::Decode;
|
||||||
|
using mindspore::dataset::vision::Resize;
|
||||||
|
using mindspore::dataset::vision::CenterCrop;
|
||||||
|
using mindspore::dataset::vision::Normalize;
|
||||||
|
using mindspore::dataset::vision::HWC2CHW;
|
||||||
|
using mindspore::dataset::TensorTransform;
|
||||||
|
using mindspore::Context;
|
||||||
|
using mindspore::Serialization;
|
||||||
|
using mindspore::Model;
|
||||||
|
using mindspore::Status;
|
||||||
|
using mindspore::ModelType;
|
||||||
|
using mindspore::GraphCell;
|
||||||
|
using mindspore::kSuccess;
|
||||||
|
using mindspore::MSTensor;
|
||||||
|
using mindspore::dataset::Execute;
|
||||||
|
|
||||||
|
|
||||||
|
DEFINE_string(mindir_path, "", "mindir path");
|
||||||
|
DEFINE_string(dataset_path, ".", "dataset path");
|
||||||
|
DEFINE_int32(device_id, 0, "device id");
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
if (RealPath(FLAGS_mindir_path).empty()) {
|
||||||
|
std::cout << "Invalid mindir" << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||||
|
ascend310->SetDeviceID(FLAGS_device_id);
|
||||||
|
context->MutableDeviceInfo().push_back(ascend310);
|
||||||
|
mindspore::Graph graph;
|
||||||
|
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||||
|
Model model;
|
||||||
|
Status ret = model.Build(GraphCell(graph), context);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
std::cout << "ERROR: Build failed." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto all_files = GetAllInputData(FLAGS_dataset_path);
|
||||||
|
if (all_files.empty()) {
|
||||||
|
std::cout << "ERROR: no input data." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<double, double> costTime_map;
|
||||||
|
size_t size = all_files.size();
|
||||||
|
// Define transform
|
||||||
|
std::vector<int32_t> crop_paras = {224};
|
||||||
|
std::vector<int32_t> resize_paras = {256};
|
||||||
|
std::vector<float> mean = {0.485 * 255, 0.456 * 255, 0.406 * 255};
|
||||||
|
std::vector<float> std = {0.229 * 255, 0.224 * 255, 0.225 * 255};
|
||||||
|
|
||||||
|
auto decode = Decode();
|
||||||
|
auto resize = Resize(resize_paras);
|
||||||
|
auto centercrop = CenterCrop(crop_paras);
|
||||||
|
auto normalize = Normalize(mean, std);
|
||||||
|
auto hwc2chw = HWC2CHW();
|
||||||
|
|
||||||
|
mindspore::dataset::Execute SingleOp({decode, resize, centercrop, normalize, hwc2chw});
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
for (size_t j = 0; j < all_files[i].size(); ++j) {
|
||||||
|
struct timeval start = {0};
|
||||||
|
struct timeval end = {0};
|
||||||
|
double startTimeMs;
|
||||||
|
double endTimeMs;
|
||||||
|
std::vector<MSTensor> inputs;
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
std::cout << "Start predict input files:" << all_files[i][j] <<std::endl;
|
||||||
|
auto imgDvpp = std::make_shared<MSTensor>();
|
||||||
|
SingleOp(ReadFileToTensor(all_files[i][j]), imgDvpp.get());
|
||||||
|
|
||||||
|
inputs.emplace_back(imgDvpp->Name(), imgDvpp->DataType(), imgDvpp->Shape(),
|
||||||
|
imgDvpp->Data().get(), imgDvpp->DataSize());
|
||||||
|
gettimeofday(&start, nullptr);
|
||||||
|
ret = model.Predict(inputs, &outputs);
|
||||||
|
gettimeofday(&end, nullptr);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
std::cout << "Predict " << all_files[i][j] << " failed." << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||||
|
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||||
|
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||||
|
WriteResult(all_files[i][j], outputs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
double average = 0.0;
|
||||||
|
int inferCount = 0;
|
||||||
|
|
||||||
|
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||||
|
double diff = 0.0;
|
||||||
|
diff = iter->second - iter->first;
|
||||||
|
average += diff;
|
||||||
|
inferCount++;
|
||||||
|
}
|
||||||
|
average = average / inferCount;
|
||||||
|
std::stringstream timeCost;
|
||||||
|
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||||
|
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||||
|
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||||
|
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||||
|
fileStream << timeCost.str();
|
||||||
|
fileStream.close();
|
||||||
|
costTime_map.clear();
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,185 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include "inc/utils.h"
|
||||||
|
|
||||||
|
using mindspore::MSTensor;
|
||||||
|
using mindspore::DataType;
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name) {
|
||||||
|
std::vector<std::vector<std::string>> ret;
|
||||||
|
|
||||||
|
DIR *dir = OpenDir(dir_name);
|
||||||
|
if (dir == nullptr) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
struct dirent *filename;
|
||||||
|
/* read all the files in the dir ~ */
|
||||||
|
std::vector<std::string> sub_dirs;
|
||||||
|
while ((filename = readdir(dir)) != nullptr) {
|
||||||
|
std::string d_name = std::string(filename->d_name);
|
||||||
|
// get rid of "." and ".."
|
||||||
|
if (d_name == "." || d_name == ".." || d_name.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string dir_path = RealPath(std::string(dir_name) + "/" + filename->d_name);
|
||||||
|
struct stat s;
|
||||||
|
lstat(dir_path.c_str(), &s);
|
||||||
|
if (!S_ISDIR(s.st_mode)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_dirs.emplace_back(dir_path);
|
||||||
|
}
|
||||||
|
std::sort(sub_dirs.begin(), sub_dirs.end());
|
||||||
|
|
||||||
|
(void)std::transform(sub_dirs.begin(), sub_dirs.end(), std::back_inserter(ret),
|
||||||
|
[](const std::string &d) { return GetAllFiles(d); });
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<std::string> GetAllFiles(std::string dir_name) {
|
||||||
|
struct dirent *filename;
|
||||||
|
DIR *dir = OpenDir(dir_name);
|
||||||
|
if (dir == nullptr) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> res;
|
||||||
|
while ((filename = readdir(dir)) != nullptr) {
|
||||||
|
std::string d_name = std::string(filename->d_name);
|
||||||
|
if (d_name == "." || d_name == ".." || d_name.size() <= 3) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
|
||||||
|
}
|
||||||
|
std::sort(res.begin(), res.end());
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||||
|
struct dirent *filename;
|
||||||
|
DIR *dir = OpenDir(dirName);
|
||||||
|
if (dir == nullptr) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
std::vector<std::string> res;
|
||||||
|
while ((filename = readdir(dir)) != nullptr) {
|
||||||
|
std::string dName = std::string(filename->d_name);
|
||||||
|
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||||
|
}
|
||||||
|
std::sort(res.begin(), res.end());
|
||||||
|
for (auto &f : res) {
|
||||||
|
std::cout << "image file: " << f << std::endl;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||||
|
std::string homePath = "./result_Files";
|
||||||
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
size_t outputSize;
|
||||||
|
std::shared_ptr<const void> netOutput;
|
||||||
|
netOutput = outputs[i].Data();
|
||||||
|
outputSize = outputs[i].DataSize();
|
||||||
|
int pos = imageFile.rfind('/');
|
||||||
|
std::string fileName(imageFile, pos + 1);
|
||||||
|
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||||
|
std::string outFileName = homePath + "/" + fileName;
|
||||||
|
FILE *outputFile = fopen(outFileName.c_str(), "wb");
|
||||||
|
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||||
|
fclose(outputFile);
|
||||||
|
outputFile = nullptr;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||||
|
if (file.empty()) {
|
||||||
|
std::cout << "Pointer file is nullptr" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ifstream ifs(file);
|
||||||
|
if (!ifs.good()) {
|
||||||
|
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ifs.is_open()) {
|
||||||
|
std::cout << "File: " << file << "open failed" << std::endl;
|
||||||
|
return mindspore::MSTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::end);
|
||||||
|
size_t size = ifs.tellg();
|
||||||
|
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::beg);
|
||||||
|
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||||
|
ifs.close();
|
||||||
|
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DIR *OpenDir(std::string_view dirName) {
|
||||||
|
if (dirName.empty()) {
|
||||||
|
std::cout << " dirName is null ! " << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::string realPath = RealPath(dirName);
|
||||||
|
struct stat s;
|
||||||
|
lstat(realPath.c_str(), &s);
|
||||||
|
if (!S_ISDIR(s.st_mode)) {
|
||||||
|
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
DIR *dir;
|
||||||
|
dir = opendir(realPath.c_str());
|
||||||
|
if (dir == nullptr) {
|
||||||
|
std::cout << "Can not open dir " << dirName << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||||
|
return dir;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RealPath(std::string_view path) {
|
||||||
|
char realPathMem[PATH_MAX] = {0};
|
||||||
|
char *realPathRet = nullptr;
|
||||||
|
realPathRet = realpath(path.data(), realPathMem);
|
||||||
|
if (realPathRet == nullptr) {
|
||||||
|
std::cout << "File: " << path << " is not exist.";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string realPath(realPathMem);
|
||||||
|
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||||
|
return realPath;
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# less required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""create_imagenet2012_label"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="resnet imagenet2012 label")
|
||||||
|
parser.add_argument("--img_path", type=str, required=True, help="imagenet2012 file path.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def create_label(file_path):
|
||||||
|
"""
|
||||||
|
create label
|
||||||
|
"""
|
||||||
|
print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!")
|
||||||
|
dirs = os.listdir(file_path)
|
||||||
|
file_list = []
|
||||||
|
for file in dirs:
|
||||||
|
file_list.append(file)
|
||||||
|
file_list = sorted(file_list)
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
img_label = {}
|
||||||
|
for i, file_dir in enumerate(file_list):
|
||||||
|
files = os.listdir(os.path.join(file_path, file_dir))
|
||||||
|
for f in files:
|
||||||
|
img_label[f] = i
|
||||||
|
total += len(files)
|
||||||
|
|
||||||
|
with open("imagenet_label.json", "w+") as label:
|
||||||
|
json.dump(img_label, label)
|
||||||
|
|
||||||
|
print("[INFO] Completed! Total {} data.".format(total))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
create_label(args.img_path)
|
|
@ -0,0 +1,94 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
python eval.py
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.nn.loss.loss import _Loss
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
from src.config import imagenet_cfg
|
||||||
|
from src.dataset import create_dataset_imagenet
|
||||||
|
|
||||||
|
import src.fishnet as net_ms
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='fishnet99')
|
||||||
|
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'],
|
||||||
|
help='dataset name.')
|
||||||
|
parser.add_argument('--device_type', type=str, default=None, help='GPU or Ascend. (Default: None)')
|
||||||
|
parser.add_argument('--checkpoint_path', type=str, default='./ckpt_0', help='Checkpoint file path')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropySmooth(_Loss):
|
||||||
|
"""CrossEntropy"""
|
||||||
|
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||||
|
super(CrossEntropySmooth, self).__init__()
|
||||||
|
self.onehot = P.OneHot()
|
||||||
|
self.sparse = sparse
|
||||||
|
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||||
|
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||||
|
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||||
|
|
||||||
|
def construct(self, logit, label):
|
||||||
|
if self.sparse:
|
||||||
|
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||||
|
loss_ = self.ce(logit, label)
|
||||||
|
return loss_
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
if args_opt.dataset_name == "imagenet":
|
||||||
|
cfg = imagenet_cfg
|
||||||
|
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
|
||||||
|
if not cfg.use_label_smooth:
|
||||||
|
cfg.label_smooth_factor = 0.0
|
||||||
|
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||||
|
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||||
|
net = net_ms.fish99()
|
||||||
|
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("dataset is not support.")
|
||||||
|
|
||||||
|
if not args_opt.device_type:
|
||||||
|
device_target = args_opt.device_type
|
||||||
|
else:
|
||||||
|
device_target = cfg.device_target
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||||
|
context.set_context(device_id=cfg.device_id)
|
||||||
|
|
||||||
|
file_list = os.listdir(args_opt.checkpoint_path)
|
||||||
|
for filename in file_list:
|
||||||
|
de_path = os.path.join(args_opt.checkpoint_path, filename)
|
||||||
|
if de_path.endswith('.ckpt'):
|
||||||
|
param_dict = load_checkpoint(de_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
net.set_train(False)
|
||||||
|
|
||||||
|
acc = model.eval(dataset)
|
||||||
|
print(f"model {de_path}'s accuracy is {acc}")
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
##############export checkpoint file into air, onnx or mindir model#################
|
||||||
|
python export.py
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||||
|
|
||||||
|
import src.fishnet as net_ms
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='FishNet99 export')
|
||||||
|
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||||
|
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||||
|
parser.add_argument("--file_name", type=str, default="FishNet99", help="output file name.")
|
||||||
|
parser.add_argument('--width', type=int, default=224, help='input width')
|
||||||
|
parser.add_argument('--height', type=int, default=224, help='input height')
|
||||||
|
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
net = net_ms.fish99()
|
||||||
|
|
||||||
|
assert args.ckpt_file is not None, "checkpoint_path is None."
|
||||||
|
|
||||||
|
param_dict = load_checkpoint(args.ckpt_file)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
input_arr = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
|
||||||
|
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# less required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""post process for 310 inference"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from src.config import imagenet_cfg
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
parser = argparse.ArgumentParser(description="resnet inference")
|
||||||
|
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
|
||||||
|
parser.add_argument("--label_path", type=str, required=True, help="image file path.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def get_result(result_path, label_path):
|
||||||
|
"""
|
||||||
|
get result.
|
||||||
|
"""
|
||||||
|
files = os.listdir(result_path)
|
||||||
|
with open(label_path, "r") as label:
|
||||||
|
labels = json.load(label)
|
||||||
|
|
||||||
|
top1 = 0
|
||||||
|
top5 = 0
|
||||||
|
total_data = len(files)
|
||||||
|
for file in files:
|
||||||
|
img_ids_name = file.split('_0.')[0]
|
||||||
|
data_path = os.path.join(result_path, img_ids_name + "_0.bin")
|
||||||
|
result = np.fromfile(data_path, dtype=np.float32).reshape(batch_size, imagenet_cfg.num_classes)
|
||||||
|
for batch in range(batch_size):
|
||||||
|
predict = np.argsort(-result[batch], axis=-1)
|
||||||
|
if labels[img_ids_name+".JPEG"] == predict[0]:
|
||||||
|
top1 += 1
|
||||||
|
if labels[img_ids_name+".JPEG"] in predict[:5]:
|
||||||
|
top5 += 1
|
||||||
|
print(f"Total data: {total_data}, top1 accuracy: {top1/total_data}, top5 accuracy: {top5/total_data}.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
get_result(args.result_path, args.label_path)
|
|
@ -0,0 +1,29 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 1 ]
|
||||||
|
then
|
||||||
|
echo "Usage: bash run_eval.sh [CKPT_LOCATION]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $1 ]
|
||||||
|
then
|
||||||
|
echo "error: CKPT_LOCATION=$1 is not files' location"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
python eval.py --checkpoint_path=$1 --device_type='GPU' > ./eval.log 2>&1 &
|
|
@ -0,0 +1,99 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [[ $# -lt 2 || $# -gt 3 ]]; then
|
||||||
|
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||||
|
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
model=$(get_real_path $1)
|
||||||
|
data_path=$(get_real_path $2)
|
||||||
|
|
||||||
|
device_id=0
|
||||||
|
if [ $# == 3 ]; then
|
||||||
|
device_id=$3
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "mindir name: "$model
|
||||||
|
echo "dataset path: "$data_path
|
||||||
|
echo "device id: "$device_id
|
||||||
|
|
||||||
|
export ASCEND_HOME=/usr/local/Ascend/
|
||||||
|
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||||
|
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||||
|
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||||
|
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||||
|
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||||
|
else
|
||||||
|
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||||
|
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||||
|
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||||
|
fi
|
||||||
|
|
||||||
|
function compile_app()
|
||||||
|
{
|
||||||
|
cd ../ascend310_infer/src/ || exit
|
||||||
|
if [ -f "Makefile" ]; then
|
||||||
|
make clean
|
||||||
|
fi
|
||||||
|
sh build.sh &> build.log
|
||||||
|
}
|
||||||
|
|
||||||
|
function infer()
|
||||||
|
{
|
||||||
|
cd - || exit
|
||||||
|
if [ -d result_Files ]; then
|
||||||
|
rm -rf ./result_Files
|
||||||
|
fi
|
||||||
|
if [ -d time_Result ]; then
|
||||||
|
rm -rf ./time_Result
|
||||||
|
fi
|
||||||
|
mkdir result_Files
|
||||||
|
mkdir time_Result
|
||||||
|
../ascend310_infer/src/main --mindir_path=$model --dataset_path=$data_path --device_id=$device_id &> infer.log
|
||||||
|
}
|
||||||
|
|
||||||
|
function cal_acc()
|
||||||
|
{
|
||||||
|
python3.7 ../create_imagenet2012_label.py --img_path=$data_path
|
||||||
|
python3.7 ../postprocess.py --result_path=./result_Files --label_path=./imagenet_label.json &> acc.log &
|
||||||
|
}
|
||||||
|
|
||||||
|
compile_app
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "compile app code failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
infer
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo " execute inference failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cal_acc
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "calculate accuracy failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
|
@ -0,0 +1,51 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 1 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_train.sh [RANK_TABLE_FILE]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $1 ]
|
||||||
|
then
|
||||||
|
echo "error: RANK_TABLE_FILE=$1 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=8
|
||||||
|
export RANK_SIZE=8
|
||||||
|
RANK_TABLE_FILE=$(realpath $1)
|
||||||
|
export RANK_TABLE_FILE
|
||||||
|
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||||
|
|
||||||
|
rank_start=$(DEVICE_NUM)
|
||||||
|
for((i=0; i<${DEVICE_NUM}; i++))
|
||||||
|
do
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$((rank_start + i))
|
||||||
|
rm -rf ./train_parallel$i
|
||||||
|
mkdir ./train_parallel$i
|
||||||
|
cp -r ./src ./train_parallel$i
|
||||||
|
cp ./train.py ./train_parallel$i
|
||||||
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
|
cd ./train_parallel$i ||exit
|
||||||
|
env > env.log
|
||||||
|
python train.py --device_id=$i --device_type='Ascend' > log 2>&1 &
|
||||||
|
cd ..
|
||||||
|
done
|
|
@ -0,0 +1,52 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# -lt 2 ]
|
||||||
|
then
|
||||||
|
echo "Usage: \
|
||||||
|
bash run_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]\
|
||||||
|
"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $1 -lt 1 ] && [ $1 -gt 8 ]
|
||||||
|
then
|
||||||
|
echo "error: DEVICE_NUM=$1 is not in (1-8)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DEVICE_NUM=$1
|
||||||
|
export RANK_SIZE=$1
|
||||||
|
|
||||||
|
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
|
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
|
if [ -d "./train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cd ./train || exit
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="$2"
|
||||||
|
|
||||||
|
|
||||||
|
if [ $1 -gt 1 ]
|
||||||
|
then
|
||||||
|
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
|
||||||
|
python3 ${BASEPATH}/../train.py --device_type='GPU' > train_gpu.log 2>&1 &
|
||||||
|
else
|
||||||
|
python3 ${BASEPATH}/../train.py --device_type='GPU' > train_gpu.log 2>&1 &
|
||||||
|
fi
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""from googlenet"""
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
|
imagenet_cfg = edict({
|
||||||
|
'name': 'imagenet',
|
||||||
|
'pre_trained': False,
|
||||||
|
'num_classes': 1000,
|
||||||
|
'lr_init': 0.05, # Ascend_1P: 0.05, Ascend_8P: 0.4, GPU_1P: 0.05, GPU_2P: 0.1
|
||||||
|
'batch_size': 128,
|
||||||
|
'epoch_size': 160, # GPU_2P: 110
|
||||||
|
'momentum': 0.9,
|
||||||
|
'weight_decay': 1e-4,
|
||||||
|
'image_height': 224,
|
||||||
|
'image_width': 224,
|
||||||
|
'data_path': '/data/ILSVRC2012_train/',
|
||||||
|
'val_data_path': '/data/ILSVRC2012_val/',
|
||||||
|
'device_target': 'Ascend',
|
||||||
|
'device_id': 0,
|
||||||
|
'keep_checkpoint_max': 30,
|
||||||
|
'checkpoint_path': None,
|
||||||
|
'onnx_filename': 'fishnet99',
|
||||||
|
'air_filename': 'fishnet99',
|
||||||
|
|
||||||
|
# optimizer and lr related
|
||||||
|
'lr_scheduler': 'cosine_annealing',
|
||||||
|
'lr_epochs': [30, 60, 90, 120],
|
||||||
|
'lr_gamma': 0.3,
|
||||||
|
'eta_min': 0.0,
|
||||||
|
'T_max': 150, # GPU_2P: 100
|
||||||
|
'warmup_epochs': 0,
|
||||||
|
|
||||||
|
# loss related
|
||||||
|
'is_dynamic_loss_scale': 0,
|
||||||
|
'loss_scale': 1024,
|
||||||
|
'label_smooth_factor': 0.1,
|
||||||
|
'use_label_smooth': True,
|
||||||
|
})
|
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Data operations, will be used in train.py and eval.py
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.c_transforms as C
|
||||||
|
import mindspore.dataset.vision.c_transforms as vision
|
||||||
|
from src.config import imagenet_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_imagenet(dataset_path, repeat_num=1, training=True, num_parallel_workers=16, shuffle=None):
|
||||||
|
"""
|
||||||
|
create a train or eval imagenet2012 dataset for resnet50
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path(string): the path of dataset.
|
||||||
|
do_train(bool): whether dataset is used for train or eval.
|
||||||
|
repeat_num(int): the repeat times of dataset. Default: 1
|
||||||
|
batch_size(int): the batch size of dataset. Default: 32
|
||||||
|
target(str): the device target. Default: Ascend
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dataset
|
||||||
|
"""
|
||||||
|
|
||||||
|
device_num, rank_id = _get_rank_info()
|
||||||
|
|
||||||
|
if device_num == 1:
|
||||||
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
|
||||||
|
else:
|
||||||
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
|
||||||
|
num_shards=device_num, shard_id=rank_id)
|
||||||
|
|
||||||
|
assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
|
||||||
|
image_size = imagenet_cfg.image_height
|
||||||
|
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||||
|
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||||
|
|
||||||
|
# define map operations
|
||||||
|
if training:
|
||||||
|
transform_img = [
|
||||||
|
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||||
|
vision.RandomHorizontalFlip(prob=0.5),
|
||||||
|
vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1),
|
||||||
|
vision.Normalize(mean=mean, std=std),
|
||||||
|
vision.HWC2CHW()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
transform_img = [
|
||||||
|
vision.Decode(),
|
||||||
|
vision.Resize(256),
|
||||||
|
vision.CenterCrop(image_size),
|
||||||
|
vision.Normalize(mean=mean, std=std),
|
||||||
|
vision.HWC2CHW()
|
||||||
|
]
|
||||||
|
|
||||||
|
transform_label = [C.TypeCast(mstype.int32)]
|
||||||
|
|
||||||
|
data_set = data_set.map(input_columns="image", num_parallel_workers=12, operations=transform_img)
|
||||||
|
data_set = data_set.map(input_columns="label", num_parallel_workers=4, operations=transform_label)
|
||||||
|
|
||||||
|
# apply batch operations
|
||||||
|
data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
# apply dataset repeat operation
|
||||||
|
data_set = data_set.repeat(repeat_num)
|
||||||
|
|
||||||
|
return data_set
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rank_info():
|
||||||
|
"""
|
||||||
|
get rank size and rank id
|
||||||
|
"""
|
||||||
|
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||||
|
|
||||||
|
if rank_size > 1:
|
||||||
|
from mindspore.communication.management import get_rank, get_group_size
|
||||||
|
rank_size = get_group_size()
|
||||||
|
rank_id = get_rank()
|
||||||
|
else:
|
||||||
|
rank_size = rank_id = None
|
||||||
|
|
||||||
|
return rank_size, rank_id
|
|
@ -0,0 +1,599 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
FishNet model of MindSpore-1.2.0.
|
||||||
|
"""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
|
||||||
|
conv_weight_init = 'HeUniform'
|
||||||
|
|
||||||
|
|
||||||
|
class adaptiveavgpool2d_ms(nn.Cell):
|
||||||
|
"""adaptiveavgpool2d_ms"""
|
||||||
|
def __init__(self):
|
||||||
|
super(adaptiveavgpool2d_ms, self).__init__()
|
||||||
|
self.ada_pool = P.ReduceMean(keep_dims=True)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.ada_pool(x, (2, 3))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _bn_relu_conv(in_c, out_c, **conv_kwargs):
|
||||||
|
return nn.SequentialCell([nn.BatchNorm2d(num_features=in_c, momentum=0.9),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(in_channels=in_c, out_channels=out_c, pad_mode='pad',
|
||||||
|
weight_init=conv_weight_init, **conv_kwargs),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock_with_shortcut(nn.Cell):
|
||||||
|
"""
|
||||||
|
Construct Basic Bottle-necked Residual Block module.
|
||||||
|
Args:
|
||||||
|
in_c : Number of channels in the input image
|
||||||
|
out_c : Number of channels in the output image
|
||||||
|
shortcut : Specific function for skip-connection
|
||||||
|
Examples)
|
||||||
|
'bn_relu_conv' for DownRefinementBlock
|
||||||
|
'bn_relu_conv with channel reduction' for UpRefinementBlock
|
||||||
|
'identity mapping' for regular connection
|
||||||
|
stride : Stride of middle conv layer
|
||||||
|
dilation : Dilation rate of middle conv layer
|
||||||
|
Forwarding Path:
|
||||||
|
⎡ (shortcut) ⎤
|
||||||
|
input image - (BN-ReLU-Conv) * 3 - (add) -output
|
||||||
|
"""
|
||||||
|
def __init__(self, in_c, out_c, shortcut, stride=1, dilation=1):
|
||||||
|
super(ResBlock_with_shortcut, self).__init__()
|
||||||
|
|
||||||
|
mid_c = out_c // 4
|
||||||
|
self.layers = nn.SequentialCell([
|
||||||
|
_bn_relu_conv(in_c, mid_c, kernel_size=1, has_bias=False),
|
||||||
|
_bn_relu_conv(mid_c, mid_c, kernel_size=3, stride=stride, padding=dilation, dilation=dilation,
|
||||||
|
has_bias=False),
|
||||||
|
_bn_relu_conv(mid_c, out_c, kernel_size=1, has_bias=False),
|
||||||
|
])
|
||||||
|
self.shortcut = shortcut
|
||||||
|
self.add_1 = P.Add()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
return self.add_1(self.layers(x), self.shortcut(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock_without_shortcut(nn.Cell):
|
||||||
|
"""
|
||||||
|
Construct Basic Bottle-necked Residual Block module.
|
||||||
|
Args:
|
||||||
|
in_c : Number of channels in the input image
|
||||||
|
out_c : Number of channels in the output image
|
||||||
|
stride : Stride of middle conv layer
|
||||||
|
dilation : Dilation rate of middle conv layer
|
||||||
|
Forwarding Path:
|
||||||
|
⎡ (shortcut) ⎤
|
||||||
|
input image - (BN-ReLU-Conv) * 3 - (add) -output
|
||||||
|
"""
|
||||||
|
def __init__(self, in_c, out_c, stride=1, dilation=1):
|
||||||
|
super(ResBlock_without_shortcut, self).__init__()
|
||||||
|
|
||||||
|
mid_c = out_c // 4
|
||||||
|
self.layers_ = nn.SequentialCell([
|
||||||
|
_bn_relu_conv(in_c, mid_c, kernel_size=1, has_bias=False),
|
||||||
|
_bn_relu_conv(mid_c, mid_c, kernel_size=3, stride=stride, padding=dilation, dilation=dilation,
|
||||||
|
has_bias=False),
|
||||||
|
_bn_relu_conv(mid_c, out_c, kernel_size=1, has_bias=False),
|
||||||
|
])
|
||||||
|
self.add_2 = P.Add()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
return self.add_2(self.layers_(x), x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransferBlock(nn.Cell):
|
||||||
|
"""
|
||||||
|
Construct Transfer Block module.
|
||||||
|
Args:
|
||||||
|
ch : Number of channels in the input and output image
|
||||||
|
num_blk : Number of Residual Blocks
|
||||||
|
Forwarding Path:
|
||||||
|
input image - (ResBlock_without_shortcut) * num_blk - output
|
||||||
|
"""
|
||||||
|
def __init__(self, ch, num_blk):
|
||||||
|
super(TransferBlock, self).__init__()
|
||||||
|
|
||||||
|
self.layers_TransferBlock = nn.SequentialCell([*[ResBlock_without_shortcut(ch, ch)
|
||||||
|
for _ in range(0, num_blk)]])
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
return self.layers_TransferBlock(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DownStage(nn.Cell):
|
||||||
|
"""
|
||||||
|
Construct a stage for each resolution.
|
||||||
|
A DownStage is consisted of one DownRefinementBlock and several residual regular connection blocks.
|
||||||
|
(Note: In fact, DownRefinementBlock is not used in FishHead according to original implementation.
|
||||||
|
However, it seems needed to be used according to original paper.
|
||||||
|
In this version, we followed original implementation, not original paper.)
|
||||||
|
Args:in_c : Number of channels in the input image
|
||||||
|
out_c : Number of channels in the output image
|
||||||
|
num_blk : Number of Residual Blocks
|
||||||
|
stride : Stride of shortcut conv layer
|
||||||
|
Forwarding Path:input image - (ResBlock with Shortcut) - (ResBlock) * num_blk - (MaxPool) - output
|
||||||
|
"""
|
||||||
|
def __init__(self, in_c, out_c, num_blk, stride=1):
|
||||||
|
super(DownStage, self).__init__()
|
||||||
|
|
||||||
|
shortcut = _bn_relu_conv(in_c, out_c, kernel_size=1, stride=stride, has_bias=False)
|
||||||
|
self.layer_DownStage = nn.SequentialCell([
|
||||||
|
ResBlock_with_shortcut(in_c, out_c, shortcut),
|
||||||
|
*[ResBlock_without_shortcut(out_c, out_c) for _ in range(1, num_blk)],
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
|
||||||
|
])
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
return self.layer_DownStage(x)
|
||||||
|
|
||||||
|
|
||||||
|
class UpStage(nn.Cell):
|
||||||
|
"""
|
||||||
|
Construct a stage for each resolution.
|
||||||
|
A DownStage is consisted of one DownRefinementBlock and several residual regular connection blocks.
|
||||||
|
Not like DownStage, this module reduces the number of channels of concatenated feature maps in the shortcut path.
|
||||||
|
Args:in_c : Number of channels in the input image
|
||||||
|
out_c : Number of channels in the output image
|
||||||
|
num_blk : Number of Residual Blocks
|
||||||
|
stride : Stride of shortcut conv layer
|
||||||
|
Forwarding Path:input image - (ResBlock with Channel Reduction) - (ResBlock) * num_blk - (UpSample) - output
|
||||||
|
"""
|
||||||
|
def __init__(self, in_c, out_c, num_blk, dilation=1):
|
||||||
|
super(UpStage, self).__init__()
|
||||||
|
|
||||||
|
self.k = in_c // out_c
|
||||||
|
self.redece_sum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.shape_ = P.Shape()
|
||||||
|
self.reshape_ = P.Reshape()
|
||||||
|
self.layer_UpStage = nn.SequentialCell([
|
||||||
|
ResBlock_with_shortcut(in_c, out_c, channel_reduction_ms(in_c // out_c), dilation=dilation),
|
||||||
|
*[ResBlock_without_shortcut(out_c, out_c, dilation=dilation) for _ in range(1, num_blk)],
|
||||||
|
])
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
return self.layer_UpStage(x)
|
||||||
|
|
||||||
|
|
||||||
|
class channel_reduction_ms(nn.Cell):
|
||||||
|
"""channel_reduction_ms"""
|
||||||
|
def __init__(self, kk):
|
||||||
|
super(channel_reduction_ms, self).__init__()
|
||||||
|
self.shape_ = P.Shape()
|
||||||
|
self.kk = kk
|
||||||
|
self.reshape_ = P.Reshape()
|
||||||
|
self.redece_sum = P.ReduceSum(keep_dims=False)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
n, c, h_, w_ = self.shape_(x)
|
||||||
|
x = self.redece_sum(self.reshape_(x, (n, c // self.kk, self.kk, h_, w_)), 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FishTail(nn.Cell):
|
||||||
|
"""
|
||||||
|
Construct FishTail module.
|
||||||
|
Each instances corresponds to each stages.
|
||||||
|
Args:
|
||||||
|
in_c : Number of channels in the input image
|
||||||
|
out_c : Number of channels in the output image
|
||||||
|
num_blk : Number of Residual Blocks
|
||||||
|
Forwarding Path:
|
||||||
|
input image - (DownStage) - output
|
||||||
|
"""
|
||||||
|
def __init__(self, in_c, out_c, num_blk):
|
||||||
|
super(FishTail, self).__init__()
|
||||||
|
|
||||||
|
self.layer_FishTail = DownStage(in_c, out_c, num_blk)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.layer_FishTail(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Bridge(nn.Cell):
|
||||||
|
"""
|
||||||
|
Construct Bridge module.
|
||||||
|
This module bridges the last FishTail stage and first FishBody stage.
|
||||||
|
Args:ch : Number of channels in the input and output image
|
||||||
|
num_blk : Number of Residual Blocks
|
||||||
|
Forwarding Path:
|
||||||
|
r (SEBlock) ㄱ
|
||||||
|
input image - (stem) - (ResBlock with Shortcut) - (ResBlock) * num_blk - (mul & sum) - output
|
||||||
|
"""
|
||||||
|
def __init__(self, ch, num_blk):
|
||||||
|
super(Bridge, self).__init__()
|
||||||
|
|
||||||
|
self.stem = nn.SequentialCell([
|
||||||
|
nn.BatchNorm2d(num_features=ch, momentum=0.9),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(in_channels=ch, out_channels=ch // 2, kernel_size=1, pad_mode='pad', has_bias=False,
|
||||||
|
weight_init=conv_weight_init),
|
||||||
|
nn.BatchNorm2d(num_features=ch // 2, momentum=0.9),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(in_channels=ch // 2, out_channels=ch * 2, kernel_size=1, pad_mode='pad', has_bias=True,
|
||||||
|
weight_init=conv_weight_init)
|
||||||
|
])
|
||||||
|
shortcut = _bn_relu_conv(ch * 2, ch, kernel_size=1, has_bias=False)
|
||||||
|
self.layers_Bridge = nn.SequentialCell([
|
||||||
|
ResBlock_with_shortcut(ch * 2, ch, shortcut),
|
||||||
|
*[ResBlock_without_shortcut(ch, ch) for _ in range(1, num_blk)],
|
||||||
|
])
|
||||||
|
|
||||||
|
self.se_block = nn.SequentialCell([
|
||||||
|
nn.BatchNorm2d(num_features=ch * 2, momentum=0.9),
|
||||||
|
nn.ReLU(),
|
||||||
|
adaptiveavgpool2d_ms(),
|
||||||
|
nn.Conv2d(in_channels=ch * 2, out_channels=ch // 16, kernel_size=1, pad_mode='pad', has_bias=True
|
||||||
|
, weight_init=conv_weight_init),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(in_channels=ch // 16, out_channels=ch, kernel_size=1, pad_mode='pad', has_bias=True
|
||||||
|
, weight_init=conv_weight_init),
|
||||||
|
nn.Sigmoid()
|
||||||
|
])
|
||||||
|
self.add_3 = P.Add()
|
||||||
|
self.mul_ = P.Mul()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.stem(x)
|
||||||
|
att = self.se_block(x)
|
||||||
|
out = self.layers_Bridge(x)
|
||||||
|
return self.add_3(self.mul_(out, att), att)
|
||||||
|
|
||||||
|
|
||||||
|
class FishBody_0(nn.Cell):
|
||||||
|
"""Construct FishBody module.
|
||||||
|
Each instances corresponds to each stages.
|
||||||
|
Args:in_c : Number of channels in the input image
|
||||||
|
out_c : Number of channels in the output image
|
||||||
|
num_blk : Number of Residual Blocks
|
||||||
|
trans_in_c : Number of channels in the transferred image
|
||||||
|
num_trans : Number of Transfer Blocks
|
||||||
|
dilation : Dilation rate of Conv in UpRefinementBlock
|
||||||
|
Forwarding Path:
|
||||||
|
input image - (UpStage) ㄱ
|
||||||
|
trans image - (transfer) --(concat)-- output
|
||||||
|
"""
|
||||||
|
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans, dilation=1):
|
||||||
|
super(FishBody_0, self).__init__()
|
||||||
|
self.layer_FishBody = UpStage(in_c, out_c, num_blk, dilation=dilation)
|
||||||
|
self.add_up = par_ms_0()
|
||||||
|
self.transfer = TransferBlock(trans_in_c, num_trans)
|
||||||
|
self.concat = P.Concat(1)
|
||||||
|
|
||||||
|
def construct(self, x, trans_x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.layer_FishBody(x)
|
||||||
|
x = self.add_up(x)
|
||||||
|
trans_x = self.transfer(trans_x)
|
||||||
|
return self.concat((x, trans_x))
|
||||||
|
|
||||||
|
|
||||||
|
class par_ms_0(nn.Cell):
|
||||||
|
"""par_ms_0"""
|
||||||
|
def __init__(self):
|
||||||
|
super(par_ms_0, self).__init__()
|
||||||
|
self.P_up = P.ResizeNearestNeighbor((14, 14))
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.P_up(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FishBody_1(nn.Cell):
|
||||||
|
"""Construct FishBody module.
|
||||||
|
Each instances corresponds to each stages.
|
||||||
|
Args:in_c : Number of channels in the input image
|
||||||
|
out_c : Number of channels in the output image
|
||||||
|
num_blk : Number of Residual Blocks
|
||||||
|
trans_in_c : Number of channels in the transferred image
|
||||||
|
num_trans : Number of Transfer Blocks
|
||||||
|
dilation : Dilation rate of Conv in UpRefinementBlock
|
||||||
|
Forwarding Path:
|
||||||
|
input image - (UpStage) ㄱ
|
||||||
|
trans image - (transfer) --(concat)-- output
|
||||||
|
"""
|
||||||
|
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans, dilation=1):
|
||||||
|
super(FishBody_1, self).__init__()
|
||||||
|
self.layer_FishBody = UpStage(in_c, out_c, num_blk, dilation=dilation)
|
||||||
|
self.add_up = par_ms_1()
|
||||||
|
self.transfer = TransferBlock(trans_in_c, num_trans)
|
||||||
|
self.concat = P.Concat(1)
|
||||||
|
|
||||||
|
def construct(self, x, trans_x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.layer_FishBody(x)
|
||||||
|
x = self.add_up(x)
|
||||||
|
trans_x = self.transfer(trans_x)
|
||||||
|
return self.concat((x, trans_x))
|
||||||
|
|
||||||
|
|
||||||
|
class par_ms_1(nn.Cell):
|
||||||
|
"""par_ms_1"""
|
||||||
|
def __init__(self):
|
||||||
|
super(par_ms_1, self).__init__()
|
||||||
|
self.P_up = P.ResizeNearestNeighbor((28, 28))
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.P_up(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FishBody_2(nn.Cell):
|
||||||
|
"""Construct FishBody module.
|
||||||
|
Each instances corresponds to each stages.
|
||||||
|
Args:in_c : Number of channels in the input image
|
||||||
|
out_c : Number of channels in the output image
|
||||||
|
num_blk : Number of Residual Blocks
|
||||||
|
trans_in_c : Number of channels in the transferred image
|
||||||
|
num_trans : Number of Transfer Blocks
|
||||||
|
dilation : Dilation rate of Conv in UpRefinementBlock
|
||||||
|
Forwarding Path:
|
||||||
|
input image - (UpStage) ㄱ
|
||||||
|
trans image - (transfer) --(concat)-- output
|
||||||
|
"""
|
||||||
|
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans, dilation=1):
|
||||||
|
super(FishBody_2, self).__init__()
|
||||||
|
self.layer_FishBody = UpStage(in_c, out_c, num_blk, dilation=dilation)
|
||||||
|
self.add_up = par_ms_2()
|
||||||
|
self.transfer = TransferBlock(trans_in_c, num_trans)
|
||||||
|
self.concat = P.Concat(1)
|
||||||
|
|
||||||
|
def construct(self, x, trans_x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.layer_FishBody(x)
|
||||||
|
x = self.add_up(x)
|
||||||
|
trans_x = self.transfer(trans_x)
|
||||||
|
return self.concat((x, trans_x))
|
||||||
|
|
||||||
|
|
||||||
|
class par_ms_2(nn.Cell):
|
||||||
|
"""par_ms_2"""
|
||||||
|
def __init__(self):
|
||||||
|
super(par_ms_2, self).__init__()
|
||||||
|
self.P_up = P.ResizeNearestNeighbor((56, 56))
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.P_up(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FishHead(nn.Cell):
|
||||||
|
"""Construct FishHead module.
|
||||||
|
Each instances corresponds to each stages.
|
||||||
|
Different with Official Code : we used shortcut layer in this Module. (shortcut layer is used according to the
|
||||||
|
original paper)
|
||||||
|
Args:in_c : Number of channels in the input image
|
||||||
|
out_c : Number of channels in the output image
|
||||||
|
num_blk : Number of Residual Blocks
|
||||||
|
trans_in_c : Number of channels in the transferred image
|
||||||
|
num_trans : Number of Transfer Blocks
|
||||||
|
Forwarding Path:
|
||||||
|
input image - (ResBlock) * num_blk - pool ㄱ
|
||||||
|
trans image - (transfer) --(concat)-- output
|
||||||
|
"""
|
||||||
|
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans):
|
||||||
|
super(FishHead, self).__init__()
|
||||||
|
|
||||||
|
self.layer_FishHead = nn.SequentialCell([
|
||||||
|
ResBlock_without_shortcut(in_c, out_c),
|
||||||
|
*[ResBlock_without_shortcut(out_c, out_c) for _ in range(1, num_blk)],
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
|
||||||
|
])
|
||||||
|
self.transfer = TransferBlock(trans_in_c, num_trans)
|
||||||
|
self.concat_ = P.Concat(1)
|
||||||
|
|
||||||
|
def construct(self, x, trans_x):
|
||||||
|
"""construct"""
|
||||||
|
x = self.layer_FishHead(x)
|
||||||
|
trans_x = self.transfer(trans_x)
|
||||||
|
return self.concat_((x, trans_x))
|
||||||
|
|
||||||
|
|
||||||
|
def _conv_bn_relu(in_ch, out_ch, stride=1):
|
||||||
|
return nn.SequentialCell([nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=stride,
|
||||||
|
pad_mode='pad', padding=1, has_bias=False, weight_init=conv_weight_init),
|
||||||
|
nn.BatchNorm2d(num_features=out_ch, momentum=0.9),
|
||||||
|
nn.ReLU()])
|
||||||
|
|
||||||
|
|
||||||
|
class Fishnet(nn.Cell):
|
||||||
|
"""
|
||||||
|
Construct entire networks
|
||||||
|
Args:
|
||||||
|
start_c : Number of channels of input image.
|
||||||
|
Note that it is NOT the number of channels in initial input image, and it IS the number of output
|
||||||
|
channel of stem.
|
||||||
|
num_cls : Number of classes
|
||||||
|
tail_num_blk : list of the numbers of Conv blocks in each FishTail stages
|
||||||
|
body_num_blk : list of the numbers of Conv blocks in each FishBody stages
|
||||||
|
head_num_blk : list of the numbers of Conv blocks in each FishHead stages
|
||||||
|
(Note : `*_num_blk` includes 1 Residual blocks in the start of each stages)
|
||||||
|
body_num_trans : list of the numbers of Conv blocks in transfer paths in each FishTail stages
|
||||||
|
head_num_trans : list of the numbers of Conv blocks in transfer paths in each FishHead stages
|
||||||
|
tail_channels : list of the number of in, out channel of each stages
|
||||||
|
body_channels : list of the number of in, out channel of each stages
|
||||||
|
head_channels : list of the number of in, out channel of each stages
|
||||||
|
"""
|
||||||
|
def __init__(self, start_c=64, num_cls=1000,
|
||||||
|
tail_num_blk=1, bridge_num_blk=2,
|
||||||
|
body_num_blk=1, body_num_trans=1,
|
||||||
|
head_num_blk=1, head_num_trans=1,
|
||||||
|
tail_channels=1, body_channels=1, head_channels=1):
|
||||||
|
super(Fishnet, self).__init__()
|
||||||
|
|
||||||
|
self.stem = nn.SequentialCell([
|
||||||
|
_conv_bn_relu(3, start_c // 2, stride=2),
|
||||||
|
_conv_bn_relu(start_c // 2, start_c // 2),
|
||||||
|
_conv_bn_relu(start_c // 2, start_c),
|
||||||
|
nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||||
|
])
|
||||||
|
print("FishNet Initialization Start")
|
||||||
|
|
||||||
|
self.tail_layer = []
|
||||||
|
for i, num_blk in enumerate(tail_num_blk):
|
||||||
|
layer = FishTail(tail_channels[i], tail_channels[i + 1], num_blk)
|
||||||
|
self.tail_layer.append(layer)
|
||||||
|
self.tail_layer = nn.CellList(self.tail_layer)
|
||||||
|
self.bridge = Bridge(tail_channels[-1], bridge_num_blk)
|
||||||
|
|
||||||
|
self.body_layer = []
|
||||||
|
for i, (num_blk, num_trans) in enumerate(zip(body_num_blk, body_num_trans)):
|
||||||
|
if i == 0:
|
||||||
|
layer = FishBody_0(body_channels[i][0], body_channels[i][1], num_blk,
|
||||||
|
tail_channels[-i - 2], num_trans, dilation=2 ** i)
|
||||||
|
elif i == 1:
|
||||||
|
layer = FishBody_1(body_channels[i][0], body_channels[i][1], num_blk,
|
||||||
|
tail_channels[-i - 2], num_trans, dilation=2 ** i)
|
||||||
|
else:
|
||||||
|
layer = FishBody_2(body_channels[i][0], body_channels[i][1], num_blk,
|
||||||
|
tail_channels[-i - 2], num_trans, dilation=2 ** i)
|
||||||
|
self.body_layer.append(layer)
|
||||||
|
self.body_layer = nn.CellList(self.body_layer)
|
||||||
|
|
||||||
|
self.head_layer = []
|
||||||
|
for i, (num_blk, num_trans) in enumerate(zip(head_num_blk, head_num_trans)):
|
||||||
|
layer = FishHead(head_channels[i][0], head_channels[i][1], num_blk,
|
||||||
|
body_channels[-i - 1][0], num_trans)
|
||||||
|
self.head_layer.append(layer)
|
||||||
|
self.head_layer = nn.CellList(self.head_layer)
|
||||||
|
|
||||||
|
last_c = head_channels[-1][1]
|
||||||
|
self.classifier = nn.SequentialCell([
|
||||||
|
nn.BatchNorm2d(num_features=last_c, momentum=0.9),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(in_channels=last_c, out_channels=last_c // 2, kernel_size=1, pad_mode='pad', has_bias=False
|
||||||
|
, weight_init=conv_weight_init),
|
||||||
|
nn.BatchNorm2d(num_features=last_c // 2, momentum=0.9),
|
||||||
|
nn.ReLU(),
|
||||||
|
adaptiveavgpool2d_ms(),
|
||||||
|
nn.Conv2d(in_channels=last_c // 2, out_channels=num_cls, kernel_size=1, pad_mode='pad', has_bias=True
|
||||||
|
, weight_init=conv_weight_init)
|
||||||
|
])
|
||||||
|
self.squeeze_ = P.Squeeze(2)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""construct"""
|
||||||
|
stem = self.stem(x)
|
||||||
|
tail_features = [stem]
|
||||||
|
for t in self.tail_layer:
|
||||||
|
last_feature = tail_features[-1]
|
||||||
|
tail_features.append(t(last_feature))
|
||||||
|
|
||||||
|
bridge = self.bridge(tail_features[-1])
|
||||||
|
|
||||||
|
body_features = [bridge]
|
||||||
|
for b, tail in zip(self.body_layer, [tail_features[2], tail_features[1], tail_features[0]]):
|
||||||
|
last_feature = body_features[-1]
|
||||||
|
body_features.append(b(last_feature, tail))
|
||||||
|
|
||||||
|
head_features = [body_features[-1]]
|
||||||
|
for h, body in zip(self.head_layer, [body_features[2], body_features[1], body_features[0]]):
|
||||||
|
last_feature = head_features[-1]
|
||||||
|
head_features.append(h(last_feature, body))
|
||||||
|
|
||||||
|
out = self.classifier(head_features[-1])
|
||||||
|
out = self.squeeze_(out)
|
||||||
|
out = self.squeeze_(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_channel(start_c, num_blk):
|
||||||
|
"""
|
||||||
|
Calculate the number of in and out channels of each stages in FishNet.
|
||||||
|
Example:
|
||||||
|
fish150 : start channel=64, num_blk=3,
|
||||||
|
tail channels : Grow double in each stages,
|
||||||
|
[64, 128, 256 ...] = [start channel ** (2**num_blk) ....]
|
||||||
|
body channels : In first stage, in_channel and out_channel is the same,
|
||||||
|
but the other layers, the number of output channels is half of the number of input channel
|
||||||
|
Add the number of transfer channels to the number of output channels
|
||||||
|
The numbers of transfer channels are reverse of the tail channel[:-2]
|
||||||
|
[(512, 512), + 256
|
||||||
|
(768, 384), + 128
|
||||||
|
(512, 256)] + 64
|
||||||
|
head channels : The number of input channels and output channels is the same.
|
||||||
|
Add the number of transfer channels to the number of output channels
|
||||||
|
The numbers of transfer channels are reverse of the tail channel[:-2]
|
||||||
|
[(320, 320), + 512
|
||||||
|
(832, 832), + 768
|
||||||
|
(1600, 1600)] + 512
|
||||||
|
"""
|
||||||
|
# tail channels
|
||||||
|
tail_channels = [start_c]
|
||||||
|
for i in range(num_blk):
|
||||||
|
tail_channels.append(tail_channels[-1] * 2)
|
||||||
|
print("Tail Channels : ", tail_channels)
|
||||||
|
|
||||||
|
# body channels
|
||||||
|
in_c, transfer_c = tail_channels[-1], tail_channels[-2]
|
||||||
|
body_channels = [(in_c, in_c), (in_c + transfer_c, (in_c + transfer_c) // 2)]
|
||||||
|
for i in range(1, num_blk - 1):
|
||||||
|
transfer_c = tail_channels[-i - 2]
|
||||||
|
in_c = body_channels[-1][1] + transfer_c
|
||||||
|
body_channels.append((in_c, in_c // 2))
|
||||||
|
print("Body Channels : ", body_channels)
|
||||||
|
|
||||||
|
# head channels
|
||||||
|
in_c = body_channels[-1][1] + tail_channels[0]
|
||||||
|
head_channels = [(in_c, in_c)]
|
||||||
|
for i in range(num_blk):
|
||||||
|
transfer_c = body_channels[-i - 1][0]
|
||||||
|
in_c = head_channels[-1][1] + transfer_c
|
||||||
|
head_channels.append((in_c, in_c))
|
||||||
|
print("Head Channels : ", head_channels)
|
||||||
|
return {"tail_channels": tail_channels, "body_channels": body_channels, "head_channels": head_channels}
|
||||||
|
|
||||||
|
|
||||||
|
def fish99(num_cls=1000):
|
||||||
|
"""fish99"""
|
||||||
|
start_c = 64
|
||||||
|
# tail
|
||||||
|
tail_num_blk = [2, 2, 6]
|
||||||
|
bridge_num_blk = 2
|
||||||
|
# body
|
||||||
|
body_num_blk = [1, 1, 1]
|
||||||
|
body_num_trans = [1, 1, 1]
|
||||||
|
# head
|
||||||
|
head_num_blk = [1, 2, 2]
|
||||||
|
head_num_trans = [1, 1, 4]
|
||||||
|
|
||||||
|
net_channel = _calc_channel(start_c, len(tail_num_blk))
|
||||||
|
|
||||||
|
return Fishnet(start_c, num_cls,
|
||||||
|
tail_num_blk, bridge_num_blk,
|
||||||
|
body_num_blk, body_num_trans,
|
||||||
|
head_num_blk, head_num_trans,
|
||||||
|
**net_channel)
|
|
@ -0,0 +1,214 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""train"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.communication.management import init, get_rank
|
||||||
|
from mindspore.nn.optim.momentum import Momentum
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.nn.loss.loss import _Loss
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
from src.config import imagenet_cfg
|
||||||
|
from src.dataset import create_dataset_imagenet
|
||||||
|
import src.fishnet as net_ms
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
def lr_steps_imagenet(_cfg, steps_per_epoch):
|
||||||
|
"""lr step for imagenet"""
|
||||||
|
if _cfg.lr_scheduler == 'cosine_annealing':
|
||||||
|
_lr = warmup_cosine_annealing_lr(_cfg.lr_init,
|
||||||
|
steps_per_epoch,
|
||||||
|
_cfg.warmup_epochs,
|
||||||
|
_cfg.epoch_size,
|
||||||
|
_cfg.T_max,
|
||||||
|
_cfg.eta_min)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(_cfg.lr_scheduler)
|
||||||
|
|
||||||
|
return _lr
|
||||||
|
|
||||||
|
|
||||||
|
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||||
|
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||||
|
lr1 = float(init_lr) + lr_inc * current_step
|
||||||
|
return lr1
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_cosine_annealing_lr(lr5, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||||
|
""" warmup cosine annealing lr."""
|
||||||
|
base_lr = lr5
|
||||||
|
warmup_init_lr = 0
|
||||||
|
total_steps = int(max_epoch * steps_per_epoch)
|
||||||
|
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||||
|
|
||||||
|
lr_each_step = []
|
||||||
|
for i in range(total_steps):
|
||||||
|
last_epoch = i // steps_per_epoch
|
||||||
|
if i < warmup_steps:
|
||||||
|
lr5 = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||||
|
else:
|
||||||
|
lr5 = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
|
||||||
|
lr_each_step.append(lr5)
|
||||||
|
|
||||||
|
return np.array(lr_each_step).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropySmooth(_Loss):
|
||||||
|
"""CrossEntropy"""
|
||||||
|
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||||
|
super(CrossEntropySmooth, self).__init__()
|
||||||
|
self.onehot = P.OneHot()
|
||||||
|
self.sparse = sparse
|
||||||
|
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||||
|
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||||
|
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||||
|
|
||||||
|
def construct(self, logit, label):
|
||||||
|
if self.sparse:
|
||||||
|
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||||
|
loss2 = self.ce(logit, label)
|
||||||
|
return loss2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Classification')
|
||||||
|
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
||||||
|
parser.add_argument('--device_type', type=str, default=None, help='GPU or Ascend. (Default: None)')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
cfg = imagenet_cfg
|
||||||
|
|
||||||
|
# set context
|
||||||
|
if not args_opt.device_type:
|
||||||
|
device_target = args_opt.device_type
|
||||||
|
else:
|
||||||
|
device_target = cfg.device_target
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||||
|
context.set_context(enable_graph_kernel=True)
|
||||||
|
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||||
|
|
||||||
|
if device_target == "Ascend":
|
||||||
|
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||||
|
if args_opt.device_id is not None:
|
||||||
|
context.set_context(device_id=args_opt.device_id)
|
||||||
|
else:
|
||||||
|
context.set_context(device_id=cfg.device_id)
|
||||||
|
|
||||||
|
if device_num > 1:
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
init()
|
||||||
|
elif device_target == "GPU":
|
||||||
|
if device_num > 1:
|
||||||
|
init()
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
device_id = get_rank()
|
||||||
|
else:
|
||||||
|
if args_opt.device_id is not None:
|
||||||
|
context.set_context(device_id=args_opt.device_id)
|
||||||
|
else:
|
||||||
|
context.set_context(device_id=cfg.device_id)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported platform.")
|
||||||
|
|
||||||
|
dataset = create_dataset_imagenet(cfg.data_path, 1)
|
||||||
|
|
||||||
|
batch_num = dataset.get_dataset_size()
|
||||||
|
|
||||||
|
net = net_ms.fish99()
|
||||||
|
|
||||||
|
# Continue training if set pre_trained to be True
|
||||||
|
if cfg.pre_trained:
|
||||||
|
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
loss_scale_manager = None
|
||||||
|
|
||||||
|
lr = lr_steps_imagenet(cfg, batch_num)
|
||||||
|
|
||||||
|
|
||||||
|
def get_param_groups(network):
|
||||||
|
""" get param groups. """
|
||||||
|
decay_params = []
|
||||||
|
no_decay_params = []
|
||||||
|
for x in network.trainable_params():
|
||||||
|
parameter_name = x.name
|
||||||
|
if parameter_name.endswith('.bias'):
|
||||||
|
# all bias not using weight decay
|
||||||
|
no_decay_params.append(x)
|
||||||
|
elif parameter_name.endswith('.gamma'):
|
||||||
|
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||||
|
no_decay_params.append(x)
|
||||||
|
elif parameter_name.endswith('.beta'):
|
||||||
|
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||||
|
no_decay_params.append(x)
|
||||||
|
else:
|
||||||
|
decay_params.append(x)
|
||||||
|
|
||||||
|
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||||
|
|
||||||
|
|
||||||
|
if cfg.is_dynamic_loss_scale:
|
||||||
|
cfg.loss_scale = 1
|
||||||
|
|
||||||
|
opt = Momentum(params=get_param_groups(net),
|
||||||
|
learning_rate=Tensor(lr),
|
||||||
|
momentum=cfg.momentum,
|
||||||
|
weight_decay=cfg.weight_decay,
|
||||||
|
loss_scale=cfg.loss_scale)
|
||||||
|
if not cfg.use_label_smooth:
|
||||||
|
cfg.label_smooth_factor = 0.0
|
||||||
|
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||||
|
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||||
|
|
||||||
|
if cfg.is_dynamic_loss_scale == 1:
|
||||||
|
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
||||||
|
else:
|
||||||
|
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
|
||||||
|
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||||
|
amp_level="O3", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager)
|
||||||
|
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 2, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||||
|
time_cb = TimeMonitor(data_size=batch_num)
|
||||||
|
ckpt_save_dir = "./ckpt/"
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix="train_fishnet99_imagenet", directory=ckpt_save_dir,
|
||||||
|
config=config_ck)
|
||||||
|
loss_cb = LossMonitor()
|
||||||
|
cbs = [time_cb, ckpoint_cb, loss_cb]
|
||||||
|
if device_num > 1 and device_id != 0:
|
||||||
|
cbs = [time_cb, loss_cb]
|
||||||
|
model.train(cfg.epoch_size, dataset, callbacks=cbs)
|
||||||
|
print("train success")
|
Loading…
Reference in New Issue