forked from mindspore-Ecosystem/mindspore
!17633 [线上贡献]黄金赛段FishNet99网络精度性能调优提交PR+GPU网络模型征集活动
Merge pull request !17633 from huicui/FishNet99_bold
This commit is contained in:
commit
a656f40f41
|
@ -0,0 +1,313 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [FishNet99描述](#FishNet99描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [特性](#特性)
|
||||
- [混合精度](#混合精度)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [导出过程](#导出过程)
|
||||
- [导出](#导出)
|
||||
- [推理过程](#推理过程)
|
||||
- [推理](#推理)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [ImageNet-1k上的FishNet99](#ImageNet-1k上的FishNet99)
|
||||
- [推理性能](#推理性能)
|
||||
- [ImageNet-1k上的FishNet99](#ImageNet-1k上的FishNet99)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# FishNet99描述
|
||||
|
||||
这是第一个统一为像素级、区域级和图像级任务设计的骨干网络;可以将梯度从非常深的层直接传播到较浅的层;可以保留并互相细化不同深度的特征。
|
||||
|
||||
[论文](http://papers.nips.cc/paper/7356-fishnet-a-versatile-backbone-for-image-region-and-pixel-level-prediction.pdf) :FishNet: a versatile backbone for image, region, and pixel level prediction. In Proceedings of the 32nd International Conference on Neural Information Processing Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 762–772.
|
||||
|
||||
# 模型架构
|
||||
|
||||
整个网络分为tail、body和head三个部分,其中tail是现有的如ResNet等CNN,随着网络的深入,特征分辨率会逐渐减小;body部分有多个上采样和细化块的结构,主要用来细化来自tail和body的特征;head则是有着数个下采样和细化块的结构,用来保留和细化来自tail和body的特征,最后一个卷积层的细化特征被用来处理图像分类等最终任务。
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[ImageNet2012](http://www.image-net.org/)
|
||||
|
||||
- 数据集大小:125G,共1000个类、125万张彩色图像
|
||||
- 训练集:120G,共1,281,167张图像
|
||||
- 测试集:5G,共50,000张图像
|
||||
- 数据格式:RGB
|
||||
- 注:数据将在src/dataset.py中处理
|
||||
- 下载数据集,目录结构如下:
|
||||
|
||||
```text
|
||||
└─dataset
|
||||
├─ILSVRC2012_train # 训练数据集
|
||||
└─ILSVRC2012_val # 评估数据集
|
||||
```
|
||||
|
||||
# 特性
|
||||
|
||||
## 混合精度
|
||||
|
||||
采用[混合精度](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/enable_mixed_precision.html) 的训练方法,使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用Ascend来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
# 运行训练示例
|
||||
python train.py --device_id=0 --device_type='Ascend' > train.log 2>&1 &
|
||||
|
||||
# 运行分布式训练示例
|
||||
bash ./scripts/run_train_ascend.sh [RANK_TABLE_FILE]
|
||||
|
||||
# 运行评估示例
|
||||
python eval.py --checkpoint_path ./ckpt_0 --device_type='Ascend' > ./eval.log 2>&1 &
|
||||
|
||||
# 运行推理示例
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
|
||||
|
||||
请遵循以下链接中的说明:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
为了在GPU处理器环境运行,请将配置文件src/config.py中的device_target从Ascend改为GPU。
|
||||
|
||||
```bash
|
||||
# 运行训练示例
|
||||
python train.py --device_id=0 --device_type='GPU' > train_gpu.log 2>&1 &
|
||||
|
||||
# 运行分布式训练示例
|
||||
bash ./scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
|
||||
|
||||
# 运行评估示例
|
||||
python eval.py --checkpoint_path ./ckpt_0 --device_type='GPU' > ./eval_gpu.log 2>&1 &
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── model_zoo
|
||||
├── README.md // 所有模型相关说明
|
||||
├── fishnet99
|
||||
├── README_CN.md // FishNet99相关说明
|
||||
├── ascend310_infer // 实现310推理源代码
|
||||
├── scripts
|
||||
│ ├──run_eval_gpu.sh // GPU评估的shell脚本
|
||||
│ ├──run_infer_310.sh // Ascend推理的shell脚本
|
||||
│ ├──run_train_ascend.sh // 分布式到Ascend的shell脚本
|
||||
│ ├──run_train_gpu.sh // 分布式到GPU处理器的shell脚本
|
||||
├── src
|
||||
│ ├──config.py // 参数配置
|
||||
│ ├──dataset.py // 创建数据集
|
||||
│ ├──fishnet.py // FishNet99架构
|
||||
├── eval.py // 评估脚本
|
||||
├── export.py // 将checkpoint文件导出到air/mindir
|
||||
├── postprocess.py // 310推理后处理脚本
|
||||
├── train.py // 训练脚本
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
- 配置FishNet99和ImageNet-1k数据集。
|
||||
|
||||
```python
|
||||
'name':'imagenet' # 数据集
|
||||
'pre_trained':'False' # 是否基于预训练模型训练
|
||||
'num_classes':1000 # 数据集类数
|
||||
'lr_init':0.05 # 初始学习率,Ascned单卡训练时设置为0.05,Ascned八卡并行训练时设置为0.4,GPU单卡训练时设置为0.05,GPU双卡并行训练时设置为0.1
|
||||
'batch_size':128 # 训练批次大小
|
||||
'epoch_size':160 # 总计训练epoch数,其中GPU双卡并行训练时设置为110
|
||||
'T_max':150 # 学习率衰减相关参数,其中GPU双卡并行训练时设置为100
|
||||
'momentum':0.9 # 动量
|
||||
'weight_decay':1e-4 # 权重衰减值
|
||||
'image_height':224 # 输入到模型的图像高度
|
||||
'image_width':224 # 输入到模型的图像宽度
|
||||
'data_path':'/data/ILSVRC2012_train/' # 训练数据集的绝对全路径
|
||||
'val_data_path':'/data/ILSVRC2012_val/' # 评估数据集的绝对全路径
|
||||
'device_target':'Ascend' # 运行设备
|
||||
'device_id':0 # 用于训练或评估数据集的设备ID使用run_train.sh进行分布式训练时可以忽略。
|
||||
'keep_checkpoint_max':25 # 最多保存25个ckpt模型文件
|
||||
'checkpoint_path':'./ckpt/train_fishnet99_imagenet-146_10009.ckpt' # checkpoint文件保存的绝对全路径
|
||||
```
|
||||
|
||||
更多配置细节请参考脚本`config.py`。
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
python train.py --device_id=0 --device_type='Ascend' > train.log 2>&1 &
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,可以通过生成的train.log文件查看结果。
|
||||
|
||||
训练结束后,可以在默认脚本文件夹下找到检查点文件,采用以下方式得到损失值:
|
||||
|
||||
```bash
|
||||
# grep "loss is " train.log
|
||||
...
|
||||
epoch: 8 step: 10009, loss is 3.0276418
|
||||
epoch: 9 step: 10009, loss is 3.0397775
|
||||
...
|
||||
```
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
为了在GPU处理器环境运行,请将配置文件src/config.py中的device_target从Ascend改为GPU。
|
||||
|
||||
```bash
|
||||
python train.py --device_id=0 --device_type='GPU' > train_gpu.log 2>&1 &
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,可以通过生成的train_gpu.log文件查看结果。
|
||||
|
||||
训练结束后,可以在默认./ckpt_0/脚本文件夹下找到检查点文件。
|
||||
|
||||
### 分布式训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
bash ./scripts/run_train_ascend.sh [RANK_TABLE_FILE]
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
为了在GPU处理器环境运行,请将配置文件src/config.py中的device_target从Ascend改为GPU。
|
||||
|
||||
```bash
|
||||
bash ./scripts/run_train_gpu.sh 2 0,1
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。可以在生成的train文件夹中查看结果。
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
- 在Ascend环境运行时评估ImageNet-1k数据集
|
||||
|
||||
“./ckpt_0”是保存了训练好的.ckpt模型文件的目录。
|
||||
|
||||
```bash
|
||||
python eval.py --checkpoint_path ./ckpt_0 --device_type='Ascend' > ./eval.log 2>&1 &
|
||||
```
|
||||
|
||||
- 在GPU处理器环境运行时评估ImageNet-1k数据集
|
||||
|
||||
“./ckpt_0”是保存了训练好的.ckpt模型文件的目录。
|
||||
|
||||
```bash
|
||||
python eval.py --checkpoint_path ./ckpt_0 --device_type='GPU' > ./eval_gpu.log 2>&1 &
|
||||
OR
|
||||
bash ./scripts/run_eval.sh
|
||||
```
|
||||
|
||||
## 导出过程
|
||||
|
||||
### 导出
|
||||
|
||||
将checkpoint文件导出成mindir格式模型。
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_FILE]
|
||||
```
|
||||
|
||||
## 推理过程
|
||||
|
||||
### 推理
|
||||
|
||||
在进行推理之前我们需要先导出模型。mindir可以在任意环境上导出,air模型只能在昇腾910环境上导出。以下展示了使用mindir模型执行推理的示例。
|
||||
|
||||
- 在昇腾310上使用ImageNet-1k数据集进行推理
|
||||
|
||||
推理的结果保存在scripts目录下,在acc.log日志文件中可以找到类似以下的结果。
|
||||
|
||||
```shell
|
||||
# Ascend310 inference
|
||||
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||
Total data: 50000, top1 accuracy: 0.78242, top5 accuracy: 0.94042.
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
#### ImageNet-1k上的FishNet99
|
||||
|
||||
| 参数 | Ascend | GPU |
|
||||
| -------------------------- | ----------------------------------------------------------- | -------------------------- |
|
||||
| 模型版本 | FishNet99 | FishNet99 |
|
||||
| 资源 | Ascend 910 | Tesla V100-32G |
|
||||
| 上传日期 | 2021-06-02 | 2021-07-24 |
|
||||
| MindSpore版本 | 1.2.0 | 1.2.0 |
|
||||
| 数据集 | ImageNet2012 | ImageNet2012 |
|
||||
| 训练参数 | epoch=160, batch_size=128, lr_init=0.05(单卡为0.05,八卡为0.4) | epoch=160(单卡160,双卡110), T_max=150(单卡150,双卡100), batch_size=128, lr_init=0.05(单卡0.05,双卡0.1) |
|
||||
| 优化器 | Momentum | Momentum |
|
||||
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
|
||||
| 输出 | 概率 | 概率 |
|
||||
| 分类准确率 | 单卡:top1:78.24%, top5:94.03%;八卡:top1:78.33%, top5:93.96% | 单卡:top1:78.12%, top5:94.13%;双卡:top1:77.97%, top5:93.98% |
|
||||
| 速度 | 单卡:132毫秒/步;八卡:135毫秒/步 | 单卡:227毫秒/步;双卡:450毫秒/步 |
|
||||
| 总时长 | 单卡:58.5小时/160轮;八卡:7.7小时/160轮 | 单卡:109.6小时/160轮;双卡:69.1小时/110轮 |
|
||||
|
||||
### 推理性能
|
||||
|
||||
#### ImageNet-1k上的FishNet99
|
||||
|
||||
| 参数 | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| 模型版本 | FishNet99 |
|
||||
| 资源 | Ascend 310 |
|
||||
| 上传日期 | 2021-06-16 |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | ImageNet2012 |
|
||||
| 分类准确率 | top1:78.24%,top5:94.04% |
|
||||
| 速度 | Average time 5.17187 ms of infer_count 50000 |
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INFERENCE_UTILS_H_
|
||||
#define MINDSPORE_INFERENCE_UTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <dirent.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName);
|
||||
DIR *OpenDir(std::string_view dirName);
|
||||
std::string RealPath(std::string_view path);
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file);
|
||||
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
|
||||
std::vector<std::string> GetAllFiles(std::string dir_name);
|
||||
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,14 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(MindSporeCxxTestcase[CXX])
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
|
||||
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
option(MINDSPORE_PATH "mindspore install path" "")
|
||||
include_directories(${MINDSPORE_PATH})
|
||||
include_directories(${MINDSPORE_PATH}/include)
|
||||
include_directories(${PROJECT_SRC_ROOT}/../)
|
||||
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
|
||||
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
|
||||
|
||||
add_executable(main main.cc utils.cc)
|
||||
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
cmake . -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
|
||||
make
|
|
@ -0,0 +1,146 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <sys/time.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/serialization.h"
|
||||
#include "minddata/dataset/include/vision_ascend.h"
|
||||
#include "minddata/dataset/include/execute.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::dataset::vision::Decode;
|
||||
using mindspore::dataset::vision::Resize;
|
||||
using mindspore::dataset::vision::CenterCrop;
|
||||
using mindspore::dataset::vision::Normalize;
|
||||
using mindspore::dataset::vision::HWC2CHW;
|
||||
using mindspore::dataset::TensorTransform;
|
||||
using mindspore::Context;
|
||||
using mindspore::Serialization;
|
||||
using mindspore::Model;
|
||||
using mindspore::Status;
|
||||
using mindspore::ModelType;
|
||||
using mindspore::GraphCell;
|
||||
using mindspore::kSuccess;
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::dataset::Execute;
|
||||
|
||||
|
||||
DEFINE_string(mindir_path, "", "mindir path");
|
||||
DEFINE_string(dataset_path, ".", "dataset path");
|
||||
DEFINE_int32(device_id, 0, "device id");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
if (RealPath(FLAGS_mindir_path).empty()) {
|
||||
std::cout << "Invalid mindir" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
ascend310->SetDeviceID(FLAGS_device_id);
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
mindspore::Graph graph;
|
||||
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
||||
Model model;
|
||||
Status ret = model.Build(GraphCell(graph), context);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "ERROR: Build failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto all_files = GetAllInputData(FLAGS_dataset_path);
|
||||
if (all_files.empty()) {
|
||||
std::cout << "ERROR: no input data." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::map<double, double> costTime_map;
|
||||
size_t size = all_files.size();
|
||||
// Define transform
|
||||
std::vector<int32_t> crop_paras = {224};
|
||||
std::vector<int32_t> resize_paras = {256};
|
||||
std::vector<float> mean = {0.485 * 255, 0.456 * 255, 0.406 * 255};
|
||||
std::vector<float> std = {0.229 * 255, 0.224 * 255, 0.225 * 255};
|
||||
|
||||
auto decode = Decode();
|
||||
auto resize = Resize(resize_paras);
|
||||
auto centercrop = CenterCrop(crop_paras);
|
||||
auto normalize = Normalize(mean, std);
|
||||
auto hwc2chw = HWC2CHW();
|
||||
|
||||
mindspore::dataset::Execute SingleOp({decode, resize, centercrop, normalize, hwc2chw});
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
for (size_t j = 0; j < all_files[i].size(); ++j) {
|
||||
struct timeval start = {0};
|
||||
struct timeval end = {0};
|
||||
double startTimeMs;
|
||||
double endTimeMs;
|
||||
std::vector<MSTensor> inputs;
|
||||
std::vector<MSTensor> outputs;
|
||||
std::cout << "Start predict input files:" << all_files[i][j] <<std::endl;
|
||||
auto imgDvpp = std::make_shared<MSTensor>();
|
||||
SingleOp(ReadFileToTensor(all_files[i][j]), imgDvpp.get());
|
||||
|
||||
inputs.emplace_back(imgDvpp->Name(), imgDvpp->DataType(), imgDvpp->Shape(),
|
||||
imgDvpp->Data().get(), imgDvpp->DataSize());
|
||||
gettimeofday(&start, nullptr);
|
||||
ret = model.Predict(inputs, &outputs);
|
||||
gettimeofday(&end, nullptr);
|
||||
if (ret != kSuccess) {
|
||||
std::cout << "Predict " << all_files[i][j] << " failed." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
|
||||
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
|
||||
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
|
||||
WriteResult(all_files[i][j], outputs);
|
||||
}
|
||||
}
|
||||
double average = 0.0;
|
||||
int inferCount = 0;
|
||||
|
||||
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
|
||||
double diff = 0.0;
|
||||
diff = iter->second - iter->first;
|
||||
average += diff;
|
||||
inferCount++;
|
||||
}
|
||||
average = average / inferCount;
|
||||
std::stringstream timeCost;
|
||||
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
|
||||
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
|
||||
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
|
||||
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
|
||||
fileStream << timeCost.str();
|
||||
fileStream.close();
|
||||
costTime_map.clear();
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "inc/utils.h"
|
||||
|
||||
using mindspore::MSTensor;
|
||||
using mindspore::DataType;
|
||||
|
||||
|
||||
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name) {
|
||||
std::vector<std::vector<std::string>> ret;
|
||||
|
||||
DIR *dir = OpenDir(dir_name);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
struct dirent *filename;
|
||||
/* read all the files in the dir ~ */
|
||||
std::vector<std::string> sub_dirs;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string d_name = std::string(filename->d_name);
|
||||
// get rid of "." and ".."
|
||||
if (d_name == "." || d_name == ".." || d_name.empty()) {
|
||||
continue;
|
||||
}
|
||||
std::string dir_path = RealPath(std::string(dir_name) + "/" + filename->d_name);
|
||||
struct stat s;
|
||||
lstat(dir_path.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
sub_dirs.emplace_back(dir_path);
|
||||
}
|
||||
std::sort(sub_dirs.begin(), sub_dirs.end());
|
||||
|
||||
(void)std::transform(sub_dirs.begin(), sub_dirs.end(), std::back_inserter(ret),
|
||||
[](const std::string &d) { return GetAllFiles(d); });
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string dir_name) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dir_name);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string d_name = std::string(filename->d_name);
|
||||
if (d_name == "." || d_name == ".." || d_name.size() <= 3) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::string> GetAllFiles(std::string_view dirName) {
|
||||
struct dirent *filename;
|
||||
DIR *dir = OpenDir(dirName);
|
||||
if (dir == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
while ((filename = readdir(dir)) != nullptr) {
|
||||
std::string dName = std::string(filename->d_name);
|
||||
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
|
||||
continue;
|
||||
}
|
||||
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
|
||||
}
|
||||
std::sort(res.begin(), res.end());
|
||||
for (auto &f : res) {
|
||||
std::cout << "image file: " << f << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
|
||||
std::string homePath = "./result_Files";
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
size_t outputSize;
|
||||
std::shared_ptr<const void> netOutput;
|
||||
netOutput = outputs[i].Data();
|
||||
outputSize = outputs[i].DataSize();
|
||||
int pos = imageFile.rfind('/');
|
||||
std::string fileName(imageFile, pos + 1);
|
||||
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
|
||||
std::string outFileName = homePath + "/" + fileName;
|
||||
FILE *outputFile = fopen(outFileName.c_str(), "wb");
|
||||
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
|
||||
fclose(outputFile);
|
||||
outputFile = nullptr;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
std::cout << "Pointer file is nullptr" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
std::cout << "File: " << file << " is not exist" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "File: " << file << "open failed" << std::endl;
|
||||
return mindspore::MSTensor();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
||||
DIR *OpenDir(std::string_view dirName) {
|
||||
if (dirName.empty()) {
|
||||
std::cout << " dirName is null ! " << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = RealPath(dirName);
|
||||
struct stat s;
|
||||
lstat(realPath.c_str(), &s);
|
||||
if (!S_ISDIR(s.st_mode)) {
|
||||
std::cout << "dirName is not a valid directory !" << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
DIR *dir;
|
||||
dir = opendir(realPath.c_str());
|
||||
if (dir == nullptr) {
|
||||
std::cout << "Can not open dir " << dirName << std::endl;
|
||||
return nullptr;
|
||||
}
|
||||
std::cout << "Successfully opened the dir " << dirName << std::endl;
|
||||
return dir;
|
||||
}
|
||||
|
||||
std::string RealPath(std::string_view path) {
|
||||
char realPathMem[PATH_MAX] = {0};
|
||||
char *realPathRet = nullptr;
|
||||
realPathRet = realpath(path.data(), realPathMem);
|
||||
if (realPathRet == nullptr) {
|
||||
std::cout << "File: " << path << " is not exist.";
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string realPath(realPathMem);
|
||||
std::cout << path << " realpath is: " << realPath << std::endl;
|
||||
return realPath;
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""create_imagenet2012_label"""
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="resnet imagenet2012 label")
|
||||
parser.add_argument("--img_path", type=str, required=True, help="imagenet2012 file path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def create_label(file_path):
|
||||
"""
|
||||
create label
|
||||
"""
|
||||
print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!")
|
||||
dirs = os.listdir(file_path)
|
||||
file_list = []
|
||||
for file in dirs:
|
||||
file_list.append(file)
|
||||
file_list = sorted(file_list)
|
||||
|
||||
total = 0
|
||||
img_label = {}
|
||||
for i, file_dir in enumerate(file_list):
|
||||
files = os.listdir(os.path.join(file_path, file_dir))
|
||||
for f in files:
|
||||
img_label[f] = i
|
||||
total += len(files)
|
||||
|
||||
with open("imagenet_label.json", "w+") as label:
|
||||
json.dump(img_label, label)
|
||||
|
||||
print("[INFO] Completed! Total {} data.".format(total))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_label(args.img_path)
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python eval.py
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from src.config import imagenet_cfg
|
||||
from src.dataset import create_dataset_imagenet
|
||||
|
||||
import src.fishnet as net_ms
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='fishnet99')
|
||||
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
parser.add_argument('--device_type', type=str, default=None, help='GPU or Ascend. (Default: None)')
|
||||
parser.add_argument('--checkpoint_path', type=str, default='./ckpt_0', help='Checkpoint file path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
|
||||
class CrossEntropySmooth(_Loss):
|
||||
"""CrossEntropy"""
|
||||
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropySmooth, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.sparse = sparse
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||
|
||||
def construct(self, logit, label):
|
||||
if self.sparse:
|
||||
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss_ = self.ce(logit, label)
|
||||
return loss_
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if args_opt.dataset_name == "imagenet":
|
||||
cfg = imagenet_cfg
|
||||
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
|
||||
if not cfg.use_label_smooth:
|
||||
cfg.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
net = net_ms.fish99()
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
|
||||
if not args_opt.device_type:
|
||||
device_target = args_opt.device_type
|
||||
else:
|
||||
device_target = cfg.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
|
||||
file_list = os.listdir(args_opt.checkpoint_path)
|
||||
for filename in file_list:
|
||||
de_path = os.path.join(args_opt.checkpoint_path, filename)
|
||||
if de_path.endswith('.ckpt'):
|
||||
param_dict = load_checkpoint(de_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
acc = model.eval(dataset)
|
||||
print(f"model {de_path}'s accuracy is {acc}")
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx or mindir model#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
import src.fishnet as net_ms
|
||||
|
||||
parser = argparse.ArgumentParser(description='FishNet99 export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="FishNet99", help="output file name.")
|
||||
parser.add_argument('--width', type=int, default=224, help='input width')
|
||||
parser.add_argument('--height', type=int, default=224, help='input height')
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
net = net_ms.fish99()
|
||||
|
||||
assert args.ckpt_file is not None, "checkpoint_path is None."
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""post process for 310 inference"""
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
from src.config import imagenet_cfg
|
||||
|
||||
batch_size = 1
|
||||
parser = argparse.ArgumentParser(description="resnet inference")
|
||||
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
|
||||
parser.add_argument("--label_path", type=str, required=True, help="image file path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def get_result(result_path, label_path):
|
||||
"""
|
||||
get result.
|
||||
"""
|
||||
files = os.listdir(result_path)
|
||||
with open(label_path, "r") as label:
|
||||
labels = json.load(label)
|
||||
|
||||
top1 = 0
|
||||
top5 = 0
|
||||
total_data = len(files)
|
||||
for file in files:
|
||||
img_ids_name = file.split('_0.')[0]
|
||||
data_path = os.path.join(result_path, img_ids_name + "_0.bin")
|
||||
result = np.fromfile(data_path, dtype=np.float32).reshape(batch_size, imagenet_cfg.num_classes)
|
||||
for batch in range(batch_size):
|
||||
predict = np.argsort(-result[batch], axis=-1)
|
||||
if labels[img_ids_name+".JPEG"] == predict[0]:
|
||||
top1 += 1
|
||||
if labels[img_ids_name+".JPEG"] in predict[:5]:
|
||||
top5 += 1
|
||||
print(f"Total data: {total_data}, top1 accuracy: {top1/total_data}, top5 accuracy: {top5/total_data}.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
get_result(args.result_path, args.label_path)
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: bash run_eval.sh [CKPT_LOCATION]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $1 ]
|
||||
then
|
||||
echo "error: CKPT_LOCATION=$1 is not files' location"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python eval.py --checkpoint_path=$1 --device_type='GPU' > ./eval.log 2>&1 &
|
|
@ -0,0 +1,99 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [[ $# -lt 2 || $# -gt 3 ]]; then
|
||||
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
|
||||
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
model=$(get_real_path $1)
|
||||
data_path=$(get_real_path $2)
|
||||
|
||||
device_id=0
|
||||
if [ $# == 3 ]; then
|
||||
device_id=$3
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "device id: "$device_id
|
||||
|
||||
export ASCEND_HOME=/usr/local/Ascend/
|
||||
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
|
||||
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
|
||||
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
|
||||
else
|
||||
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
|
||||
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
|
||||
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
|
||||
fi
|
||||
|
||||
function compile_app()
|
||||
{
|
||||
cd ../ascend310_infer/src/ || exit
|
||||
if [ -f "Makefile" ]; then
|
||||
make clean
|
||||
fi
|
||||
sh build.sh &> build.log
|
||||
}
|
||||
|
||||
function infer()
|
||||
{
|
||||
cd - || exit
|
||||
if [ -d result_Files ]; then
|
||||
rm -rf ./result_Files
|
||||
fi
|
||||
if [ -d time_Result ]; then
|
||||
rm -rf ./time_Result
|
||||
fi
|
||||
mkdir result_Files
|
||||
mkdir time_Result
|
||||
../ascend310_infer/src/main --mindir_path=$model --dataset_path=$data_path --device_id=$device_id &> infer.log
|
||||
}
|
||||
|
||||
function cal_acc()
|
||||
{
|
||||
python3.7 ../create_imagenet2012_label.py --img_path=$data_path
|
||||
python3.7 ../postprocess.py --result_path=./result_Files --label_path=./imagenet_label.json &> acc.log &
|
||||
}
|
||||
|
||||
compile_app
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "compile app code failed"
|
||||
exit 1
|
||||
fi
|
||||
infer
|
||||
if [ $? -ne 0 ]; then
|
||||
echo " execute inference failed"
|
||||
exit 1
|
||||
fi
|
||||
cal_acc
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "calculate accuracy failed"
|
||||
exit 1
|
||||
fi
|
|
@ -0,0 +1,51 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
RANK_TABLE_FILE=$(realpath $1)
|
||||
export RANK_TABLE_FILE
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
rank_start=$(DEVICE_NUM)
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py --device_id=$i --device_type='Ascend' > log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,52 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -lt 2 ]
|
||||
then
|
||||
echo "Usage: \
|
||||
bash run_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]\
|
||||
"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $1 -lt 1 ] && [ $1 -gt 8 ]
|
||||
then
|
||||
echo "error: DEVICE_NUM=$1 is not in (1-8)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=$1
|
||||
export RANK_SIZE=$1
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
if [ -d "./train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cd ./train || exit
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="$2"
|
||||
|
||||
|
||||
if [ $1 -gt 1 ]
|
||||
then
|
||||
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
|
||||
python3 ${BASEPATH}/../train.py --device_type='GPU' > train_gpu.log 2>&1 &
|
||||
else
|
||||
python3 ${BASEPATH}/../train.py --device_type='GPU' > train_gpu.log 2>&1 &
|
||||
fi
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""from googlenet"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
imagenet_cfg = edict({
|
||||
'name': 'imagenet',
|
||||
'pre_trained': False,
|
||||
'num_classes': 1000,
|
||||
'lr_init': 0.05, # Ascend_1P: 0.05, Ascend_8P: 0.4, GPU_1P: 0.05, GPU_2P: 0.1
|
||||
'batch_size': 128,
|
||||
'epoch_size': 160, # GPU_2P: 110
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 1e-4,
|
||||
'image_height': 224,
|
||||
'image_width': 224,
|
||||
'data_path': '/data/ILSVRC2012_train/',
|
||||
'val_data_path': '/data/ILSVRC2012_val/',
|
||||
'device_target': 'Ascend',
|
||||
'device_id': 0,
|
||||
'keep_checkpoint_max': 30,
|
||||
'checkpoint_path': None,
|
||||
'onnx_filename': 'fishnet99',
|
||||
'air_filename': 'fishnet99',
|
||||
|
||||
# optimizer and lr related
|
||||
'lr_scheduler': 'cosine_annealing',
|
||||
'lr_epochs': [30, 60, 90, 120],
|
||||
'lr_gamma': 0.3,
|
||||
'eta_min': 0.0,
|
||||
'T_max': 150, # GPU_2P: 100
|
||||
'warmup_epochs': 0,
|
||||
|
||||
# loss related
|
||||
'is_dynamic_loss_scale': 0,
|
||||
'loss_scale': 1024,
|
||||
'label_smooth_factor': 0.1,
|
||||
'use_label_smooth': True,
|
||||
})
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from src.config import imagenet_cfg
|
||||
|
||||
|
||||
def create_dataset_imagenet(dataset_path, repeat_num=1, training=True, num_parallel_workers=16, shuffle=None):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for resnet50
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
|
||||
device_num, rank_id = _get_rank_info()
|
||||
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
|
||||
image_size = imagenet_cfg.image_height
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# define map operations
|
||||
if training:
|
||||
transform_img = [
|
||||
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
vision.RandomHorizontalFlip(prob=0.5),
|
||||
vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1),
|
||||
vision.Normalize(mean=mean, std=std),
|
||||
vision.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
transform_img = [
|
||||
vision.Decode(),
|
||||
vision.Resize(256),
|
||||
vision.CenterCrop(image_size),
|
||||
vision.Normalize(mean=mean, std=std),
|
||||
vision.HWC2CHW()
|
||||
]
|
||||
|
||||
transform_label = [C.TypeCast(mstype.int32)]
|
||||
|
||||
data_set = data_set.map(input_columns="image", num_parallel_workers=12, operations=transform_img)
|
||||
data_set = data_set.map(input_columns="label", num_parallel_workers=4, operations=transform_label)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
rank_size = rank_id = None
|
||||
|
||||
return rank_size, rank_id
|
|
@ -0,0 +1,599 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
FishNet model of MindSpore-1.2.0.
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
conv_weight_init = 'HeUniform'
|
||||
|
||||
|
||||
class adaptiveavgpool2d_ms(nn.Cell):
|
||||
"""adaptiveavgpool2d_ms"""
|
||||
def __init__(self):
|
||||
super(adaptiveavgpool2d_ms, self).__init__()
|
||||
self.ada_pool = P.ReduceMean(keep_dims=True)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
x = self.ada_pool(x, (2, 3))
|
||||
return x
|
||||
|
||||
|
||||
def _bn_relu_conv(in_c, out_c, **conv_kwargs):
|
||||
return nn.SequentialCell([nn.BatchNorm2d(num_features=in_c, momentum=0.9),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=in_c, out_channels=out_c, pad_mode='pad',
|
||||
weight_init=conv_weight_init, **conv_kwargs),
|
||||
])
|
||||
|
||||
|
||||
class ResBlock_with_shortcut(nn.Cell):
|
||||
"""
|
||||
Construct Basic Bottle-necked Residual Block module.
|
||||
Args:
|
||||
in_c : Number of channels in the input image
|
||||
out_c : Number of channels in the output image
|
||||
shortcut : Specific function for skip-connection
|
||||
Examples)
|
||||
'bn_relu_conv' for DownRefinementBlock
|
||||
'bn_relu_conv with channel reduction' for UpRefinementBlock
|
||||
'identity mapping' for regular connection
|
||||
stride : Stride of middle conv layer
|
||||
dilation : Dilation rate of middle conv layer
|
||||
Forwarding Path:
|
||||
⎡ (shortcut) ⎤
|
||||
input image - (BN-ReLU-Conv) * 3 - (add) -output
|
||||
"""
|
||||
def __init__(self, in_c, out_c, shortcut, stride=1, dilation=1):
|
||||
super(ResBlock_with_shortcut, self).__init__()
|
||||
|
||||
mid_c = out_c // 4
|
||||
self.layers = nn.SequentialCell([
|
||||
_bn_relu_conv(in_c, mid_c, kernel_size=1, has_bias=False),
|
||||
_bn_relu_conv(mid_c, mid_c, kernel_size=3, stride=stride, padding=dilation, dilation=dilation,
|
||||
has_bias=False),
|
||||
_bn_relu_conv(mid_c, out_c, kernel_size=1, has_bias=False),
|
||||
])
|
||||
self.shortcut = shortcut
|
||||
self.add_1 = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
return self.add_1(self.layers(x), self.shortcut(x))
|
||||
|
||||
|
||||
class ResBlock_without_shortcut(nn.Cell):
|
||||
"""
|
||||
Construct Basic Bottle-necked Residual Block module.
|
||||
Args:
|
||||
in_c : Number of channels in the input image
|
||||
out_c : Number of channels in the output image
|
||||
stride : Stride of middle conv layer
|
||||
dilation : Dilation rate of middle conv layer
|
||||
Forwarding Path:
|
||||
⎡ (shortcut) ⎤
|
||||
input image - (BN-ReLU-Conv) * 3 - (add) -output
|
||||
"""
|
||||
def __init__(self, in_c, out_c, stride=1, dilation=1):
|
||||
super(ResBlock_without_shortcut, self).__init__()
|
||||
|
||||
mid_c = out_c // 4
|
||||
self.layers_ = nn.SequentialCell([
|
||||
_bn_relu_conv(in_c, mid_c, kernel_size=1, has_bias=False),
|
||||
_bn_relu_conv(mid_c, mid_c, kernel_size=3, stride=stride, padding=dilation, dilation=dilation,
|
||||
has_bias=False),
|
||||
_bn_relu_conv(mid_c, out_c, kernel_size=1, has_bias=False),
|
||||
])
|
||||
self.add_2 = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
return self.add_2(self.layers_(x), x)
|
||||
|
||||
|
||||
class TransferBlock(nn.Cell):
|
||||
"""
|
||||
Construct Transfer Block module.
|
||||
Args:
|
||||
ch : Number of channels in the input and output image
|
||||
num_blk : Number of Residual Blocks
|
||||
Forwarding Path:
|
||||
input image - (ResBlock_without_shortcut) * num_blk - output
|
||||
"""
|
||||
def __init__(self, ch, num_blk):
|
||||
super(TransferBlock, self).__init__()
|
||||
|
||||
self.layers_TransferBlock = nn.SequentialCell([*[ResBlock_without_shortcut(ch, ch)
|
||||
for _ in range(0, num_blk)]])
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
return self.layers_TransferBlock(x)
|
||||
|
||||
|
||||
class DownStage(nn.Cell):
|
||||
"""
|
||||
Construct a stage for each resolution.
|
||||
A DownStage is consisted of one DownRefinementBlock and several residual regular connection blocks.
|
||||
(Note: In fact, DownRefinementBlock is not used in FishHead according to original implementation.
|
||||
However, it seems needed to be used according to original paper.
|
||||
In this version, we followed original implementation, not original paper.)
|
||||
Args:in_c : Number of channels in the input image
|
||||
out_c : Number of channels in the output image
|
||||
num_blk : Number of Residual Blocks
|
||||
stride : Stride of shortcut conv layer
|
||||
Forwarding Path:input image - (ResBlock with Shortcut) - (ResBlock) * num_blk - (MaxPool) - output
|
||||
"""
|
||||
def __init__(self, in_c, out_c, num_blk, stride=1):
|
||||
super(DownStage, self).__init__()
|
||||
|
||||
shortcut = _bn_relu_conv(in_c, out_c, kernel_size=1, stride=stride, has_bias=False)
|
||||
self.layer_DownStage = nn.SequentialCell([
|
||||
ResBlock_with_shortcut(in_c, out_c, shortcut),
|
||||
*[ResBlock_without_shortcut(out_c, out_c) for _ in range(1, num_blk)],
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
|
||||
])
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
return self.layer_DownStage(x)
|
||||
|
||||
|
||||
class UpStage(nn.Cell):
|
||||
"""
|
||||
Construct a stage for each resolution.
|
||||
A DownStage is consisted of one DownRefinementBlock and several residual regular connection blocks.
|
||||
Not like DownStage, this module reduces the number of channels of concatenated feature maps in the shortcut path.
|
||||
Args:in_c : Number of channels in the input image
|
||||
out_c : Number of channels in the output image
|
||||
num_blk : Number of Residual Blocks
|
||||
stride : Stride of shortcut conv layer
|
||||
Forwarding Path:input image - (ResBlock with Channel Reduction) - (ResBlock) * num_blk - (UpSample) - output
|
||||
"""
|
||||
def __init__(self, in_c, out_c, num_blk, dilation=1):
|
||||
super(UpStage, self).__init__()
|
||||
|
||||
self.k = in_c // out_c
|
||||
self.redece_sum = P.ReduceSum(keep_dims=False)
|
||||
self.shape_ = P.Shape()
|
||||
self.reshape_ = P.Reshape()
|
||||
self.layer_UpStage = nn.SequentialCell([
|
||||
ResBlock_with_shortcut(in_c, out_c, channel_reduction_ms(in_c // out_c), dilation=dilation),
|
||||
*[ResBlock_without_shortcut(out_c, out_c, dilation=dilation) for _ in range(1, num_blk)],
|
||||
])
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
return self.layer_UpStage(x)
|
||||
|
||||
|
||||
class channel_reduction_ms(nn.Cell):
|
||||
"""channel_reduction_ms"""
|
||||
def __init__(self, kk):
|
||||
super(channel_reduction_ms, self).__init__()
|
||||
self.shape_ = P.Shape()
|
||||
self.kk = kk
|
||||
self.reshape_ = P.Reshape()
|
||||
self.redece_sum = P.ReduceSum(keep_dims=False)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
n, c, h_, w_ = self.shape_(x)
|
||||
x = self.redece_sum(self.reshape_(x, (n, c // self.kk, self.kk, h_, w_)), 2)
|
||||
return x
|
||||
|
||||
|
||||
class FishTail(nn.Cell):
|
||||
"""
|
||||
Construct FishTail module.
|
||||
Each instances corresponds to each stages.
|
||||
Args:
|
||||
in_c : Number of channels in the input image
|
||||
out_c : Number of channels in the output image
|
||||
num_blk : Number of Residual Blocks
|
||||
Forwarding Path:
|
||||
input image - (DownStage) - output
|
||||
"""
|
||||
def __init__(self, in_c, out_c, num_blk):
|
||||
super(FishTail, self).__init__()
|
||||
|
||||
self.layer_FishTail = DownStage(in_c, out_c, num_blk)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
x = self.layer_FishTail(x)
|
||||
return x
|
||||
|
||||
|
||||
class Bridge(nn.Cell):
|
||||
"""
|
||||
Construct Bridge module.
|
||||
This module bridges the last FishTail stage and first FishBody stage.
|
||||
Args:ch : Number of channels in the input and output image
|
||||
num_blk : Number of Residual Blocks
|
||||
Forwarding Path:
|
||||
r (SEBlock) ㄱ
|
||||
input image - (stem) - (ResBlock with Shortcut) - (ResBlock) * num_blk - (mul & sum) - output
|
||||
"""
|
||||
def __init__(self, ch, num_blk):
|
||||
super(Bridge, self).__init__()
|
||||
|
||||
self.stem = nn.SequentialCell([
|
||||
nn.BatchNorm2d(num_features=ch, momentum=0.9),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=ch, out_channels=ch // 2, kernel_size=1, pad_mode='pad', has_bias=False,
|
||||
weight_init=conv_weight_init),
|
||||
nn.BatchNorm2d(num_features=ch // 2, momentum=0.9),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=ch // 2, out_channels=ch * 2, kernel_size=1, pad_mode='pad', has_bias=True,
|
||||
weight_init=conv_weight_init)
|
||||
])
|
||||
shortcut = _bn_relu_conv(ch * 2, ch, kernel_size=1, has_bias=False)
|
||||
self.layers_Bridge = nn.SequentialCell([
|
||||
ResBlock_with_shortcut(ch * 2, ch, shortcut),
|
||||
*[ResBlock_without_shortcut(ch, ch) for _ in range(1, num_blk)],
|
||||
])
|
||||
|
||||
self.se_block = nn.SequentialCell([
|
||||
nn.BatchNorm2d(num_features=ch * 2, momentum=0.9),
|
||||
nn.ReLU(),
|
||||
adaptiveavgpool2d_ms(),
|
||||
nn.Conv2d(in_channels=ch * 2, out_channels=ch // 16, kernel_size=1, pad_mode='pad', has_bias=True
|
||||
, weight_init=conv_weight_init),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=ch // 16, out_channels=ch, kernel_size=1, pad_mode='pad', has_bias=True
|
||||
, weight_init=conv_weight_init),
|
||||
nn.Sigmoid()
|
||||
])
|
||||
self.add_3 = P.Add()
|
||||
self.mul_ = P.Mul()
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
x = self.stem(x)
|
||||
att = self.se_block(x)
|
||||
out = self.layers_Bridge(x)
|
||||
return self.add_3(self.mul_(out, att), att)
|
||||
|
||||
|
||||
class FishBody_0(nn.Cell):
|
||||
"""Construct FishBody module.
|
||||
Each instances corresponds to each stages.
|
||||
Args:in_c : Number of channels in the input image
|
||||
out_c : Number of channels in the output image
|
||||
num_blk : Number of Residual Blocks
|
||||
trans_in_c : Number of channels in the transferred image
|
||||
num_trans : Number of Transfer Blocks
|
||||
dilation : Dilation rate of Conv in UpRefinementBlock
|
||||
Forwarding Path:
|
||||
input image - (UpStage) ㄱ
|
||||
trans image - (transfer) --(concat)-- output
|
||||
"""
|
||||
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans, dilation=1):
|
||||
super(FishBody_0, self).__init__()
|
||||
self.layer_FishBody = UpStage(in_c, out_c, num_blk, dilation=dilation)
|
||||
self.add_up = par_ms_0()
|
||||
self.transfer = TransferBlock(trans_in_c, num_trans)
|
||||
self.concat = P.Concat(1)
|
||||
|
||||
def construct(self, x, trans_x):
|
||||
"""construct"""
|
||||
x = self.layer_FishBody(x)
|
||||
x = self.add_up(x)
|
||||
trans_x = self.transfer(trans_x)
|
||||
return self.concat((x, trans_x))
|
||||
|
||||
|
||||
class par_ms_0(nn.Cell):
|
||||
"""par_ms_0"""
|
||||
def __init__(self):
|
||||
super(par_ms_0, self).__init__()
|
||||
self.P_up = P.ResizeNearestNeighbor((14, 14))
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
x = self.P_up(x)
|
||||
return x
|
||||
|
||||
|
||||
class FishBody_1(nn.Cell):
|
||||
"""Construct FishBody module.
|
||||
Each instances corresponds to each stages.
|
||||
Args:in_c : Number of channels in the input image
|
||||
out_c : Number of channels in the output image
|
||||
num_blk : Number of Residual Blocks
|
||||
trans_in_c : Number of channels in the transferred image
|
||||
num_trans : Number of Transfer Blocks
|
||||
dilation : Dilation rate of Conv in UpRefinementBlock
|
||||
Forwarding Path:
|
||||
input image - (UpStage) ㄱ
|
||||
trans image - (transfer) --(concat)-- output
|
||||
"""
|
||||
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans, dilation=1):
|
||||
super(FishBody_1, self).__init__()
|
||||
self.layer_FishBody = UpStage(in_c, out_c, num_blk, dilation=dilation)
|
||||
self.add_up = par_ms_1()
|
||||
self.transfer = TransferBlock(trans_in_c, num_trans)
|
||||
self.concat = P.Concat(1)
|
||||
|
||||
def construct(self, x, trans_x):
|
||||
"""construct"""
|
||||
x = self.layer_FishBody(x)
|
||||
x = self.add_up(x)
|
||||
trans_x = self.transfer(trans_x)
|
||||
return self.concat((x, trans_x))
|
||||
|
||||
|
||||
class par_ms_1(nn.Cell):
|
||||
"""par_ms_1"""
|
||||
def __init__(self):
|
||||
super(par_ms_1, self).__init__()
|
||||
self.P_up = P.ResizeNearestNeighbor((28, 28))
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
x = self.P_up(x)
|
||||
return x
|
||||
|
||||
|
||||
class FishBody_2(nn.Cell):
|
||||
"""Construct FishBody module.
|
||||
Each instances corresponds to each stages.
|
||||
Args:in_c : Number of channels in the input image
|
||||
out_c : Number of channels in the output image
|
||||
num_blk : Number of Residual Blocks
|
||||
trans_in_c : Number of channels in the transferred image
|
||||
num_trans : Number of Transfer Blocks
|
||||
dilation : Dilation rate of Conv in UpRefinementBlock
|
||||
Forwarding Path:
|
||||
input image - (UpStage) ㄱ
|
||||
trans image - (transfer) --(concat)-- output
|
||||
"""
|
||||
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans, dilation=1):
|
||||
super(FishBody_2, self).__init__()
|
||||
self.layer_FishBody = UpStage(in_c, out_c, num_blk, dilation=dilation)
|
||||
self.add_up = par_ms_2()
|
||||
self.transfer = TransferBlock(trans_in_c, num_trans)
|
||||
self.concat = P.Concat(1)
|
||||
|
||||
def construct(self, x, trans_x):
|
||||
"""construct"""
|
||||
x = self.layer_FishBody(x)
|
||||
x = self.add_up(x)
|
||||
trans_x = self.transfer(trans_x)
|
||||
return self.concat((x, trans_x))
|
||||
|
||||
|
||||
class par_ms_2(nn.Cell):
|
||||
"""par_ms_2"""
|
||||
def __init__(self):
|
||||
super(par_ms_2, self).__init__()
|
||||
self.P_up = P.ResizeNearestNeighbor((56, 56))
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
x = self.P_up(x)
|
||||
return x
|
||||
|
||||
|
||||
class FishHead(nn.Cell):
|
||||
"""Construct FishHead module.
|
||||
Each instances corresponds to each stages.
|
||||
Different with Official Code : we used shortcut layer in this Module. (shortcut layer is used according to the
|
||||
original paper)
|
||||
Args:in_c : Number of channels in the input image
|
||||
out_c : Number of channels in the output image
|
||||
num_blk : Number of Residual Blocks
|
||||
trans_in_c : Number of channels in the transferred image
|
||||
num_trans : Number of Transfer Blocks
|
||||
Forwarding Path:
|
||||
input image - (ResBlock) * num_blk - pool ㄱ
|
||||
trans image - (transfer) --(concat)-- output
|
||||
"""
|
||||
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans):
|
||||
super(FishHead, self).__init__()
|
||||
|
||||
self.layer_FishHead = nn.SequentialCell([
|
||||
ResBlock_without_shortcut(in_c, out_c),
|
||||
*[ResBlock_without_shortcut(out_c, out_c) for _ in range(1, num_blk)],
|
||||
nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
|
||||
])
|
||||
self.transfer = TransferBlock(trans_in_c, num_trans)
|
||||
self.concat_ = P.Concat(1)
|
||||
|
||||
def construct(self, x, trans_x):
|
||||
"""construct"""
|
||||
x = self.layer_FishHead(x)
|
||||
trans_x = self.transfer(trans_x)
|
||||
return self.concat_((x, trans_x))
|
||||
|
||||
|
||||
def _conv_bn_relu(in_ch, out_ch, stride=1):
|
||||
return nn.SequentialCell([nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=stride,
|
||||
pad_mode='pad', padding=1, has_bias=False, weight_init=conv_weight_init),
|
||||
nn.BatchNorm2d(num_features=out_ch, momentum=0.9),
|
||||
nn.ReLU()])
|
||||
|
||||
|
||||
class Fishnet(nn.Cell):
|
||||
"""
|
||||
Construct entire networks
|
||||
Args:
|
||||
start_c : Number of channels of input image.
|
||||
Note that it is NOT the number of channels in initial input image, and it IS the number of output
|
||||
channel of stem.
|
||||
num_cls : Number of classes
|
||||
tail_num_blk : list of the numbers of Conv blocks in each FishTail stages
|
||||
body_num_blk : list of the numbers of Conv blocks in each FishBody stages
|
||||
head_num_blk : list of the numbers of Conv blocks in each FishHead stages
|
||||
(Note : `*_num_blk` includes 1 Residual blocks in the start of each stages)
|
||||
body_num_trans : list of the numbers of Conv blocks in transfer paths in each FishTail stages
|
||||
head_num_trans : list of the numbers of Conv blocks in transfer paths in each FishHead stages
|
||||
tail_channels : list of the number of in, out channel of each stages
|
||||
body_channels : list of the number of in, out channel of each stages
|
||||
head_channels : list of the number of in, out channel of each stages
|
||||
"""
|
||||
def __init__(self, start_c=64, num_cls=1000,
|
||||
tail_num_blk=1, bridge_num_blk=2,
|
||||
body_num_blk=1, body_num_trans=1,
|
||||
head_num_blk=1, head_num_trans=1,
|
||||
tail_channels=1, body_channels=1, head_channels=1):
|
||||
super(Fishnet, self).__init__()
|
||||
|
||||
self.stem = nn.SequentialCell([
|
||||
_conv_bn_relu(3, start_c // 2, stride=2),
|
||||
_conv_bn_relu(start_c // 2, start_c // 2),
|
||||
_conv_bn_relu(start_c // 2, start_c),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||
])
|
||||
print("FishNet Initialization Start")
|
||||
|
||||
self.tail_layer = []
|
||||
for i, num_blk in enumerate(tail_num_blk):
|
||||
layer = FishTail(tail_channels[i], tail_channels[i + 1], num_blk)
|
||||
self.tail_layer.append(layer)
|
||||
self.tail_layer = nn.CellList(self.tail_layer)
|
||||
self.bridge = Bridge(tail_channels[-1], bridge_num_blk)
|
||||
|
||||
self.body_layer = []
|
||||
for i, (num_blk, num_trans) in enumerate(zip(body_num_blk, body_num_trans)):
|
||||
if i == 0:
|
||||
layer = FishBody_0(body_channels[i][0], body_channels[i][1], num_blk,
|
||||
tail_channels[-i - 2], num_trans, dilation=2 ** i)
|
||||
elif i == 1:
|
||||
layer = FishBody_1(body_channels[i][0], body_channels[i][1], num_blk,
|
||||
tail_channels[-i - 2], num_trans, dilation=2 ** i)
|
||||
else:
|
||||
layer = FishBody_2(body_channels[i][0], body_channels[i][1], num_blk,
|
||||
tail_channels[-i - 2], num_trans, dilation=2 ** i)
|
||||
self.body_layer.append(layer)
|
||||
self.body_layer = nn.CellList(self.body_layer)
|
||||
|
||||
self.head_layer = []
|
||||
for i, (num_blk, num_trans) in enumerate(zip(head_num_blk, head_num_trans)):
|
||||
layer = FishHead(head_channels[i][0], head_channels[i][1], num_blk,
|
||||
body_channels[-i - 1][0], num_trans)
|
||||
self.head_layer.append(layer)
|
||||
self.head_layer = nn.CellList(self.head_layer)
|
||||
|
||||
last_c = head_channels[-1][1]
|
||||
self.classifier = nn.SequentialCell([
|
||||
nn.BatchNorm2d(num_features=last_c, momentum=0.9),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=last_c, out_channels=last_c // 2, kernel_size=1, pad_mode='pad', has_bias=False
|
||||
, weight_init=conv_weight_init),
|
||||
nn.BatchNorm2d(num_features=last_c // 2, momentum=0.9),
|
||||
nn.ReLU(),
|
||||
adaptiveavgpool2d_ms(),
|
||||
nn.Conv2d(in_channels=last_c // 2, out_channels=num_cls, kernel_size=1, pad_mode='pad', has_bias=True
|
||||
, weight_init=conv_weight_init)
|
||||
])
|
||||
self.squeeze_ = P.Squeeze(2)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
stem = self.stem(x)
|
||||
tail_features = [stem]
|
||||
for t in self.tail_layer:
|
||||
last_feature = tail_features[-1]
|
||||
tail_features.append(t(last_feature))
|
||||
|
||||
bridge = self.bridge(tail_features[-1])
|
||||
|
||||
body_features = [bridge]
|
||||
for b, tail in zip(self.body_layer, [tail_features[2], tail_features[1], tail_features[0]]):
|
||||
last_feature = body_features[-1]
|
||||
body_features.append(b(last_feature, tail))
|
||||
|
||||
head_features = [body_features[-1]]
|
||||
for h, body in zip(self.head_layer, [body_features[2], body_features[1], body_features[0]]):
|
||||
last_feature = head_features[-1]
|
||||
head_features.append(h(last_feature, body))
|
||||
|
||||
out = self.classifier(head_features[-1])
|
||||
out = self.squeeze_(out)
|
||||
out = self.squeeze_(out)
|
||||
return out
|
||||
|
||||
|
||||
def _calc_channel(start_c, num_blk):
|
||||
"""
|
||||
Calculate the number of in and out channels of each stages in FishNet.
|
||||
Example:
|
||||
fish150 : start channel=64, num_blk=3,
|
||||
tail channels : Grow double in each stages,
|
||||
[64, 128, 256 ...] = [start channel ** (2**num_blk) ....]
|
||||
body channels : In first stage, in_channel and out_channel is the same,
|
||||
but the other layers, the number of output channels is half of the number of input channel
|
||||
Add the number of transfer channels to the number of output channels
|
||||
The numbers of transfer channels are reverse of the tail channel[:-2]
|
||||
[(512, 512), + 256
|
||||
(768, 384), + 128
|
||||
(512, 256)] + 64
|
||||
head channels : The number of input channels and output channels is the same.
|
||||
Add the number of transfer channels to the number of output channels
|
||||
The numbers of transfer channels are reverse of the tail channel[:-2]
|
||||
[(320, 320), + 512
|
||||
(832, 832), + 768
|
||||
(1600, 1600)] + 512
|
||||
"""
|
||||
# tail channels
|
||||
tail_channels = [start_c]
|
||||
for i in range(num_blk):
|
||||
tail_channels.append(tail_channels[-1] * 2)
|
||||
print("Tail Channels : ", tail_channels)
|
||||
|
||||
# body channels
|
||||
in_c, transfer_c = tail_channels[-1], tail_channels[-2]
|
||||
body_channels = [(in_c, in_c), (in_c + transfer_c, (in_c + transfer_c) // 2)]
|
||||
for i in range(1, num_blk - 1):
|
||||
transfer_c = tail_channels[-i - 2]
|
||||
in_c = body_channels[-1][1] + transfer_c
|
||||
body_channels.append((in_c, in_c // 2))
|
||||
print("Body Channels : ", body_channels)
|
||||
|
||||
# head channels
|
||||
in_c = body_channels[-1][1] + tail_channels[0]
|
||||
head_channels = [(in_c, in_c)]
|
||||
for i in range(num_blk):
|
||||
transfer_c = body_channels[-i - 1][0]
|
||||
in_c = head_channels[-1][1] + transfer_c
|
||||
head_channels.append((in_c, in_c))
|
||||
print("Head Channels : ", head_channels)
|
||||
return {"tail_channels": tail_channels, "body_channels": body_channels, "head_channels": head_channels}
|
||||
|
||||
|
||||
def fish99(num_cls=1000):
|
||||
"""fish99"""
|
||||
start_c = 64
|
||||
# tail
|
||||
tail_num_blk = [2, 2, 6]
|
||||
bridge_num_blk = 2
|
||||
# body
|
||||
body_num_blk = [1, 1, 1]
|
||||
body_num_trans = [1, 1, 1]
|
||||
# head
|
||||
head_num_blk = [1, 2, 2]
|
||||
head_num_trans = [1, 1, 4]
|
||||
|
||||
net_channel = _calc_channel(start_c, len(tail_num_blk))
|
||||
|
||||
return Fishnet(start_c, num_cls,
|
||||
tail_num_blk, bridge_num_blk,
|
||||
body_num_blk, body_num_trans,
|
||||
head_num_blk, head_num_trans,
|
||||
**net_channel)
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from src.config import imagenet_cfg
|
||||
from src.dataset import create_dataset_imagenet
|
||||
import src.fishnet as net_ms
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def lr_steps_imagenet(_cfg, steps_per_epoch):
|
||||
"""lr step for imagenet"""
|
||||
if _cfg.lr_scheduler == 'cosine_annealing':
|
||||
_lr = warmup_cosine_annealing_lr(_cfg.lr_init,
|
||||
steps_per_epoch,
|
||||
_cfg.warmup_epochs,
|
||||
_cfg.epoch_size,
|
||||
_cfg.T_max,
|
||||
_cfg.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(_cfg.lr_scheduler)
|
||||
|
||||
return _lr
|
||||
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr1 = float(init_lr) + lr_inc * current_step
|
||||
return lr1
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr5, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||
""" warmup cosine annealing lr."""
|
||||
base_lr = lr5
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr5 = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr5 = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
|
||||
lr_each_step.append(lr5)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
class CrossEntropySmooth(_Loss):
|
||||
"""CrossEntropy"""
|
||||
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropySmooth, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.sparse = sparse
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||
|
||||
def construct(self, logit, label):
|
||||
if self.sparse:
|
||||
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss2 = self.ce(logit, label)
|
||||
return loss2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Classification')
|
||||
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
||||
parser.add_argument('--device_type', type=str, default=None, help='GPU or Ascend. (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
cfg = imagenet_cfg
|
||||
|
||||
# set context
|
||||
if not args_opt.device_type:
|
||||
device_target = args_opt.device_type
|
||||
else:
|
||||
device_target = cfg.device_target
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||
|
||||
if device_target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
if args_opt.device_id is not None:
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
else:
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
|
||||
if device_num > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
elif device_target == "GPU":
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
device_id = get_rank()
|
||||
else:
|
||||
if args_opt.device_id is not None:
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
else:
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
dataset = create_dataset_imagenet(cfg.data_path, 1)
|
||||
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
net = net_ms.fish99()
|
||||
|
||||
# Continue training if set pre_trained to be True
|
||||
if cfg.pre_trained:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
loss_scale_manager = None
|
||||
|
||||
lr = lr_steps_imagenet(cfg, batch_num)
|
||||
|
||||
|
||||
def get_param_groups(network):
|
||||
""" get param groups. """
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||
|
||||
|
||||
if cfg.is_dynamic_loss_scale:
|
||||
cfg.loss_scale = 1
|
||||
|
||||
opt = Momentum(params=get_param_groups(net),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=cfg.momentum,
|
||||
weight_decay=cfg.weight_decay,
|
||||
loss_scale=cfg.loss_scale)
|
||||
if not cfg.use_label_smooth:
|
||||
cfg.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
|
||||
if cfg.is_dynamic_loss_scale == 1:
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
||||
else:
|
||||
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O3", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager)
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 2, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
ckpt_save_dir = "./ckpt/"
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_fishnet99_imagenet", directory=ckpt_save_dir,
|
||||
config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
cbs = [time_cb, ckpoint_cb, loss_cb]
|
||||
if device_num > 1 and device_id != 0:
|
||||
cbs = [time_cb, loss_cb]
|
||||
model.train(cfg.epoch_size, dataset, callbacks=cbs)
|
||||
print("train success")
|
Loading…
Reference in New Issue