!17633 [线上贡献]黄金赛段FishNet99网络精度性能调优提交PR+GPU网络模型征集活动

Merge pull request !17633 from huicui/FishNet99_bold
This commit is contained in:
i-robot 2021-09-08 02:46:55 +00:00 committed by Gitee
commit a656f40f41
18 changed files with 2154 additions and 0 deletions

View File

@ -0,0 +1,313 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [FishNet99描述](#FishNet99描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [导出过程](#导出过程)
- [导出](#导出)
- [推理过程](#推理过程)
- [推理](#推理)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [ImageNet-1k上的FishNet99](#ImageNet-1k上的FishNet99)
- [推理性能](#推理性能)
- [ImageNet-1k上的FishNet99](#ImageNet-1k上的FishNet99)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# FishNet99描述
这是第一个统一为像素级、区域级和图像级任务设计的骨干网络;可以将梯度从非常深的层直接传播到较浅的层;可以保留并互相细化不同深度的特征。
[论文](http://papers.nips.cc/paper/7356-fishnet-a-versatile-backbone-for-image-region-and-pixel-level-prediction.pdf) :FishNet: a versatile backbone for image, region, and pixel level prediction. In Proceedings of the 32nd International Conference on Neural Information Processing Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 762772.
# 模型架构
整个网络分为tail、body和head三个部分其中tail是现有的如ResNet等CNN随着网络的深入特征分辨率会逐渐减小body部分有多个上采样和细化块的结构主要用来细化来自tail和body的特征head则是有着数个下采样和细化块的结构用来保留和细化来自tail和body的特征最后一个卷积层的细化特征被用来处理图像分类等最终任务。
# 数据集
使用的数据集:[ImageNet2012](http://www.image-net.org/)
- 数据集大小125G共1000个类、125万张彩色图像
- 训练集120G共1,281,167张图像
- 测试集5G共50,000张图像
- 数据格式RGB
- 注数据将在src/dataset.py中处理
- 下载数据集,目录结构如下:
```text
└─dataset
├─ILSVRC2012_train # 训练数据集
└─ILSVRC2012_val # 评估数据集
```
# 特性
## 混合精度
采用[混合精度](https://www.mindspore.cn/docs/programming_guide/zh-CN/master/enable_mixed_precision.html) 的训练方法,使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
# 环境要求
- 硬件Ascend
- 使用Ascend来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```bash
# 运行训练示例
python train.py --device_id=0 --device_type='Ascend' > train.log 2>&1 &
# 运行分布式训练示例
bash ./scripts/run_train_ascend.sh [RANK_TABLE_FILE]
# 运行评估示例
python eval.py --checkpoint_path ./ckpt_0 --device_type='Ascend' > ./eval.log 2>&1 &
# 运行推理示例
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
```
对于分布式训练需要提前创建JSON格式的hccl配置文件。
请遵循以下链接中的说明:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
- GPU处理器环境运行
为了在GPU处理器环境运行请将配置文件src/config.py中的device_target从Ascend改为GPU。
```bash
# 运行训练示例
python train.py --device_id=0 --device_type='GPU' > train_gpu.log 2>&1 &
# 运行分布式训练示例
bash ./scripts/run_train_gpu.sh 8 0,1,2,3,4,5,6,7
# 运行评估示例
python eval.py --checkpoint_path ./ckpt_0 --device_type='GPU' > ./eval_gpu.log 2>&1 &
```
# 脚本说明
## 脚本及样例代码
```bash
├── model_zoo
├── README.md // 所有模型相关说明
├── fishnet99
├── README_CN.md // FishNet99相关说明
├── ascend310_infer // 实现310推理源代码
├── scripts
│ ├──run_eval_gpu.sh // GPU评估的shell脚本
│ ├──run_infer_310.sh // Ascend推理的shell脚本
│ ├──run_train_ascend.sh // 分布式到Ascend的shell脚本
│ ├──run_train_gpu.sh // 分布式到GPU处理器的shell脚本
├── src
│ ├──config.py // 参数配置
│ ├──dataset.py // 创建数据集
│ ├──fishnet.py // FishNet99架构
├── eval.py // 评估脚本
├── export.py // 将checkpoint文件导出到air/mindir
├── postprocess.py // 310推理后处理脚本
├── train.py // 训练脚本
```
## 脚本参数
在config.py中可以同时配置训练参数和评估参数。
- 配置FishNet99和ImageNet-1k数据集。
```python
'name':'imagenet' # 数据集
'pre_trained':'False' # 是否基于预训练模型训练
'num_classes':1000 # 数据集类数
'lr_init':0.05 # 初始学习率Ascned单卡训练时设置为0.05Ascned八卡并行训练时设置为0.4GPU单卡训练时设置为0.05GPU双卡并行训练时设置为0.1
'batch_size':128 # 训练批次大小
'epoch_size':160 # 总计训练epoch数其中GPU双卡并行训练时设置为110
'T_max':150 # 学习率衰减相关参数其中GPU双卡并行训练时设置为100
'momentum':0.9 # 动量
'weight_decay':1e-4 # 权重衰减值
'image_height':224 # 输入到模型的图像高度
'image_width':224 # 输入到模型的图像宽度
'data_path':'/data/ILSVRC2012_train/' # 训练数据集的绝对全路径
'val_data_path':'/data/ILSVRC2012_val/' # 评估数据集的绝对全路径
'device_target':'Ascend' # 运行设备
'device_id':0 # 用于训练或评估数据集的设备ID使用run_train.sh进行分布式训练时可以忽略。
'keep_checkpoint_max':25 # 最多保存25个ckpt模型文件
'checkpoint_path':'./ckpt/train_fishnet99_imagenet-146_10009.ckpt' # checkpoint文件保存的绝对全路径
```
更多配置细节请参考脚本`config.py`。
## 训练过程
### 训练
- Ascend处理器环境运行
```bash
python train.py --device_id=0 --device_type='Ascend' > train.log 2>&1 &
```
上述python命令将在后台运行可以通过生成的train.log文件查看结果。
训练结束后,可以在默认脚本文件夹下找到检查点文件,采用以下方式得到损失值:
```bash
# grep "loss is " train.log
...
epoch: 8 step: 10009, loss is 3.0276418
epoch: 9 step: 10009, loss is 3.0397775
...
```
- GPU处理器环境运行
为了在GPU处理器环境运行请将配置文件src/config.py中的device_target从Ascend改为GPU。
```bash
python train.py --device_id=0 --device_type='GPU' > train_gpu.log 2>&1 &
```
上述python命令将在后台运行可以通过生成的train_gpu.log文件查看结果。
训练结束后,可以在默认./ckpt_0/脚本文件夹下找到检查点文件。
### 分布式训练
- Ascend处理器环境运行
```bash
bash ./scripts/run_train_ascend.sh [RANK_TABLE_FILE]
```
上述shell脚本将在后台运行分布训练。
- GPU处理器环境运行
为了在GPU处理器环境运行请将配置文件src/config.py中的device_target从Ascend改为GPU。
```bash
bash ./scripts/run_train_gpu.sh 2 0,1
```
上述shell脚本将在后台运行分布训练。可以在生成的train文件夹中查看结果。
## 评估过程
### 评估
- 在Ascend环境运行时评估ImageNet-1k数据集
“./ckpt_0”是保存了训练好的.ckpt模型文件的目录。
```bash
python eval.py --checkpoint_path ./ckpt_0 --device_type='Ascend' > ./eval.log 2>&1 &
```
- 在GPU处理器环境运行时评估ImageNet-1k数据集
“./ckpt_0”是保存了训练好的.ckpt模型文件的目录。
```bash
python eval.py --checkpoint_path ./ckpt_0 --device_type='GPU' > ./eval_gpu.log 2>&1 &
OR
bash ./scripts/run_eval.sh
```
## 导出过程
### 导出
将checkpoint文件导出成mindir格式模型。
```shell
python export.py --ckpt_file [CKPT_FILE]
```
## 推理过程
### 推理
在进行推理之前我们需要先导出模型。mindir可以在任意环境上导出air模型只能在昇腾910环境上导出。以下展示了使用mindir模型执行推理的示例。
- 在昇腾310上使用ImageNet-1k数据集进行推理
推理的结果保存在scripts目录下在acc.log日志文件中可以找到类似以下的结果。
```shell
# Ascend310 inference
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
Total data: 50000, top1 accuracy: 0.78242, top5 accuracy: 0.94042.
```
# 模型描述
## 性能
### 评估性能
#### ImageNet-1k上的FishNet99
| 参数 | Ascend | GPU |
| -------------------------- | ----------------------------------------------------------- | -------------------------- |
| 模型版本 | FishNet99 | FishNet99 |
| 资源 | Ascend 910 | Tesla V100-32G |
| 上传日期 | 2021-06-02 | 2021-07-24 |
| MindSpore版本 | 1.2.0 | 1.2.0 |
| 数据集 | ImageNet2012 | ImageNet2012 |
| 训练参数 | epoch=160, batch_size=128, lr_init=0.05单卡为0.05八卡为0.4 | epoch=160单卡160双卡110, T_max=150单卡150双卡100, batch_size=128, lr_init=0.05单卡0.05双卡0.1 |
| 优化器 | Momentum | Momentum |
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
| 输出 | 概率 | 概率 |
| 分类准确率 | 单卡top1:78.24%, top5:94.03%八卡top1:78.33%, top5:93.96% | 单卡top1:78.12%, top5:94.13%双卡top1:77.97%, top5:93.98% |
| 速度 | 单卡132毫秒/步八卡135毫秒/步 | 单卡227毫秒/步双卡450毫秒/步 |
| 总时长 | 单卡58.5小时/160轮八卡7.7小时/160轮 | 单卡109.6小时/160轮双卡69.1小时/110轮 |
### 推理性能
#### ImageNet-1k上的FishNet99
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | FishNet99 |
| 资源 | Ascend 310 |
| 上传日期 | 2021-06-16 |
| MindSpore版本 | 1.2.0 |
| 数据集 | ImageNet2012 |
| 分类准确率 | top1:78.24%,top5:94.04% |
| 速度 | Average time 5.17187 ms of infer_count 50000 |
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
std::vector<std::string> GetAllFiles(std::string dir_name);
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name);
#endif

View File

@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.14.1)
project(MindSporeCxxTestcase[CXX])
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT}/../)
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main main.cc utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
cmake . -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make

View File

@ -0,0 +1,146 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "minddata/dataset/include/vision_ascend.h"
#include "minddata/dataset/include/execute.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/vision.h"
#include "inc/utils.h"
using mindspore::dataset::vision::Decode;
using mindspore::dataset::vision::Resize;
using mindspore::dataset::vision::CenterCrop;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::TensorTransform;
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(dataset_path, ".", "dataset path");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
auto all_files = GetAllInputData(FLAGS_dataset_path);
if (all_files.empty()) {
std::cout << "ERROR: no input data." << std::endl;
return 1;
}
std::map<double, double> costTime_map;
size_t size = all_files.size();
// Define transform
std::vector<int32_t> crop_paras = {224};
std::vector<int32_t> resize_paras = {256};
std::vector<float> mean = {0.485 * 255, 0.456 * 255, 0.406 * 255};
std::vector<float> std = {0.229 * 255, 0.224 * 255, 0.225 * 255};
auto decode = Decode();
auto resize = Resize(resize_paras);
auto centercrop = CenterCrop(crop_paras);
auto normalize = Normalize(mean, std);
auto hwc2chw = HWC2CHW();
mindspore::dataset::Execute SingleOp({decode, resize, centercrop, normalize, hwc2chw});
for (size_t i = 0; i < size; ++i) {
for (size_t j = 0; j < all_files[i].size(); ++j) {
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs;
double endTimeMs;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << all_files[i][j] <<std::endl;
auto imgDvpp = std::make_shared<MSTensor>();
SingleOp(ReadFileToTensor(all_files[i][j]), imgDvpp.get());
inputs.emplace_back(imgDvpp->Name(), imgDvpp->DataType(), imgDvpp->Shape(),
imgDvpp->Data().get(), imgDvpp->DataSize());
gettimeofday(&start, nullptr);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << all_files[i][j] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(all_files[i][j], outputs);
}
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}

View File

@ -0,0 +1,185 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <algorithm>
#include <iostream>
#include "inc/utils.h"
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name) {
std::vector<std::vector<std::string>> ret;
DIR *dir = OpenDir(dir_name);
if (dir == nullptr) {
return {};
}
struct dirent *filename;
/* read all the files in the dir ~ */
std::vector<std::string> sub_dirs;
while ((filename = readdir(dir)) != nullptr) {
std::string d_name = std::string(filename->d_name);
// get rid of "." and ".."
if (d_name == "." || d_name == ".." || d_name.empty()) {
continue;
}
std::string dir_path = RealPath(std::string(dir_name) + "/" + filename->d_name);
struct stat s;
lstat(dir_path.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
continue;
}
sub_dirs.emplace_back(dir_path);
}
std::sort(sub_dirs.begin(), sub_dirs.end());
(void)std::transform(sub_dirs.begin(), sub_dirs.end(), std::back_inserter(ret),
[](const std::string &d) { return GetAllFiles(d); });
return ret;
}
std::vector<std::string> GetAllFiles(std::string dir_name) {
struct dirent *filename;
DIR *dir = OpenDir(dir_name);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string d_name = std::string(filename->d_name);
if (d_name == "." || d_name == ".." || d_name.size() <= 3) {
continue;
}
res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
return res;
}
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE *outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create_imagenet2012_label"""
import os
import json
import argparse
parser = argparse.ArgumentParser(description="resnet imagenet2012 label")
parser.add_argument("--img_path", type=str, required=True, help="imagenet2012 file path.")
args = parser.parse_args()
def create_label(file_path):
"""
create label
"""
print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!")
dirs = os.listdir(file_path)
file_list = []
for file in dirs:
file_list.append(file)
file_list = sorted(file_list)
total = 0
img_label = {}
for i, file_dir in enumerate(file_list):
files = os.listdir(os.path.join(file_path, file_dir))
for f in files:
img_label[f] = i
total += len(files)
with open("imagenet_label.json", "w+") as label:
json.dump(img_label, label)
print("[INFO] Completed! Total {} data.".format(total))
if __name__ == '__main__':
create_label(args.img_path)

View File

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python eval.py
"""
import argparse
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from src.config import imagenet_cfg
from src.dataset import create_dataset_imagenet
import src.fishnet as net_ms
set_seed(1)
parser = argparse.ArgumentParser(description='fishnet99')
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--device_type', type=str, default=None, help='GPU or Ascend. (Default: None)')
parser.add_argument('--checkpoint_path', type=str, default='./ckpt_0', help='Checkpoint file path')
args_opt = parser.parse_args()
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss_ = self.ce(logit, label)
return loss_
if __name__ == '__main__':
if args_opt.dataset_name == "imagenet":
cfg = imagenet_cfg
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
net = net_ms.fish99()
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
else:
raise ValueError("dataset is not support.")
if not args_opt.device_type:
device_target = args_opt.device_type
else:
device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
context.set_context(device_id=cfg.device_id)
file_list = os.listdir(args_opt.checkpoint_path)
for filename in file_list:
de_path = os.path.join(args_opt.checkpoint_path, filename)
if de_path.endswith('.ckpt'):
param_dict = load_checkpoint(de_path)
load_param_into_net(net, param_dict)
net.set_train(False)
acc = model.eval(dataset)
print(f"model {de_path}'s accuracy is {acc}")

View File

@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx or mindir model#################
python export.py
"""
import argparse
import numpy as np
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
import src.fishnet as net_ms
parser = argparse.ArgumentParser(description='FishNet99 export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="FishNet99", help="output file name.")
parser.add_argument('--width', type=int, default=224, help='input width')
parser.add_argument('--height', type=int, default=224, help='input height')
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
net = net_ms.fish99()
assert args.ckpt_file is not None, "checkpoint_path is None."
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,54 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""post process for 310 inference"""
import os
import json
import argparse
import numpy as np
from src.config import imagenet_cfg
batch_size = 1
parser = argparse.ArgumentParser(description="resnet inference")
parser.add_argument("--result_path", type=str, required=True, help="result files path.")
parser.add_argument("--label_path", type=str, required=True, help="image file path.")
args = parser.parse_args()
def get_result(result_path, label_path):
"""
get result.
"""
files = os.listdir(result_path)
with open(label_path, "r") as label:
labels = json.load(label)
top1 = 0
top5 = 0
total_data = len(files)
for file in files:
img_ids_name = file.split('_0.')[0]
data_path = os.path.join(result_path, img_ids_name + "_0.bin")
result = np.fromfile(data_path, dtype=np.float32).reshape(batch_size, imagenet_cfg.num_classes)
for batch in range(batch_size):
predict = np.argsort(-result[batch], axis=-1)
if labels[img_ids_name+".JPEG"] == predict[0]:
top1 += 1
if labels[img_ids_name+".JPEG"] in predict[:5]:
top5 += 1
print(f"Total data: {total_data}, top1 accuracy: {top1/total_data}, top5 accuracy: {top5/total_data}.")
if __name__ == '__main__':
get_result(args.result_path, args.label_path)

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: bash run_eval.sh [CKPT_LOCATION]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: CKPT_LOCATION=$1 is not files' location"
exit 1
fi
python eval.py --checkpoint_path=$1 --device_type='GPU' > ./eval.log 2>&1 &

View File

@ -0,0 +1,99 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
device_id=0
if [ $# == 3 ]; then
device_id=$3
fi
echo "mindir name: "$model
echo "dataset path: "$data_path
echo "device id: "$device_id
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function compile_app()
{
cd ../ascend310_infer/src/ || exit
if [ -f "Makefile" ]; then
make clean
fi
sh build.sh &> build.log
}
function infer()
{
cd - || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
../ascend310_infer/src/main --mindir_path=$model --dataset_path=$data_path --device_id=$device_id &> infer.log
}
function cal_acc()
{
python3.7 ../create_imagenet2012_label.py --img_path=$data_path
python3.7 ../postprocess.py --result_path=./result_Files --label_path=./imagenet_label.json &> acc.log &
}
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
cal_acc
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi

View File

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_train.sh [RANK_TABLE_FILE]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
rank_start=$(DEVICE_NUM)
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i
cp ./train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env.log
python train.py --device_id=$i --device_type='Ascend' > log 2>&1 &
cd ..
done

View File

@ -0,0 +1,52 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 2 ]
then
echo "Usage: \
bash run_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]\
"
exit 1
fi
if [ $1 -lt 1 ] && [ $1 -gt 8 ]
then
echo "error: DEVICE_NUM=$1 is not in (1-8)"
exit 1
fi
export DEVICE_NUM=$1
export RANK_SIZE=$1
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "./train" ];
then
rm -rf ./train
fi
mkdir ./train
cd ./train || exit
export CUDA_VISIBLE_DEVICES="$2"
if [ $1 -gt 1 ]
then
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python3 ${BASEPATH}/../train.py --device_type='GPU' > train_gpu.log 2>&1 &
else
python3 ${BASEPATH}/../train.py --device_type='GPU' > train_gpu.log 2>&1 &
fi

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""from googlenet"""
from easydict import EasyDict as edict
imagenet_cfg = edict({
'name': 'imagenet',
'pre_trained': False,
'num_classes': 1000,
'lr_init': 0.05, # Ascend_1P: 0.05, Ascend_8P: 0.4, GPU_1P: 0.05, GPU_2P: 0.1
'batch_size': 128,
'epoch_size': 160, # GPU_2P: 110
'momentum': 0.9,
'weight_decay': 1e-4,
'image_height': 224,
'image_width': 224,
'data_path': '/data/ILSVRC2012_train/',
'val_data_path': '/data/ILSVRC2012_val/',
'device_target': 'Ascend',
'device_id': 0,
'keep_checkpoint_max': 30,
'checkpoint_path': None,
'onnx_filename': 'fishnet99',
'air_filename': 'fishnet99',
# optimizer and lr related
'lr_scheduler': 'cosine_annealing',
'lr_epochs': [30, 60, 90, 120],
'lr_gamma': 0.3,
'eta_min': 0.0,
'T_max': 150, # GPU_2P: 100
'warmup_epochs': 0,
# loss related
'is_dynamic_loss_scale': 0,
'loss_scale': 1024,
'label_smooth_factor': 0.1,
'use_label_smooth': True,
})

View File

@ -0,0 +1,100 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision
from src.config import imagenet_cfg
def create_dataset_imagenet(dataset_path, repeat_num=1, training=True, num_parallel_workers=16, shuffle=None):
"""
create a train or eval imagenet2012 dataset for resnet50
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
device_num, rank_id = _get_rank_info()
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
num_shards=device_num, shard_id=rank_id)
assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
image_size = imagenet_cfg.image_height
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if training:
transform_img = [
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
vision.RandomHorizontalFlip(prob=0.5),
vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
transform_img = [
vision.Decode(),
vision.Resize(256),
vision.CenterCrop(image_size),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
transform_label = [C.TypeCast(mstype.int32)]
data_set = data_set.map(input_columns="image", num_parallel_workers=12, operations=transform_img)
data_set = data_set.map(input_columns="label", num_parallel_workers=4, operations=transform_label)
# apply batch operations
data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
from mindspore.communication.management import get_rank, get_group_size
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None
return rank_size, rank_id

View File

@ -0,0 +1,599 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
FishNet model of MindSpore-1.2.0.
"""
import mindspore.nn as nn
import mindspore.ops.operations as P
conv_weight_init = 'HeUniform'
class adaptiveavgpool2d_ms(nn.Cell):
"""adaptiveavgpool2d_ms"""
def __init__(self):
super(adaptiveavgpool2d_ms, self).__init__()
self.ada_pool = P.ReduceMean(keep_dims=True)
def construct(self, x):
"""construct"""
x = self.ada_pool(x, (2, 3))
return x
def _bn_relu_conv(in_c, out_c, **conv_kwargs):
return nn.SequentialCell([nn.BatchNorm2d(num_features=in_c, momentum=0.9),
nn.ReLU(),
nn.Conv2d(in_channels=in_c, out_channels=out_c, pad_mode='pad',
weight_init=conv_weight_init, **conv_kwargs),
])
class ResBlock_with_shortcut(nn.Cell):
"""
Construct Basic Bottle-necked Residual Block module.
Args:
in_c : Number of channels in the input image
out_c : Number of channels in the output image
shortcut : Specific function for skip-connection
Examples)
'bn_relu_conv' for DownRefinementBlock
'bn_relu_conv with channel reduction' for UpRefinementBlock
'identity mapping' for regular connection
stride : Stride of middle conv layer
dilation : Dilation rate of middle conv layer
Forwarding Path:
(shortcut)
input image - (BN-ReLU-Conv) * 3 - (add) -output
"""
def __init__(self, in_c, out_c, shortcut, stride=1, dilation=1):
super(ResBlock_with_shortcut, self).__init__()
mid_c = out_c // 4
self.layers = nn.SequentialCell([
_bn_relu_conv(in_c, mid_c, kernel_size=1, has_bias=False),
_bn_relu_conv(mid_c, mid_c, kernel_size=3, stride=stride, padding=dilation, dilation=dilation,
has_bias=False),
_bn_relu_conv(mid_c, out_c, kernel_size=1, has_bias=False),
])
self.shortcut = shortcut
self.add_1 = P.Add()
def construct(self, x):
"""construct"""
return self.add_1(self.layers(x), self.shortcut(x))
class ResBlock_without_shortcut(nn.Cell):
"""
Construct Basic Bottle-necked Residual Block module.
Args:
in_c : Number of channels in the input image
out_c : Number of channels in the output image
stride : Stride of middle conv layer
dilation : Dilation rate of middle conv layer
Forwarding Path:
(shortcut)
input image - (BN-ReLU-Conv) * 3 - (add) -output
"""
def __init__(self, in_c, out_c, stride=1, dilation=1):
super(ResBlock_without_shortcut, self).__init__()
mid_c = out_c // 4
self.layers_ = nn.SequentialCell([
_bn_relu_conv(in_c, mid_c, kernel_size=1, has_bias=False),
_bn_relu_conv(mid_c, mid_c, kernel_size=3, stride=stride, padding=dilation, dilation=dilation,
has_bias=False),
_bn_relu_conv(mid_c, out_c, kernel_size=1, has_bias=False),
])
self.add_2 = P.Add()
def construct(self, x):
"""construct"""
return self.add_2(self.layers_(x), x)
class TransferBlock(nn.Cell):
"""
Construct Transfer Block module.
Args:
ch : Number of channels in the input and output image
num_blk : Number of Residual Blocks
Forwarding Path:
input image - (ResBlock_without_shortcut) * num_blk - output
"""
def __init__(self, ch, num_blk):
super(TransferBlock, self).__init__()
self.layers_TransferBlock = nn.SequentialCell([*[ResBlock_without_shortcut(ch, ch)
for _ in range(0, num_blk)]])
def construct(self, x):
"""construct"""
return self.layers_TransferBlock(x)
class DownStage(nn.Cell):
"""
Construct a stage for each resolution.
A DownStage is consisted of one DownRefinementBlock and several residual regular connection blocks.
(Note: In fact, DownRefinementBlock is not used in FishHead according to original implementation.
However, it seems needed to be used according to original paper.
In this version, we followed original implementation, not original paper.)
Args:in_c : Number of channels in the input image
out_c : Number of channels in the output image
num_blk : Number of Residual Blocks
stride : Stride of shortcut conv layer
Forwarding Path:input image - (ResBlock with Shortcut) - (ResBlock) * num_blk - (MaxPool) - output
"""
def __init__(self, in_c, out_c, num_blk, stride=1):
super(DownStage, self).__init__()
shortcut = _bn_relu_conv(in_c, out_c, kernel_size=1, stride=stride, has_bias=False)
self.layer_DownStage = nn.SequentialCell([
ResBlock_with_shortcut(in_c, out_c, shortcut),
*[ResBlock_without_shortcut(out_c, out_c) for _ in range(1, num_blk)],
nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
])
def construct(self, x):
"""construct"""
return self.layer_DownStage(x)
class UpStage(nn.Cell):
"""
Construct a stage for each resolution.
A DownStage is consisted of one DownRefinementBlock and several residual regular connection blocks.
Not like DownStage, this module reduces the number of channels of concatenated feature maps in the shortcut path.
Args:in_c : Number of channels in the input image
out_c : Number of channels in the output image
num_blk : Number of Residual Blocks
stride : Stride of shortcut conv layer
Forwarding Path:input image - (ResBlock with Channel Reduction) - (ResBlock) * num_blk - (UpSample) - output
"""
def __init__(self, in_c, out_c, num_blk, dilation=1):
super(UpStage, self).__init__()
self.k = in_c // out_c
self.redece_sum = P.ReduceSum(keep_dims=False)
self.shape_ = P.Shape()
self.reshape_ = P.Reshape()
self.layer_UpStage = nn.SequentialCell([
ResBlock_with_shortcut(in_c, out_c, channel_reduction_ms(in_c // out_c), dilation=dilation),
*[ResBlock_without_shortcut(out_c, out_c, dilation=dilation) for _ in range(1, num_blk)],
])
def construct(self, x):
"""construct"""
return self.layer_UpStage(x)
class channel_reduction_ms(nn.Cell):
"""channel_reduction_ms"""
def __init__(self, kk):
super(channel_reduction_ms, self).__init__()
self.shape_ = P.Shape()
self.kk = kk
self.reshape_ = P.Reshape()
self.redece_sum = P.ReduceSum(keep_dims=False)
def construct(self, x):
"""construct"""
n, c, h_, w_ = self.shape_(x)
x = self.redece_sum(self.reshape_(x, (n, c // self.kk, self.kk, h_, w_)), 2)
return x
class FishTail(nn.Cell):
"""
Construct FishTail module.
Each instances corresponds to each stages.
Args:
in_c : Number of channels in the input image
out_c : Number of channels in the output image
num_blk : Number of Residual Blocks
Forwarding Path:
input image - (DownStage) - output
"""
def __init__(self, in_c, out_c, num_blk):
super(FishTail, self).__init__()
self.layer_FishTail = DownStage(in_c, out_c, num_blk)
def construct(self, x):
"""construct"""
x = self.layer_FishTail(x)
return x
class Bridge(nn.Cell):
"""
Construct Bridge module.
This module bridges the last FishTail stage and first FishBody stage.
Args:ch : Number of channels in the input and output image
num_blk : Number of Residual Blocks
Forwarding Path:
r (SEBlock)
input image - (stem) - (ResBlock with Shortcut) - (ResBlock) * num_blk - (mul & sum) - output
"""
def __init__(self, ch, num_blk):
super(Bridge, self).__init__()
self.stem = nn.SequentialCell([
nn.BatchNorm2d(num_features=ch, momentum=0.9),
nn.ReLU(),
nn.Conv2d(in_channels=ch, out_channels=ch // 2, kernel_size=1, pad_mode='pad', has_bias=False,
weight_init=conv_weight_init),
nn.BatchNorm2d(num_features=ch // 2, momentum=0.9),
nn.ReLU(),
nn.Conv2d(in_channels=ch // 2, out_channels=ch * 2, kernel_size=1, pad_mode='pad', has_bias=True,
weight_init=conv_weight_init)
])
shortcut = _bn_relu_conv(ch * 2, ch, kernel_size=1, has_bias=False)
self.layers_Bridge = nn.SequentialCell([
ResBlock_with_shortcut(ch * 2, ch, shortcut),
*[ResBlock_without_shortcut(ch, ch) for _ in range(1, num_blk)],
])
self.se_block = nn.SequentialCell([
nn.BatchNorm2d(num_features=ch * 2, momentum=0.9),
nn.ReLU(),
adaptiveavgpool2d_ms(),
nn.Conv2d(in_channels=ch * 2, out_channels=ch // 16, kernel_size=1, pad_mode='pad', has_bias=True
, weight_init=conv_weight_init),
nn.ReLU(),
nn.Conv2d(in_channels=ch // 16, out_channels=ch, kernel_size=1, pad_mode='pad', has_bias=True
, weight_init=conv_weight_init),
nn.Sigmoid()
])
self.add_3 = P.Add()
self.mul_ = P.Mul()
def construct(self, x):
"""construct"""
x = self.stem(x)
att = self.se_block(x)
out = self.layers_Bridge(x)
return self.add_3(self.mul_(out, att), att)
class FishBody_0(nn.Cell):
"""Construct FishBody module.
Each instances corresponds to each stages.
Args:in_c : Number of channels in the input image
out_c : Number of channels in the output image
num_blk : Number of Residual Blocks
trans_in_c : Number of channels in the transferred image
num_trans : Number of Transfer Blocks
dilation : Dilation rate of Conv in UpRefinementBlock
Forwarding Path:
input image - (UpStage)
trans image - (transfer) --(concat)-- output
"""
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans, dilation=1):
super(FishBody_0, self).__init__()
self.layer_FishBody = UpStage(in_c, out_c, num_blk, dilation=dilation)
self.add_up = par_ms_0()
self.transfer = TransferBlock(trans_in_c, num_trans)
self.concat = P.Concat(1)
def construct(self, x, trans_x):
"""construct"""
x = self.layer_FishBody(x)
x = self.add_up(x)
trans_x = self.transfer(trans_x)
return self.concat((x, trans_x))
class par_ms_0(nn.Cell):
"""par_ms_0"""
def __init__(self):
super(par_ms_0, self).__init__()
self.P_up = P.ResizeNearestNeighbor((14, 14))
def construct(self, x):
"""construct"""
x = self.P_up(x)
return x
class FishBody_1(nn.Cell):
"""Construct FishBody module.
Each instances corresponds to each stages.
Args:in_c : Number of channels in the input image
out_c : Number of channels in the output image
num_blk : Number of Residual Blocks
trans_in_c : Number of channels in the transferred image
num_trans : Number of Transfer Blocks
dilation : Dilation rate of Conv in UpRefinementBlock
Forwarding Path:
input image - (UpStage)
trans image - (transfer) --(concat)-- output
"""
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans, dilation=1):
super(FishBody_1, self).__init__()
self.layer_FishBody = UpStage(in_c, out_c, num_blk, dilation=dilation)
self.add_up = par_ms_1()
self.transfer = TransferBlock(trans_in_c, num_trans)
self.concat = P.Concat(1)
def construct(self, x, trans_x):
"""construct"""
x = self.layer_FishBody(x)
x = self.add_up(x)
trans_x = self.transfer(trans_x)
return self.concat((x, trans_x))
class par_ms_1(nn.Cell):
"""par_ms_1"""
def __init__(self):
super(par_ms_1, self).__init__()
self.P_up = P.ResizeNearestNeighbor((28, 28))
def construct(self, x):
"""construct"""
x = self.P_up(x)
return x
class FishBody_2(nn.Cell):
"""Construct FishBody module.
Each instances corresponds to each stages.
Args:in_c : Number of channels in the input image
out_c : Number of channels in the output image
num_blk : Number of Residual Blocks
trans_in_c : Number of channels in the transferred image
num_trans : Number of Transfer Blocks
dilation : Dilation rate of Conv in UpRefinementBlock
Forwarding Path:
input image - (UpStage)
trans image - (transfer) --(concat)-- output
"""
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans, dilation=1):
super(FishBody_2, self).__init__()
self.layer_FishBody = UpStage(in_c, out_c, num_blk, dilation=dilation)
self.add_up = par_ms_2()
self.transfer = TransferBlock(trans_in_c, num_trans)
self.concat = P.Concat(1)
def construct(self, x, trans_x):
"""construct"""
x = self.layer_FishBody(x)
x = self.add_up(x)
trans_x = self.transfer(trans_x)
return self.concat((x, trans_x))
class par_ms_2(nn.Cell):
"""par_ms_2"""
def __init__(self):
super(par_ms_2, self).__init__()
self.P_up = P.ResizeNearestNeighbor((56, 56))
def construct(self, x):
"""construct"""
x = self.P_up(x)
return x
class FishHead(nn.Cell):
"""Construct FishHead module.
Each instances corresponds to each stages.
Different with Official Code : we used shortcut layer in this Module. (shortcut layer is used according to the
original paper)
Args:in_c : Number of channels in the input image
out_c : Number of channels in the output image
num_blk : Number of Residual Blocks
trans_in_c : Number of channels in the transferred image
num_trans : Number of Transfer Blocks
Forwarding Path:
input image - (ResBlock) * num_blk - pool
trans image - (transfer) --(concat)-- output
"""
def __init__(self, in_c, out_c, num_blk, trans_in_c, num_trans):
super(FishHead, self).__init__()
self.layer_FishHead = nn.SequentialCell([
ResBlock_without_shortcut(in_c, out_c),
*[ResBlock_without_shortcut(out_c, out_c) for _ in range(1, num_blk)],
nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
])
self.transfer = TransferBlock(trans_in_c, num_trans)
self.concat_ = P.Concat(1)
def construct(self, x, trans_x):
"""construct"""
x = self.layer_FishHead(x)
trans_x = self.transfer(trans_x)
return self.concat_((x, trans_x))
def _conv_bn_relu(in_ch, out_ch, stride=1):
return nn.SequentialCell([nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=stride,
pad_mode='pad', padding=1, has_bias=False, weight_init=conv_weight_init),
nn.BatchNorm2d(num_features=out_ch, momentum=0.9),
nn.ReLU()])
class Fishnet(nn.Cell):
"""
Construct entire networks
Args:
start_c : Number of channels of input image.
Note that it is NOT the number of channels in initial input image, and it IS the number of output
channel of stem.
num_cls : Number of classes
tail_num_blk : list of the numbers of Conv blocks in each FishTail stages
body_num_blk : list of the numbers of Conv blocks in each FishBody stages
head_num_blk : list of the numbers of Conv blocks in each FishHead stages
(Note : `*_num_blk` includes 1 Residual blocks in the start of each stages)
body_num_trans : list of the numbers of Conv blocks in transfer paths in each FishTail stages
head_num_trans : list of the numbers of Conv blocks in transfer paths in each FishHead stages
tail_channels : list of the number of in, out channel of each stages
body_channels : list of the number of in, out channel of each stages
head_channels : list of the number of in, out channel of each stages
"""
def __init__(self, start_c=64, num_cls=1000,
tail_num_blk=1, bridge_num_blk=2,
body_num_blk=1, body_num_trans=1,
head_num_blk=1, head_num_trans=1,
tail_channels=1, body_channels=1, head_channels=1):
super(Fishnet, self).__init__()
self.stem = nn.SequentialCell([
_conv_bn_relu(3, start_c // 2, stride=2),
_conv_bn_relu(start_c // 2, start_c // 2),
_conv_bn_relu(start_c // 2, start_c),
nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
])
print("FishNet Initialization Start")
self.tail_layer = []
for i, num_blk in enumerate(tail_num_blk):
layer = FishTail(tail_channels[i], tail_channels[i + 1], num_blk)
self.tail_layer.append(layer)
self.tail_layer = nn.CellList(self.tail_layer)
self.bridge = Bridge(tail_channels[-1], bridge_num_blk)
self.body_layer = []
for i, (num_blk, num_trans) in enumerate(zip(body_num_blk, body_num_trans)):
if i == 0:
layer = FishBody_0(body_channels[i][0], body_channels[i][1], num_blk,
tail_channels[-i - 2], num_trans, dilation=2 ** i)
elif i == 1:
layer = FishBody_1(body_channels[i][0], body_channels[i][1], num_blk,
tail_channels[-i - 2], num_trans, dilation=2 ** i)
else:
layer = FishBody_2(body_channels[i][0], body_channels[i][1], num_blk,
tail_channels[-i - 2], num_trans, dilation=2 ** i)
self.body_layer.append(layer)
self.body_layer = nn.CellList(self.body_layer)
self.head_layer = []
for i, (num_blk, num_trans) in enumerate(zip(head_num_blk, head_num_trans)):
layer = FishHead(head_channels[i][0], head_channels[i][1], num_blk,
body_channels[-i - 1][0], num_trans)
self.head_layer.append(layer)
self.head_layer = nn.CellList(self.head_layer)
last_c = head_channels[-1][1]
self.classifier = nn.SequentialCell([
nn.BatchNorm2d(num_features=last_c, momentum=0.9),
nn.ReLU(),
nn.Conv2d(in_channels=last_c, out_channels=last_c // 2, kernel_size=1, pad_mode='pad', has_bias=False
, weight_init=conv_weight_init),
nn.BatchNorm2d(num_features=last_c // 2, momentum=0.9),
nn.ReLU(),
adaptiveavgpool2d_ms(),
nn.Conv2d(in_channels=last_c // 2, out_channels=num_cls, kernel_size=1, pad_mode='pad', has_bias=True
, weight_init=conv_weight_init)
])
self.squeeze_ = P.Squeeze(2)
def construct(self, x):
"""construct"""
stem = self.stem(x)
tail_features = [stem]
for t in self.tail_layer:
last_feature = tail_features[-1]
tail_features.append(t(last_feature))
bridge = self.bridge(tail_features[-1])
body_features = [bridge]
for b, tail in zip(self.body_layer, [tail_features[2], tail_features[1], tail_features[0]]):
last_feature = body_features[-1]
body_features.append(b(last_feature, tail))
head_features = [body_features[-1]]
for h, body in zip(self.head_layer, [body_features[2], body_features[1], body_features[0]]):
last_feature = head_features[-1]
head_features.append(h(last_feature, body))
out = self.classifier(head_features[-1])
out = self.squeeze_(out)
out = self.squeeze_(out)
return out
def _calc_channel(start_c, num_blk):
"""
Calculate the number of in and out channels of each stages in FishNet.
Example:
fish150 : start channel=64, num_blk=3,
tail channels : Grow double in each stages,
[64, 128, 256 ...] = [start channel ** (2**num_blk) ....]
body channels : In first stage, in_channel and out_channel is the same,
but the other layers, the number of output channels is half of the number of input channel
Add the number of transfer channels to the number of output channels
The numbers of transfer channels are reverse of the tail channel[:-2]
[(512, 512), + 256
(768, 384), + 128
(512, 256)] + 64
head channels : The number of input channels and output channels is the same.
Add the number of transfer channels to the number of output channels
The numbers of transfer channels are reverse of the tail channel[:-2]
[(320, 320), + 512
(832, 832), + 768
(1600, 1600)] + 512
"""
# tail channels
tail_channels = [start_c]
for i in range(num_blk):
tail_channels.append(tail_channels[-1] * 2)
print("Tail Channels : ", tail_channels)
# body channels
in_c, transfer_c = tail_channels[-1], tail_channels[-2]
body_channels = [(in_c, in_c), (in_c + transfer_c, (in_c + transfer_c) // 2)]
for i in range(1, num_blk - 1):
transfer_c = tail_channels[-i - 2]
in_c = body_channels[-1][1] + transfer_c
body_channels.append((in_c, in_c // 2))
print("Body Channels : ", body_channels)
# head channels
in_c = body_channels[-1][1] + tail_channels[0]
head_channels = [(in_c, in_c)]
for i in range(num_blk):
transfer_c = body_channels[-i - 1][0]
in_c = head_channels[-1][1] + transfer_c
head_channels.append((in_c, in_c))
print("Head Channels : ", head_channels)
return {"tail_channels": tail_channels, "body_channels": body_channels, "head_channels": head_channels}
def fish99(num_cls=1000):
"""fish99"""
start_c = 64
# tail
tail_num_blk = [2, 2, 6]
bridge_num_blk = 2
# body
body_num_blk = [1, 1, 1]
body_num_trans = [1, 1, 1]
# head
head_num_blk = [1, 2, 2]
head_num_trans = [1, 1, 4]
net_channel = _calc_channel(start_c, len(tail_num_blk))
return Fishnet(start_c, num_cls,
tail_num_blk, bridge_num_blk,
body_num_blk, body_num_trans,
head_num_blk, head_num_trans,
**net_channel)

View File

@ -0,0 +1,214 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train"""
import argparse
import os
import math
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init, get_rank
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from src.config import imagenet_cfg
from src.dataset import create_dataset_imagenet
import src.fishnet as net_ms
set_seed(1)
def lr_steps_imagenet(_cfg, steps_per_epoch):
"""lr step for imagenet"""
if _cfg.lr_scheduler == 'cosine_annealing':
_lr = warmup_cosine_annealing_lr(_cfg.lr_init,
steps_per_epoch,
_cfg.warmup_epochs,
_cfg.epoch_size,
_cfg.T_max,
_cfg.eta_min)
else:
raise NotImplementedError(_cfg.lr_scheduler)
return _lr
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr1 = float(init_lr) + lr_inc * current_step
return lr1
def warmup_cosine_annealing_lr(lr5, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
""" warmup cosine annealing lr."""
base_lr = lr5
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr5 = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr5 = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
lr_each_step.append(lr5)
return np.array(lr_each_step).astype(np.float32)
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss2 = self.ce(logit, label)
return loss2
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Classification')
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
parser.add_argument('--device_type', type=str, default=None, help='GPU or Ascend. (Default: None)')
args_opt = parser.parse_args()
cfg = imagenet_cfg
# set context
if not args_opt.device_type:
device_target = args_opt.device_type
else:
device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
context.set_context(enable_graph_kernel=True)
device_num = int(os.getenv('RANK_SIZE', '1'))
if device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID', '0'))
if args_opt.device_id is not None:
context.set_context(device_id=args_opt.device_id)
else:
context.set_context(device_id=cfg.device_id)
if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
elif device_target == "GPU":
if device_num > 1:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
device_id = get_rank()
else:
if args_opt.device_id is not None:
context.set_context(device_id=args_opt.device_id)
else:
context.set_context(device_id=cfg.device_id)
else:
raise ValueError("Unsupported platform.")
dataset = create_dataset_imagenet(cfg.data_path, 1)
batch_num = dataset.get_dataset_size()
net = net_ms.fish99()
# Continue training if set pre_trained to be True
if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
loss_scale_manager = None
lr = lr_steps_imagenet(cfg, batch_num)
def get_param_groups(network):
""" get param groups. """
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
if cfg.is_dynamic_loss_scale:
cfg.loss_scale = 1
opt = Momentum(params=get_param_groups(net),
learning_rate=Tensor(lr),
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
if cfg.is_dynamic_loss_scale == 1:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O3", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager)
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 2, keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpt_save_dir = "./ckpt/"
ckpoint_cb = ModelCheckpoint(prefix="train_fishnet99_imagenet", directory=ckpt_save_dir,
config=config_ck)
loss_cb = LossMonitor()
cbs = [time_cb, ckpoint_cb, loss_cb]
if device_num > 1 and device_id != 0:
cbs = [time_cb, loss_cb]
model.train(cfg.epoch_size, dataset, callbacks=cbs)
print("train success")