forked from mindspore-Ecosystem/mindspore
commit
815760cd35
|
@ -0,0 +1,256 @@
|
||||||
|
# 目录
|
||||||
|
|
||||||
|
<!-- TOC -->
|
||||||
|
|
||||||
|
- [目录](#目录)
|
||||||
|
- [概述](#概述)
|
||||||
|
- [论文](#论文)
|
||||||
|
- [模型架构](#模型架构)
|
||||||
|
- [数据集](#数据集)
|
||||||
|
- [环境要求](#环境要求)
|
||||||
|
- [快速入门](#快速入门)
|
||||||
|
- [脚本说明](#脚本说明)
|
||||||
|
- [脚本结构与说明](#脚本结构与说明)
|
||||||
|
- [脚本参数](#脚本参数)
|
||||||
|
- [训练过程](#训练过程)
|
||||||
|
- [用法](#用法)
|
||||||
|
- [Ascend处理器环境运行](#Ascend处理器环境运行)
|
||||||
|
- [结果](#结果)
|
||||||
|
- [评估过程](#评估过程)
|
||||||
|
- [用法](#用法-1)
|
||||||
|
- [Ascend处理器环境运行](#Ascend处理器环境运行-1)
|
||||||
|
- [结果](#结果-1)
|
||||||
|
- [推理过程](#推理过程)
|
||||||
|
- [导出MindIR](#导出MindIR)
|
||||||
|
- [结果](#结果)
|
||||||
|
- [模型描述](#模型描述)
|
||||||
|
- [性能](#性能)
|
||||||
|
- [评估性能](#评估性能)
|
||||||
|
- [随机情况说明](#随机情况说明)
|
||||||
|
- [ModelZoo主页](#modelzoo主页)
|
||||||
|
|
||||||
|
<!-- /TOC -->
|
||||||
|
|
||||||
|
# GhostNet描述
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
GhostNet由华为诺亚方舟实验室在2020年提出,此网络提供了一个全新的Ghost模块,旨在通过廉价操作生成更多的特征图。基于一组原始的特征图,作者应用一系列线性变换,以很小的代价生成许多能从原始特征发掘所需信息的“幻影”特征图(Ghost feature maps)。该Ghost模块即插即用,通过堆叠Ghost模块得出Ghost bottleneck,进而搭建轻量级神经网络——GhostNet。该架构可以在同样精度下,速度和计算量均少于SOTA算法。
|
||||||
|
|
||||||
|
如下为MindSpore使用ImageNet2012数据集对GhostNet进行训练的示例。
|
||||||
|
|
||||||
|
## 论文
|
||||||
|
|
||||||
|
1. [论文](https://arxiv.org/pdf/1911.11907.pdf): Kai Han, Yunhe Wang, Qi Tian."GhostNet: More Features From Cheap Operations"
|
||||||
|
|
||||||
|
# 模型架构
|
||||||
|
|
||||||
|
GhostNet的总体网络架构如下:[链接](https://arxiv.org/pdf/1911.11907.pdf)
|
||||||
|
|
||||||
|
# 数据集
|
||||||
|
|
||||||
|
使用的数据集:[ImageNet2012](http://www.image-net.org/)
|
||||||
|
|
||||||
|
- 数据集大小:共1000个类、224*224彩色图像
|
||||||
|
- 训练集:共1,281,167张图像
|
||||||
|
- 测试集:共50,000张图像
|
||||||
|
- 数据格式:JPEG
|
||||||
|
- 注:数据在dataset.py中处理。
|
||||||
|
- 下载数据集,目录结构如下:
|
||||||
|
|
||||||
|
```text
|
||||||
|
└─dataset
|
||||||
|
├─ilsvrc # 训练数据集
|
||||||
|
└─validation_preprocess # 评估数据集
|
||||||
|
```
|
||||||
|
|
||||||
|
# 环境要求
|
||||||
|
|
||||||
|
- 硬件
|
||||||
|
- 准备Ascend处理器搭建硬件环境。如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com,审核通过即可获得资源。
|
||||||
|
- 框架
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- 如需查看详情,请参见如下资源:
|
||||||
|
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
# 快速入门
|
||||||
|
|
||||||
|
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||||
|
|
||||||
|
- Ascend处理器环境运行
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
# 分布式训练
|
||||||
|
用法:sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||||
|
|
||||||
|
# 单机训练
|
||||||
|
用法:sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||||
|
|
||||||
|
# 运行评估示例
|
||||||
|
用法:sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
# 脚本说明
|
||||||
|
|
||||||
|
## 脚本结构与说明
|
||||||
|
|
||||||
|
```text
|
||||||
|
└──ghostnet
|
||||||
|
├── README.md
|
||||||
|
├── scripts
|
||||||
|
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
|
||||||
|
├── run_eval.sh # 启动Ascend评估
|
||||||
|
└── run_standalone_train.sh # 启动Ascend单机训练(单卡)
|
||||||
|
├── src
|
||||||
|
├── config.py # 参数配置
|
||||||
|
├── dataset.py # 数据预处理
|
||||||
|
├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义
|
||||||
|
├── lr_generator.py # 生成每个步骤的学习率
|
||||||
|
└── ghostnet.py # ghostnet网络
|
||||||
|
├── eval.py # 评估网络
|
||||||
|
└── train.py # 训练网络
|
||||||
|
```
|
||||||
|
|
||||||
|
# 脚本参数
|
||||||
|
|
||||||
|
在config.py中可以同时配置训练参数和评估参数。
|
||||||
|
|
||||||
|
- 配置GhostNet和ImageNet2012数据集。
|
||||||
|
|
||||||
|
```Python
|
||||||
|
"num_classes": 1000, # 数据集类数
|
||||||
|
"batch_size": 128, # 输入张量的批次大小
|
||||||
|
"epoch_size": 500, # 训练周期大小
|
||||||
|
"warmup_epochs": 20, # 热身周期数
|
||||||
|
"lr_init": 0.1, # 基础学习率
|
||||||
|
"lr_max": 0.4, # 最大学习率
|
||||||
|
'lr_end': 1e-6, # 最终学习率
|
||||||
|
'lr_decay_mode': 'cosine', # 用于生成学习率的衰减模式
|
||||||
|
"momentum": 0.9, # 动量优化器
|
||||||
|
"weight_decay": 4e-5, # 权重衰减
|
||||||
|
"label_smooth": 0.1, # 标签平滑因子
|
||||||
|
"loss_scale": 128, # 损失等级
|
||||||
|
"use_label_smooth": True, # 标签平滑
|
||||||
|
"label_smooth_factor": 0.1, # 标签平滑因子
|
||||||
|
"save_checkpoint": True, # 是否保存检查点
|
||||||
|
"save_checkpoint_epochs": 20, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
|
||||||
|
"keep_checkpoint_max": 10, # 只保存最后一个keep_checkpoint_max检查点
|
||||||
|
"save_checkpoint_path": "./", # 检查点相对于执行路径的保存路径
|
||||||
|
```
|
||||||
|
|
||||||
|
# 训练过程
|
||||||
|
|
||||||
|
## 用法
|
||||||
|
|
||||||
|
### Ascend处理器环境运行
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
# 分布式训练
|
||||||
|
用法:sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||||
|
|
||||||
|
# 单机训练
|
||||||
|
用法:sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
||||||
|
|
||||||
|
具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。
|
||||||
|
|
||||||
|
训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。
|
||||||
|
|
||||||
|
## 结果
|
||||||
|
|
||||||
|
- 使用ImageNet2012数据集训练GhostNet
|
||||||
|
|
||||||
|
```text
|
||||||
|
# 分布式训练结果(8P)
|
||||||
|
epoch: 1 step: 1251, loss is 5.001419
|
||||||
|
epoch time: 457012.100 ms, per step time: 365.317 ms
|
||||||
|
epoch: 2 step: 1251, loss is 4.275552
|
||||||
|
epoch time: 280175.784 ms, per step time: 223.961 ms
|
||||||
|
epoch: 3 step: 1251, loss is 4.0788813
|
||||||
|
epoch time: 280134.943 ms, per step time: 223.929 ms
|
||||||
|
epoch: 4 step: 1251, loss is 4.0310946
|
||||||
|
epoch time: 280161.342 ms, per step time: 223.950 ms
|
||||||
|
epoch: 5 step: 1251, loss is 3.7326777
|
||||||
|
epoch time: 280178.602 ms, per step time: 223.964 ms
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
# 评估过程
|
||||||
|
|
||||||
|
## 用法
|
||||||
|
|
||||||
|
### Ascend处理器环境运行
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
# 评估
|
||||||
|
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
# 评估示例
|
||||||
|
sh run_eval.sh /data/dataset/ImageNet/imagenet_original ghostnet-500_1251.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
训练过程中可以生成检查点。
|
||||||
|
|
||||||
|
## 结果
|
||||||
|
|
||||||
|
评估结果保存在示例路径中,文件夹名为“eval”。您可在此路径下的日志找到如下结果:
|
||||||
|
|
||||||
|
- 使用ImageNet2012数据集评估GhostNet
|
||||||
|
|
||||||
|
```text
|
||||||
|
result: {'top_5_accuracy': 0.9162371134020618, 'top_1_accuracy': 0.739368556701031}
|
||||||
|
ckpt = /home/lzu/ghost_Mindspore/scripts/device0/ghostnet-500_1251.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
# 推理过程
|
||||||
|
|
||||||
|
## [导出MindIR](#contents)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||||
|
```
|
||||||
|
|
||||||
|
参数ckpt_file为必填项,
|
||||||
|
`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中选择。
|
||||||
|
|
||||||
|
## 结果
|
||||||
|
|
||||||
|
导出“.mindir”文件可在当前目录查看
|
||||||
|
|
||||||
|
# 模型描述
|
||||||
|
|
||||||
|
## 性能
|
||||||
|
|
||||||
|
### 评估性能
|
||||||
|
|
||||||
|
| 参数 | Ascend 910 |
|
||||||
|
|---|---|
|
||||||
|
| 模型版本 | GhostNet |
|
||||||
|
| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755G |
|
||||||
|
| 上传日期 |2021-06-22 ; |
|
||||||
|
| MindSpore版本 | 1.0.1 |
|
||||||
|
| 数据集 | ImageNet2012 |
|
||||||
|
| 训练参数 | epoch=500, steps per epoch=1251, batch_size = 128 |
|
||||||
|
| 优化器 | Momentum |
|
||||||
|
| 损失函数 |Softmax交叉熵 |
|
||||||
|
| 输出 | 概率 |
|
||||||
|
| 损失 | 1.7887309 |
|
||||||
|
|速度|223.92毫秒/步(8卡) |
|
||||||
|
|总时长 | 39小时 |
|
||||||
|
|参数(M) | 5.18 |
|
||||||
|
| 微调检查点 | 42.05M(.ckpt文件) |
|
||||||
|
| 脚本 | [链接](https://gitee.com/alreadyhad/mindspore/tree/master/model_zoo/research/cv/ghostnet) |
|
||||||
|
|
||||||
|
# 随机情况说明
|
||||||
|
|
||||||
|
dataset.py中设置了“create_dataset”函数内的种子,同时还使用了train.py中的随机种子。
|
||||||
|
|
||||||
|
# ModelZoo主页
|
||||||
|
|
||||||
|
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -1,145 +0,0 @@
|
||||||
# Contents
|
|
||||||
|
|
||||||
- [GhostNet Description](#ghostnet-description)
|
|
||||||
- [Model Architecture](#model-architecture)
|
|
||||||
- [Dataset](#dataset)
|
|
||||||
- [Environment Requirements](#environment-requirements)
|
|
||||||
- [Script Description](#script-description)
|
|
||||||
- [Script and Sample Code](#script-and-sample-code)
|
|
||||||
- [Training Process](#training-process)
|
|
||||||
- [Evaluation Process](#evaluation-process)
|
|
||||||
- [Evaluation](#evaluation)
|
|
||||||
- [Model Description](#model-description)
|
|
||||||
- [Performance](#performance)
|
|
||||||
- [Training Performance](#evaluation-performance)
|
|
||||||
- [Inference Performance](#evaluation-performance)
|
|
||||||
- [Description of Random Situation](#description-of-random-situation)
|
|
||||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
|
||||||
|
|
||||||
## [GhostNet Description](#contents)
|
|
||||||
|
|
||||||
The GhostNet architecture is based on an Ghost module structure which generate more features from cheap operations. Based on a set of intrinsic feature maps, a series of cheap operations are applied to generate many ghost feature maps that could fully reveal information underlying intrinsic features.
|
|
||||||
|
|
||||||
[Paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Han_GhostNet_More_Features_From_Cheap_Operations_CVPR_2020_paper.pdf): Kai Han, Yunhe Wang, Qi Tian, Jianyuan Guo, Chunjing Xu, Chang Xu. GhostNet: More Features from Cheap Operations. CVPR 2020.
|
|
||||||
|
|
||||||
## [Model architecture](#contents)
|
|
||||||
|
|
||||||
The overall network architecture of GhostNet is show below:
|
|
||||||
|
|
||||||
[Link](https://openaccess.thecvf.com/content_CVPR_2020/papers/Han_GhostNet_More_Features_From_Cheap_Operations_CVPR_2020_paper.pdf)
|
|
||||||
|
|
||||||
## [Dataset](#contents)
|
|
||||||
|
|
||||||
Dataset used: [Oxford-IIIT Pet](https://www.robots.ox.ac.uk/~vgg/data/pets/)
|
|
||||||
|
|
||||||
- Dataset size: 7049 colorful images in 1000 classes
|
|
||||||
- Train: 3680 images
|
|
||||||
- Test: 3369 images
|
|
||||||
- Data format: RGB images.
|
|
||||||
- Note: Data will be processed in src/dataset.py
|
|
||||||
|
|
||||||
## [Environment Requirements](#contents)
|
|
||||||
|
|
||||||
- Hardware(Ascend/GPU)
|
|
||||||
- Prepare hardware environment with Ascend or GPU.
|
|
||||||
- Framework
|
|
||||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
|
||||||
- For more information, please check the resources below:
|
|
||||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
|
|
||||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
|
|
||||||
|
|
||||||
## [Script description](#contents)
|
|
||||||
|
|
||||||
### [Script and sample code](#contents)
|
|
||||||
|
|
||||||
```python
|
|
||||||
├── GhostNet
|
|
||||||
├── Readme.md # descriptions about ghostnet # shell script for evaluation with CPU, GPU or Ascend
|
|
||||||
├── src
|
|
||||||
│ ├──config.py # parameter configuration
|
|
||||||
│ ├──dataset.py # creating dataset
|
|
||||||
│ ├──launch.py # start python script
|
|
||||||
│ ├──lr_generator.py # learning rate config
|
|
||||||
│ ├──ghostnet.py # GhostNet architecture
|
|
||||||
│ ├──ghostnet600.py # GhostNet-600M architecture
|
|
||||||
├── eval.py # evaluation script
|
|
||||||
├── mindspore_hub_conf.py # export model for hub
|
|
||||||
```
|
|
||||||
|
|
||||||
## [Training process](#contents)
|
|
||||||
|
|
||||||
To Be Done
|
|
||||||
|
|
||||||
## [Eval process](#contents)
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
After installing MindSpore via the official website, you can start evaluation as follows:
|
|
||||||
|
|
||||||
### Launch
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# infer example
|
|
||||||
|
|
||||||
Ascend: python eval.py --model [ghostnet/ghostnet-600] --dataset_path ~/Pets/test.mindrecord --platform Ascend --checkpoint_path [CHECKPOINT_PATH]
|
|
||||||
GPU: python eval.py --model [ghostnet/ghostnet-600] --dataset_path ~/Pets/test.mindrecord --platform GPU --checkpoint_path [CHECKPOINT_PATH]
|
|
||||||
```
|
|
||||||
|
|
||||||
> checkpoint can be produced in training process.
|
|
||||||
|
|
||||||
### Result
|
|
||||||
|
|
||||||
```bash
|
|
||||||
result: {'acc': 0.8113927500681385} ckpt= ./ghostnet_nose_1x_pets.ckpt
|
|
||||||
result: {'acc': 0.824475333878441} ckpt= ./ghostnet_1x_pets.ckpt
|
|
||||||
result: {'acc': 0.8691741618969746} ckpt= ./ghostnet600M_pets.ckpt
|
|
||||||
```
|
|
||||||
|
|
||||||
## [Model Description](#contents)
|
|
||||||
|
|
||||||
### [Performance](#contents)
|
|
||||||
|
|
||||||
#### Evaluation Performance
|
|
||||||
|
|
||||||
##### GhostNet on ImageNet2012
|
|
||||||
|
|
||||||
| Parameters | | |
|
|
||||||
| -------------------------- | -------------------------------------- |---------------------------------- |
|
|
||||||
| Model Version | GhostNet |GhostNet-600|
|
|
||||||
| uploaded Date | 09/08/2020 (month/day/year) ; | 09/08/2020 (month/day/year) |
|
|
||||||
| MindSpore Version | 0.6.0-alpha |0.6.0-alpha |
|
|
||||||
| Dataset | ImageNet2012 | ImageNet2012|
|
|
||||||
| Parameters (M) | 5.2 | 11.9 |
|
|
||||||
| FLOPs (M) | 142 | 591 |
|
|
||||||
| Accuracy (Top1) | 73.9 |80.2 |
|
|
||||||
|
|
||||||
###### GhostNet on Oxford-IIIT Pet
|
|
||||||
|
|
||||||
| Parameters | | |
|
|
||||||
| -------------------------- | -------------------------------------- |---------------------------------- |
|
|
||||||
| Model Version | GhostNet |GhostNet-600|
|
|
||||||
| uploaded Date | 09/08/2020 (month/day/year) ; | 09/08/2020 (month/day/year) |
|
|
||||||
| MindSpore Version | 0.6.0-alpha |0.6.0-alpha |
|
|
||||||
| Dataset | Oxford-IIIT Pet | Oxford-IIIT Pet|
|
|
||||||
| Parameters (M) | 3.9 | 10.6 |
|
|
||||||
| FLOPs (M) | 140 | 590 |
|
|
||||||
| Accuracy (Top1) | 82.4 |86.9 |
|
|
||||||
|
|
||||||
###### Comparison with other methods on Oxford-IIIT Pet
|
|
||||||
|
|
||||||
|Model|FLOPs (M)|Latency (ms)*|Accuracy (Top1)|
|
|
||||||
|-|-|-|-|
|
|
||||||
|MobileNetV2-1x|300|28.2|78.5|
|
|
||||||
|Ghost-1x w\o SE|138|19.1|81.1|
|
|
||||||
|Ghost-1x|140|25.3|82.4|
|
|
||||||
|Ghost-600|590|-|86.9|
|
|
||||||
|
|
||||||
*The latency is measured on Huawei Kirin 990 chip under single-threaded mode with batch size 1.
|
|
||||||
|
|
||||||
## [Description of Random Situation](#contents)
|
|
||||||
|
|
||||||
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
|
|
||||||
|
|
||||||
## [ModelZoo Homepage](#contents)
|
|
||||||
|
|
||||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -21,56 +21,25 @@ from mindspore import context
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.common import dtype as mstype
|
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.config import config_ascend, config_gpu
|
from src.ghostnet import ghostnet_1x
|
||||||
from src.ghostnet import ghostnet_1x, ghostnet_nose_1x
|
|
||||||
from src.ghostnet600 import ghostnet_600m
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Image classification')
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||||
parser.add_argument('--platform', type=str, default=None, help='run platform')
|
|
||||||
parser.add_argument('--model', type=str, default=None, help='ghostnet')
|
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
config_platform = None
|
|
||||||
if args_opt.platform == "Ascend":
|
|
||||||
config_platform = config_ascend
|
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
|
||||||
device_id=device_id, save_graphs=False)
|
device_id=device_id, save_graphs=False)
|
||||||
elif args_opt.platform == "GPU":
|
|
||||||
config_platform = config_gpu
|
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
|
||||||
device_target="GPU", save_graphs=False)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported platform.")
|
|
||||||
|
|
||||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||||
|
|
||||||
if args_opt.model == 'ghostnet':
|
net = ghostnet_1x()
|
||||||
net = ghostnet_1x(num_classes=config_platform.num_classes)
|
|
||||||
elif args_opt.model == 'ghostnet_nose':
|
|
||||||
net = ghostnet_nose_1x(num_classes=config_platform.num_classes)
|
|
||||||
elif args_opt.model == 'ghostnet-600':
|
|
||||||
net = ghostnet_600m(num_classes=config_platform.num_classes)
|
|
||||||
|
|
||||||
if args_opt.platform == "Ascend":
|
dataset = create_dataset(dataset_path=args_opt.data_url, do_train=False)
|
||||||
net.to_float(mstype.float16)
|
|
||||||
for _, cell in net.cells_and_names():
|
|
||||||
if isinstance(cell, nn.Dense):
|
|
||||||
cell.to_float(mstype.float32)
|
|
||||||
|
|
||||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
|
||||||
do_train=False,
|
|
||||||
config=config_platform,
|
|
||||||
platform=args_opt.platform,
|
|
||||||
batch_size=config_platform.batch_size,
|
|
||||||
model=args_opt.model)
|
|
||||||
step_size = dataset.get_dataset_size()
|
step_size = dataset.get_dataset_size()
|
||||||
|
|
||||||
if args_opt.checkpoint_path:
|
if args_opt.checkpoint_path:
|
||||||
|
@ -78,6 +47,6 @@ if __name__ == '__main__':
|
||||||
load_param_into_net(net, param_dict)
|
load_param_into_net(net, param_dict)
|
||||||
net.set_train(False)
|
net.set_train(False)
|
||||||
|
|
||||||
model = Model(net, loss_fn=loss, metrics={'acc'})
|
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||||
res = model.eval(dataset)
|
res = model.eval(dataset)
|
||||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" export MINDIR """
|
||||||
|
import argparse as arg
|
||||||
|
import numpy as np
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import context, Tensor, export, load_checkpoint
|
||||||
|
from src.ghostnet import ghostnet_1x
|
||||||
|
from src.config import config
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = arg.ArgumentParser(description='SID export')
|
||||||
|
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
|
||||||
|
help='device where the code will be implemented')
|
||||||
|
parser.add_argument('--device_id', type=int, default=0, help='device id')
|
||||||
|
parser.add_argument('--file_format', type=str, choices=['AIR', 'MINDIR'], default='MINDIR',
|
||||||
|
help='file format')
|
||||||
|
parser.add_argument('--checkpoint_path', required=True, default=None, help='ckpt file path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
if args.device_target == 'Ascend':
|
||||||
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
|
ckpt_dir = args.checkpoint_path
|
||||||
|
net = ghostnet_1x(num_classes=config.num_classes)
|
||||||
|
load_checkpoint(ckpt_dir, net=net)
|
||||||
|
net.set_train(False)
|
||||||
|
|
||||||
|
input_data = Tensor(np.zeros([1, 3, 224, 224]), ms.float32)
|
||||||
|
export(net, input_data, file_name='ghost', file_format=args.file_format)
|
|
@ -1,27 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""hub config."""
|
|
||||||
from src.ghostnet import ghostnet_1x, ghostnet_nose_1x
|
|
||||||
from src.ghostnet600 import ghostnet_600m
|
|
||||||
|
|
||||||
|
|
||||||
def create_network(name, *args, **kwargs):
|
|
||||||
if name == 'ghostnet':
|
|
||||||
return ghostnet_1x(*args, **kwargs)
|
|
||||||
if name == 'ghostnet_nose':
|
|
||||||
return ghostnet_nose_1x(*args, **kwargs)
|
|
||||||
if name == 'ghostnet-600':
|
|
||||||
return ghostnet_600m(*args, **kwargs)
|
|
||||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH PRETRAINED_CKPT_PATH](optional)"
|
||||||
|
echo "For example: bash run_distribute_train.sh hccl_8p_01234567_127.0.0.1.json /path/dataset"
|
||||||
|
echo "It is better to use the absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
PATH2=$(get_real_path $2)
|
||||||
|
|
||||||
|
if [ $# == 3 ]
|
||||||
|
then
|
||||||
|
PATH3=$(get_real_path $3)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d $PATH2 ]
|
||||||
|
then
|
||||||
|
echo "error: DATA_PATH=$PATH2 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $# == 3 ] && [ ! -f $PATH3 ]
|
||||||
|
then
|
||||||
|
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=8
|
||||||
|
export RANK_SIZE=8
|
||||||
|
export RANK_TABLE_FILE=$PATH1
|
||||||
|
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
|
||||||
|
|
||||||
|
DATA_PATH=$2
|
||||||
|
export DATA_PATH=${DATA_PATH}
|
||||||
|
|
||||||
|
for((i=0;i<${RANK_SIZE};i++))
|
||||||
|
do
|
||||||
|
rm -rf device$i
|
||||||
|
mkdir device$i
|
||||||
|
cp ../*.py ./device$i
|
||||||
|
cp *.sh ./device$i
|
||||||
|
cp -r ../src ./device$i
|
||||||
|
cd ./device$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$i
|
||||||
|
echo "start training for device $i"
|
||||||
|
env > env$i.log
|
||||||
|
|
||||||
|
if [ $# == 2 ]
|
||||||
|
then
|
||||||
|
python train.py --run_distribute=True --data_url=$PATH2 &> train.log &
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $# == 3 ]
|
||||||
|
then
|
||||||
|
python train.py --run_distribute=True --data_url=$PATH2 --pre_trained=$PATH3 &> train.log &
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd ../
|
||||||
|
done
|
|
@ -0,0 +1,64 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash run_eval.sh DATA_PATH CHECKPOINT_PATH "
|
||||||
|
echo "For example: bash run.sh /path/dataset ghostnet-500_1251.ckpt"
|
||||||
|
echo "It is better to use the absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
PATH2=$(get_real_path $2)
|
||||||
|
|
||||||
|
if [ ! -d $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $PATH2 ]
|
||||||
|
then
|
||||||
|
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp *.sh ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cd ./eval
|
||||||
|
env > env.log
|
||||||
|
echo "start evaluation for device $DEVICE_ID"
|
||||||
|
python eval.py --data_url=$PATH1 --checkpoint_path=$PATH2 &> eval.log &
|
||||||
|
cd ..
|
|
@ -0,0 +1,77 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash run_standalone_train.sh DATA_PATH PRETRAINED_CKPT_PATH(optional)"
|
||||||
|
echo "For example: bash run_standalone_train.sh /path/dataset"
|
||||||
|
echo "It is better to use the absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
if [ $# == 2 ]
|
||||||
|
then
|
||||||
|
PATH2=$(get_real_path $2)
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $# == 2 ] && [ ! -f $PATH2 ]
|
||||||
|
then
|
||||||
|
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cp ../*.py ./train
|
||||||
|
cp *.sh ./train
|
||||||
|
cp -r ../src ./train
|
||||||
|
cd ./train
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
if [ $# == 1 ]
|
||||||
|
then
|
||||||
|
python train.py --data_url=$PATH1 &> train.log &
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $# == 2 ]
|
||||||
|
then
|
||||||
|
python train.py --data_url=$PATH1 --pre_trained=$PATH2 &> train.log &
|
||||||
|
fi
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""define loss function for network"""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.nn.loss.loss import _Loss
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropySmooth(_Loss):
|
||||||
|
"""CrossEntropy"""
|
||||||
|
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||||
|
super(CrossEntropySmooth, self).__init__()
|
||||||
|
self.onehot = P.OneHot()
|
||||||
|
self.sparse = sparse
|
||||||
|
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||||
|
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||||
|
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||||
|
|
||||||
|
def construct(self, logit, label):
|
||||||
|
if self.sparse:
|
||||||
|
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||||
|
loss = self.ce(logit, label)
|
||||||
|
return loss
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,38 +17,23 @@ network config setting, will be used in train.py and eval.py
|
||||||
"""
|
"""
|
||||||
from easydict import EasyDict as ed
|
from easydict import EasyDict as ed
|
||||||
|
|
||||||
config_ascend = ed({
|
config = ed({
|
||||||
"num_classes": 37,
|
"num_classes": 1000,
|
||||||
"image_height": 224,
|
"batch_size": 128,
|
||||||
"image_width": 224,
|
"epoch_size": 500,
|
||||||
"batch_size": 256,
|
"warmup_epochs": 20,
|
||||||
"epoch_size": 200,
|
"lr_init": 0.1,
|
||||||
"warmup_epochs": 4,
|
"lr_max": 0.4,
|
||||||
"lr": 0.4,
|
'lr_end': 1e-6,
|
||||||
|
'lr_decay_mode': 'cosine',
|
||||||
"momentum": 0.9,
|
"momentum": 0.9,
|
||||||
"weight_decay": 4e-5,
|
"weight_decay": 4e-5,
|
||||||
"label_smooth": 0.1,
|
"label_smooth": 0.1,
|
||||||
"loss_scale": 1024,
|
"loss_scale": 128,
|
||||||
|
"use_label_smooth": True,
|
||||||
|
"label_smooth_factor": 0.1,
|
||||||
"save_checkpoint": True,
|
"save_checkpoint": True,
|
||||||
"save_checkpoint_epochs": 1,
|
"save_checkpoint_epochs": 20,
|
||||||
"keep_checkpoint_max": 200,
|
"keep_checkpoint_max": 10,
|
||||||
"save_checkpoint_path": "./checkpoint",
|
"save_checkpoint_path": "./",
|
||||||
})
|
|
||||||
|
|
||||||
config_gpu = ed({
|
|
||||||
"num_classes": 37,
|
|
||||||
"image_height": 224,
|
|
||||||
"image_width": 224,
|
|
||||||
"batch_size": 3,
|
|
||||||
"epoch_size": 370,
|
|
||||||
"warmup_epochs": 4,
|
|
||||||
"lr": 0.4,
|
|
||||||
"momentum": 0.9,
|
|
||||||
"weight_decay": 4e-5,
|
|
||||||
"label_smooth": 0.1,
|
|
||||||
"loss_scale": 1024,
|
|
||||||
"save_checkpoint": True,
|
|
||||||
"save_checkpoint_epochs": 1,
|
|
||||||
"keep_checkpoint_max": 500,
|
|
||||||
"save_checkpoint_path": "./checkpoint",
|
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -12,99 +12,83 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""
|
"""Data operations, will be used in train.py and eval.py"""
|
||||||
create train or eval dataset.
|
|
||||||
"""
|
|
||||||
import os
|
import os
|
||||||
|
from src.config import config
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset.engine as de
|
||||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
|
||||||
import mindspore.dataset.transforms.vision.py_transforms as P
|
|
||||||
import mindspore.dataset.transforms.c_transforms as C2
|
import mindspore.dataset.transforms.c_transforms as C2
|
||||||
from mindspore.dataset.transforms.vision import Inter
|
import mindspore.dataset.vision.c_transforms as C
|
||||||
|
from mindspore.communication.management import get_rank, get_group_size
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=100, model='ghsotnet'):
|
def create_dataset(dataset_path, do_train, target="Ascend"):
|
||||||
"""
|
"""
|
||||||
create a train or eval dataset
|
create a train or eval dataset
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_path(string): the path of dataset.
|
dataset_path(string): the path of dataset.
|
||||||
do_train(bool): whether dataset is used for train or eval.
|
do_train(bool): whether dataset is used for train or eval.
|
||||||
repeat_num(int): the repeat times of dataset. Default: 1
|
rank (int): The shard ID within num_shards (default=None).
|
||||||
batch_size(int): the batch size of dataset. Default: 32
|
group_size (int): Number of shards that the dataset should be divided into (default=None).
|
||||||
|
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dataset
|
dataset
|
||||||
"""
|
"""
|
||||||
if platform == "Ascend":
|
if not do_train:
|
||||||
rank_size = int(os.getenv("RANK_SIZE"))
|
dataset_path = os.path.join(dataset_path, 'val')
|
||||||
rank_id = int(os.getenv("RANK_ID"))
|
|
||||||
if rank_size == 1:
|
|
||||||
data_set = ds.MindDataset(
|
|
||||||
dataset_path, num_parallel_workers=8, shuffle=True)
|
|
||||||
else:
|
else:
|
||||||
data_set = ds.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
dataset_path = os.path.join(dataset_path, 'train')
|
||||||
num_shards=rank_size, shard_id=rank_id)
|
if target == "Ascend":
|
||||||
elif platform == "GPU":
|
device_num, rank_id = _get_rank_info()
|
||||||
if do_train:
|
|
||||||
from mindspore.communication.management import get_rank, get_group_size
|
|
||||||
data_set = ds.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
|
||||||
num_shards=get_group_size(), shard_id=get_rank())
|
|
||||||
else:
|
|
||||||
data_set = ds.MindDataset(
|
|
||||||
dataset_path, num_parallel_workers=8, shuffle=True)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported platform.")
|
|
||||||
|
|
||||||
resize_height = config.image_height
|
if device_num == 1:
|
||||||
buffer_size = 1000
|
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||||
|
else:
|
||||||
|
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||||
|
num_shards=device_num, shard_id=rank_id)
|
||||||
|
|
||||||
|
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||||
|
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||||
# define map operations
|
# define map operations
|
||||||
resize_crop_op = C.RandomCropDecodeResize(
|
|
||||||
resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
|
|
||||||
horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)
|
|
||||||
|
|
||||||
color_op = C.RandomColorAdjust(
|
|
||||||
brightness=0.4, contrast=0.4, saturation=0.4)
|
|
||||||
rescale_op = C.Rescale(1 / 255.0, 0)
|
|
||||||
normalize_op = C.Normalize(
|
|
||||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
change_swap_op = C.HWC2CHW()
|
|
||||||
|
|
||||||
# define python operations
|
|
||||||
decode_p = P.Decode()
|
|
||||||
if model == 'ghostnet-600':
|
|
||||||
s = 274
|
|
||||||
c = 240
|
|
||||||
else:
|
|
||||||
s = 256
|
|
||||||
c = 224
|
|
||||||
resize_p = P.Resize(s, interpolation=Inter.BICUBIC)
|
|
||||||
center_crop_p = P.CenterCrop(c)
|
|
||||||
totensor = P.ToTensor()
|
|
||||||
normalize_p = P.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
|
||||||
composeop = P.ComposeOp(
|
|
||||||
[decode_p, resize_p, center_crop_p, totensor, normalize_p])
|
|
||||||
if do_train:
|
if do_train:
|
||||||
trans = [resize_crop_op, horizontal_flip_op, color_op,
|
trans = [
|
||||||
rescale_op, normalize_op, change_swap_op]
|
C.RandomCropDecodeResize(224),
|
||||||
|
C.RandomHorizontalFlip(prob=0.5),
|
||||||
|
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
trans = composeop()
|
trans = [
|
||||||
|
C.Decode(),
|
||||||
|
C.Resize(256),
|
||||||
|
C.CenterCrop(224),
|
||||||
|
]
|
||||||
|
trans += [
|
||||||
|
C.Normalize(mean=mean, std=std),
|
||||||
|
C.HWC2CHW(),
|
||||||
|
]
|
||||||
|
|
||||||
type_cast_op = C2.TypeCast(mstype.int32)
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
|
||||||
data_set = data_set.map(input_columns="image", operations=trans,
|
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
|
||||||
num_parallel_workers=8)
|
|
||||||
data_set = data_set.map(input_columns="label_list",
|
|
||||||
operations=type_cast_op, num_parallel_workers=8)
|
|
||||||
|
|
||||||
# apply shuffle operations
|
|
||||||
data_set = data_set.shuffle(buffer_size=buffer_size)
|
|
||||||
|
|
||||||
# apply batch operations
|
# apply batch operations
|
||||||
data_set = data_set.batch(batch_size, drop_remainder=True)
|
ds = ds.batch(config.batch_size, drop_remainder=True)
|
||||||
|
return ds
|
||||||
|
|
||||||
# apply dataset repeat operation
|
|
||||||
data_set = data_set.repeat(repeat_num)
|
|
||||||
|
|
||||||
return data_set
|
def _get_rank_info():
|
||||||
|
"""
|
||||||
|
get rank size and rank id
|
||||||
|
"""
|
||||||
|
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||||
|
|
||||||
|
if rank_size > 1:
|
||||||
|
rank_size = get_group_size()
|
||||||
|
rank_id = get_rank()
|
||||||
|
else:
|
||||||
|
rank_size = 1
|
||||||
|
rank_id = 0
|
||||||
|
|
||||||
|
return rank_size, rank_id
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -46,6 +46,7 @@ class MyHSigmoid(nn.Cell):
|
||||||
self.relu6 = nn.ReLU6()
|
self.relu6 = nn.ReLU6()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
""" construct """
|
||||||
return self.relu6(x + 3.) * 0.16666667
|
return self.relu6(x + 3.) * 0.16666667
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,6 +75,7 @@ class Activation(nn.Cell):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
""" construct """
|
||||||
return self.act(x)
|
return self.act(x)
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,6 +97,7 @@ class GlobalAvgPooling(nn.Cell):
|
||||||
self.mean = P.ReduceMean(keep_dims=keep_dims)
|
self.mean = P.ReduceMean(keep_dims=keep_dims)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
""" construct """
|
||||||
x = self.mean(x, (2, 3))
|
x = self.mean(x, (2, 3))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -127,6 +130,7 @@ class SE(nn.Cell):
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
""" construct of SE module """
|
||||||
out = self.pool(x)
|
out = self.pool(x)
|
||||||
out = self.conv_reduce(out)
|
out = self.conv_reduce(out)
|
||||||
out = self.act1(out)
|
out = self.act1(out)
|
||||||
|
@ -173,6 +177,7 @@ class ConvUnit(nn.Cell):
|
||||||
self.act = Activation(act_type) if use_act else None
|
self.act = Activation(act_type) if use_act else None
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
""" construct of conv unit """
|
||||||
out = self.conv(x)
|
out = self.conv(x)
|
||||||
out = self.bn(out)
|
out = self.bn(out)
|
||||||
if self.use_act:
|
if self.use_act:
|
||||||
|
@ -209,12 +214,14 @@ class GhostModule(nn.Cell):
|
||||||
new_channels = init_channels * (ratio - 1)
|
new_channels = init_channels * (ratio - 1)
|
||||||
|
|
||||||
self.primary_conv = ConvUnit(num_in, init_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
self.primary_conv = ConvUnit(num_in, init_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
num_groups=1, use_act=use_act, act_type='relu')
|
num_groups=1, use_act=use_act, act_type=act_type)
|
||||||
self.cheap_operation = ConvUnit(init_channels, new_channels, kernel_size=dw_size, stride=1, padding=dw_size//2,
|
self.cheap_operation = ConvUnit(init_channels, new_channels, kernel_size=dw_size, stride=1,
|
||||||
num_groups=init_channels, use_act=use_act, act_type='relu')
|
padding=dw_size // 2, num_groups=init_channels,
|
||||||
|
use_act=use_act, act_type=act_type)
|
||||||
self.concat = P.Concat(axis=1)
|
self.concat = P.Concat(axis=1)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
""" ghost module construct """
|
||||||
x1 = self.primary_conv(x)
|
x1 = self.primary_conv(x)
|
||||||
x2 = self.cheap_operation(x1)
|
x2 = self.cheap_operation(x1)
|
||||||
return self.concat((x1, x2))
|
return self.concat((x1, x2))
|
||||||
|
@ -269,10 +276,10 @@ class GhostBottleneck(nn.Cell):
|
||||||
ConvUnit(num_in, num_out, kernel_size=1, stride=1,
|
ConvUnit(num_in, num_out, kernel_size=1, stride=1,
|
||||||
padding=0, num_groups=1, use_act=False),
|
padding=0, num_groups=1, use_act=False),
|
||||||
])
|
])
|
||||||
self.add = P.Add()
|
self.add = P.TensorAdd()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
r"""construct of ghostnet"""
|
""" construct of ghostnet """
|
||||||
shortcut = x
|
shortcut = x
|
||||||
out = self.ghost1(x)
|
out = self.ghost1(x)
|
||||||
if self.use_dw:
|
if self.use_dw:
|
||||||
|
@ -318,7 +325,7 @@ class GhostNet(nn.Cell):
|
||||||
>>> GhostNet(num_classes=1000)
|
>>> GhostNet(num_classes=1000)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8):
|
def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0.):
|
||||||
super(GhostNet, self).__init__()
|
super(GhostNet, self).__init__()
|
||||||
self.cfgs = model_cfgs['cfg']
|
self.cfgs = model_cfgs['cfg']
|
||||||
self.inplanes = 16
|
self.inplanes = 16
|
||||||
|
@ -365,7 +372,7 @@ class GhostNet(nn.Cell):
|
||||||
self._initialize_weights()
|
self._initialize_weights()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
r"""construct of GhostNet"""
|
""" construct of GhostNet """
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
x = self.act1(x)
|
x = self.act1(x)
|
||||||
|
@ -403,21 +410,21 @@ class GhostNet(nn.Cell):
|
||||||
for _, m in self.cells_and_names():
|
for _, m in self.cells_and_names():
|
||||||
if isinstance(m, (nn.Conv2d)):
|
if isinstance(m, (nn.Conv2d)):
|
||||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
|
m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
|
||||||
m.weight.data.shape).astype("float32")))
|
m.weight.data.shape).astype("float32")))
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
m.bias.set_parameter_data(
|
m.bias.set_data(
|
||||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
m.gamma.set_parameter_data(
|
m.gamma.set_data(
|
||||||
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
|
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
|
||||||
m.beta.set_parameter_data(
|
m.beta.set_data(
|
||||||
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
|
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
|
||||||
elif isinstance(m, nn.Dense):
|
elif isinstance(m, nn.Dense):
|
||||||
m.weight.set_parameter_data(Tensor(np.random.normal(
|
m.weight.set_data(Tensor(np.random.normal(
|
||||||
0, 0.01, m.weight.data.shape).astype("float32")))
|
0, 0.01, m.weight.data.shape).astype("float32")))
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
m.bias.set_parameter_data(
|
m.bias.set_data(
|
||||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -16,8 +16,7 @@
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
|
||||||
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
|
|
||||||
"""
|
"""
|
||||||
generate learning rate array
|
generate learning rate array
|
||||||
|
|
||||||
|
@ -47,9 +46,6 @@ def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, st
|
||||||
if lr < 0.0:
|
if lr < 0.0:
|
||||||
lr = 0.0
|
lr = 0.0
|
||||||
lr_each_step.append(lr)
|
lr_each_step.append(lr)
|
||||||
|
|
||||||
current_step = global_step
|
|
||||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||||
learning_rate = lr_each_step[current_step:]
|
|
||||||
|
|
||||||
return learning_rate
|
return lr_each_step
|
||||||
|
|
|
@ -0,0 +1,141 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
train.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
import mindspore.common.initializer as weight_init
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.nn.optim.momentum import Momentum
|
||||||
|
from mindspore.communication.management import init, get_rank
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from src.lr_generator import get_lr
|
||||||
|
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.config import config
|
||||||
|
from src.ghostnet import ghostnet_1x
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Image classification--GhostNet')
|
||||||
|
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||||
|
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
|
||||||
|
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
||||||
|
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||||
|
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# init context
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
|
||||||
|
save_graphs=False)
|
||||||
|
|
||||||
|
if args_opt.run_distribute:
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||||
|
print(rank_size)
|
||||||
|
device_num = rank_size
|
||||||
|
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
||||||
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
init()
|
||||||
|
args_opt.rank = get_rank()
|
||||||
|
|
||||||
|
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||||
|
args_opt.rank_save_ckpt_flag = 0
|
||||||
|
if args_opt.is_save_on_master:
|
||||||
|
if args_opt.rank == 0:
|
||||||
|
args_opt.rank_save_ckpt_flag = 1
|
||||||
|
else:
|
||||||
|
args_opt.rank_save_ckpt_flag = 1
|
||||||
|
|
||||||
|
# define net
|
||||||
|
net = ghostnet_1x(num_classes=config.num_classes)
|
||||||
|
net.to_float(mstype.float16)
|
||||||
|
for _, cell in net.cells_and_names():
|
||||||
|
if isinstance(cell, nn.Dense):
|
||||||
|
cell.to_float(mstype.float32)
|
||||||
|
|
||||||
|
local_data_path = args_opt.data_url
|
||||||
|
print('Download data:')
|
||||||
|
dataset = create_dataset(dataset_path=local_data_path,
|
||||||
|
do_train=True,
|
||||||
|
target="Ascend")
|
||||||
|
|
||||||
|
step_size = dataset.get_dataset_size()
|
||||||
|
print('steps:', step_size)
|
||||||
|
|
||||||
|
# init weight
|
||||||
|
if args_opt.pre_trained:
|
||||||
|
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
else:
|
||||||
|
for _, cell in net.cells_and_names():
|
||||||
|
if isinstance(cell, nn.Conv2d):
|
||||||
|
cell.weight.set_data(weight_init.initializer(weight_init.HeUniform(),
|
||||||
|
cell.weight.shape,
|
||||||
|
cell.weight.dtype))
|
||||||
|
if isinstance(cell, nn.Dense):
|
||||||
|
cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(),
|
||||||
|
cell.weight.shape,
|
||||||
|
cell.weight.dtype))
|
||||||
|
|
||||||
|
# init lr
|
||||||
|
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end,
|
||||||
|
lr_max=config.lr_max, warmup_epochs=config.warmup_epochs,
|
||||||
|
total_epochs=config.epoch_size, steps_per_epoch=step_size)
|
||||||
|
lr = Tensor(lr)
|
||||||
|
|
||||||
|
if not config.use_label_smooth:
|
||||||
|
config.label_smooth_factor = 0.0
|
||||||
|
|
||||||
|
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||||
|
smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
|
||||||
|
|
||||||
|
opt = Momentum(net.trainable_params(), lr, config.momentum, loss_scale=config.loss_scale,
|
||||||
|
weight_decay=config.weight_decay)
|
||||||
|
|
||||||
|
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||||
|
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale,
|
||||||
|
metrics={'top_1_accuracy', 'top_5_accuracy'},
|
||||||
|
amp_level="O3", keep_batchnorm_fp32=False)
|
||||||
|
|
||||||
|
# define callbacks
|
||||||
|
time_cb = TimeMonitor(data_size=step_size)
|
||||||
|
loss_cb = LossMonitor()
|
||||||
|
cb = [time_cb, loss_cb]
|
||||||
|
if config.save_checkpoint:
|
||||||
|
if args_opt.rank_save_ckpt_flag:
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||||
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
|
ckpt_cb = ModelCheckpoint(prefix="ghostnet", directory=config.save_checkpoint_path, config=config_ck)
|
||||||
|
cb += [ckpt_cb]
|
||||||
|
|
||||||
|
# train model
|
||||||
|
model.train(config.epoch_size, dataset, callbacks=cb,
|
||||||
|
sink_size=dataset.get_dataset_size())
|
Loading…
Reference in New Issue