forked from mindspore-Ecosystem/mindspore
add DPN implementation
This commit is contained in:
parent
ef0b483eb4
commit
d3bf1c3280
|
@ -0,0 +1,370 @@
|
|||
# Contents
|
||||
|
||||
- [Description](#description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Dataset Preparation](#dataset-preparation)
|
||||
- [Running](#running)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Running on Ascend](#running-on-ascend)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Running on Ascend](#running-on-ascend-1)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Running on Ascend](#running-on-ascend-2)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Accuracy](#accuracy)
|
||||
- [DPN92 (Pretrain)](#dpn92-pretrain)
|
||||
- [DPN98 (Pretrain)](#dpn98-pretrain)
|
||||
- [DPN131 (Pretrain)](#dpn131-pretrain)
|
||||
- [DPN92 (Fine tune)](#dpn92-fine-tune)
|
||||
- [DPN92 (Training)](#dpn92-training)
|
||||
- [Efficiency](#efficiency)
|
||||
- [DPN92](#dpn92)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [Description](#contents)
|
||||
|
||||
Dual Path Network (DPN) is a convolution-based neural network for the task of image classification. It combines the advantage of both ResNeXt and DenseNet to get higher accuracy. More detail about this model can be found in:
|
||||
|
||||
Yunpeng Chen, Jianan Li, Huaxin Xiao, Xiaojie Jin, Shuicheng Yan, Jiashi Feng. "Dual Path Networks" (NIPS17).
|
||||
|
||||
This repository contains a Mindspore implementation of DPNs based upon cypw's original MXNet implementation (<https://github.com/cypw/DPNs>). The training and validating scripts are also included, and the validation results with cypw’s pretrained weights are shown in the Results section.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
The overall network architecture of DPN is show below:
|
||||
|
||||
[Link](https://arxiv.org/pdf/1707.01629.pdf)
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
All the models in this repository are trained and validated on ImageNet-1K. The models can achieve the [results](#model-description) with the configurations of the dataset preprocessing as follow:
|
||||
|
||||
- For the training dataset:
|
||||
|
||||
- Range (min, max) of the respective size of the original size to be cropped is (0.08, 1.0)
|
||||
|
||||
- Range (min, max) of aspect ratio to be cropped is (0.75, 1.333)
|
||||
- The size of input images is reshaped to (width = 224, height = 224)
|
||||
- Probability of random horizontal flip is 50%
|
||||
- In normalization, the mean is (255\*0.485, 255\*0.456, 255\*0.406) and the standard deviation is (255\*0.229, 255\*0.224, 255\*0.225)
|
||||
|
||||
- For the evaluation dataset:
|
||||
- Input size of images is 224\*224 (Resize to 256\*256 then crops images at the center)
|
||||
- In normalization, the mean is (255\*0.485, 255\*0.456, 255\*0.406) and the standard deviation is (255\*0.229, 255\*0.224, 255\*0.225)
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## [Mixed Precision](#contents)
|
||||
|
||||
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
To run the python scripts in the repository, you need to prepare the environment as follow:
|
||||
|
||||
- Hardware
|
||||
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to [ascend@huawei.com](mailto:ascend@huawei.com). Once approved, you can get the resources.
|
||||
- Python and dependencies
|
||||
- Python3.7
|
||||
- Mindspore 1.0.0
|
||||
- Easydict
|
||||
- MXNet 1.6.0 if running the script `param_convert.py`
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
## [Dataset Preparation](#contents)
|
||||
|
||||
The DPN models use ImageNet-1K dataset to train and validate in this repository. Download the dataset from [ImageNet.org](http://image-net.org/download). You can place them anywhere and tell the scripts where they are when running.
|
||||
|
||||
## [Running](#contents)
|
||||
|
||||
To train the DPNs, run the shell script `scripts/train_standalone.sh` with the format below:
|
||||
|
||||
```shell
|
||||
sh scripts/train_standalone.sh [device_id] [dataset_dir] [ckpt_path_to_save] [eval_each_epoch] [pretrained_ckpt(optional)]
|
||||
```
|
||||
|
||||
To validate the DPNs, run the shell script `scripts/eval.sh` with the format below:
|
||||
|
||||
```shell
|
||||
sh scripts/eval.sh [device_id] [dataset_dir] [pretrained_ckpt]
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
The structure of the files in this repository is shown below.
|
||||
|
||||
```text
|
||||
└─ mindspore-dpns
|
||||
├─ scripts
|
||||
│ ├─ eval.sh // launch ascend standalone evaluation
|
||||
│ ├─ train_distributed.sh // launch ascend distributed training
|
||||
│ └─ train_standalone.sh // launch ascend standalone training
|
||||
├─ src
|
||||
│ ├─ config.py // network and running config
|
||||
│ ├─ crossentropy.py // loss function
|
||||
│ ├─ dpn.py // dpns implementation
|
||||
│ ├─ imagenet_dataset.py // dataset processor and provider
|
||||
│ └─ lr_scheduler.py // dpn learning rate scheduler
|
||||
├─ eval.py // evaluation script
|
||||
├─ train.py // training script
|
||||
└─ README.md // descriptions about this repository
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in `src/config.py`
|
||||
|
||||
- Configurations for DPN92 with ImageNet-1K dataset
|
||||
|
||||
```python
|
||||
# model config
|
||||
config.image_size = (224,224) # inpute image size
|
||||
config.num_classes = 1000 # dataset class number
|
||||
config.backbone = 'dpn92' # backbone network
|
||||
config.is_save_on_master = True
|
||||
|
||||
# parallel config
|
||||
config.num_parallel_workers = 4 # number of workers to read the data
|
||||
config.rank = 0 # local rank of distributed
|
||||
config.group_size = 1 # group size of distributed
|
||||
|
||||
# training config
|
||||
config.batch_size = 32 # batch_size
|
||||
config.global_step = 0 # start step of learning rate
|
||||
config.epoch_size = 180 # epoch_size
|
||||
config.loss_scale_num = 1024 # loss scale
|
||||
# optimizer config
|
||||
config.momentum = 0.9 # momentum (SGD)
|
||||
config.weight_decay = 1e-4 # weight_decay (SGD)
|
||||
# learning rate config
|
||||
config.lr_schedule = 'warmup' # learning rate schedule
|
||||
config.lr_init = 0.01 # init learning rate
|
||||
config.lr_max = 0.1 # max learning rate
|
||||
config.factor = 0.1 # factor of lr to drop
|
||||
config.epoch_number_to_drop = [5,15] # learing rate will drop after these epochs
|
||||
config.warmup_epochs = 5 # warmup epochs in learning rate schedule
|
||||
|
||||
# dataset config
|
||||
config.dataset = "imagenet-1K" # dataset
|
||||
config.label_smooth = False # label_smooth
|
||||
config.label_smooth_factor = 0.0 # label_smooth_factor
|
||||
|
||||
# parameter save config
|
||||
config.keep_checkpoint_max = 3 # only keep the last keep_checkpoint_max checkpoint
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### [Training](#contents)
|
||||
|
||||
#### Running on Ascend
|
||||
|
||||
Run `scripts/train_standalone.sh` to train the model standalone. The usage of the script is:
|
||||
|
||||
```shell
|
||||
sh scripts/train_standalone.sh [device_id] [dataset_dir] [ckpt_path_to_save] [eval_each_epoch] [pretrained_ckpt(optional)]
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the training procedure.
|
||||
|
||||
```shell
|
||||
sh scripts/train_standalone.sh 0 /data/dataset/imagenet/ scripts/pretrian/ 0
|
||||
```
|
||||
|
||||
If eval_each_epoch is 1, it will evaluate after each epoch and save the parameters with the max accurracy. But in this case, the time of one epoch will be longer.
|
||||
|
||||
If eval_each_epoch is 0, it will save parameters every some epochs instead of evaluating in the training process.
|
||||
|
||||
The script will run training in the background, you can view the results through the file `train_log.txt` as follows (eval_each_epoch = 0):
|
||||
|
||||
```text
|
||||
epoch: 1 step: 40036, loss is 3.6232593
|
||||
Epoch time: 10048893.336, per step time: 250.996
|
||||
...
|
||||
```
|
||||
|
||||
or as follows (eval_each_epoch = 1):
|
||||
|
||||
```text
|
||||
epoch: 1 step: 40036, loss is 3.6232593
|
||||
Epoch time: 10048893.336, per step time: 250.996
|
||||
Save the maximum accuracy checkpoint,the accuracy is 0.2629158669225848
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved into `[ckpt_path_to_save]`.
|
||||
|
||||
### [Distributed Training](#contents)
|
||||
|
||||
#### Running on Ascend
|
||||
|
||||
Run `scripts/train_distributed.sh` to train the model distributed. The usage of the script is:
|
||||
|
||||
```text
|
||||
sh scripts/train_distributed.sh [rank_table] [dataset_dir] [ckpt_path_to_save] [rank_size] [eval_each_epoch] [pretrained_ckpt(optional)]
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the training procedure.
|
||||
|
||||
```shell
|
||||
sh scripts/train_distributed.sh /home/rank_table.json /data/dataset/imagenet/ ../scripts 8 0 ../pretrain/dpn92.ckpt
|
||||
```
|
||||
|
||||
The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt` as follows:
|
||||
|
||||
```text
|
||||
train_parallel0/log:
|
||||
epoch: 1 step 20018, loss is 5.74429988861084
|
||||
Epoch time: 7908183.789, per step time: 395.054, avg loss: 5.744
|
||||
train_parallel0/log:
|
||||
epoch: 2 step 20018, loss is 4.53381872177124
|
||||
Epoch time: 5036189.547, per step time: 251.583, avg loss: 4.534
|
||||
...
|
||||
train_parallel1/log:
|
||||
poch: 1 step 20018, loss is 5.751555442810059
|
||||
Epoch time: 7895946.079, per step time: 394.442, avg loss: 5.752
|
||||
train_parallel1/log:
|
||||
epoch: 2 step 20018, loss is 4.875896453857422
|
||||
Epoch time: 5036190.008, per step time: 251.583, avg loss: 4.876
|
||||
...
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved into `[ckpt_path_to_save]`.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### [Running on Ascend](#contents)
|
||||
|
||||
Run `scripts/eval.sh` to evaluate the model with one Ascend processor. The usage of the script is:
|
||||
|
||||
```text
|
||||
sh scripts/eval.sh [device_id] [dataset_dir] [pretrained_ckpt]
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the validation procedure.
|
||||
|
||||
```text
|
||||
sh scripts/eval.sh 0 /data/dataset/imageNet/ pretrain/dpn92.ckpt
|
||||
```
|
||||
|
||||
The above shell script will run evaluation in the background. You can view the results through the file `eval_log.txt`. The result will be achieved as follows:
|
||||
|
||||
```text
|
||||
Evaluation result: {'top_5_accuracy': 0.9449223751600512, 'top_1_accuracy': 0.7911731754161332}.
|
||||
DPN evaluate success!
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
The evaluation of model performance is divided into two parts: accuracy and efficiency. The part of accuracy shows the accuracy of the model in classifying images on ImageNet-1K dataset, and it can be evaluated by top-k measure. The part of efficiency reveals the time cost by model training on ImageNet-1K.
|
||||
|
||||
All results are validated at image size of 224x224. The dataset preprocessing and training configurations are shown in [Dataset](#dataset) section.
|
||||
|
||||
### [Accuracy](#contents)
|
||||
|
||||
The `Pretrain` tag in the table above means that the model's weights are converted from MXNet directly without further training. Relatively, the `Fine tune` tag means that the model is fine tuned after converted from MXNet.
|
||||
|
||||
#### DPN92 (Pretrain)
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ----------------- | --------------------------- |
|
||||
| Model Version | DPN92 (Pretrain) |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 09/19/2020 (month/day/year) |
|
||||
| MindSpore Version | 0.5.0 |
|
||||
| Dataset | ImageNet-1K |
|
||||
| outputs | probability |
|
||||
| train performance | Top1:79.12%; Top5:94.49% |
|
||||
|
||||
#### DPN98 (Pretrain)
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ----------------- | --------------------------- |
|
||||
| Model Version | DPN98 (Pretrain) |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 09/19/2020 (month/day/year) |
|
||||
| MindSpore Version | 0.5.0 |
|
||||
| Dataset | ImageNet-1K |
|
||||
| outputs | probability |
|
||||
| train performance | Top1:79.90%; Top5:94.81% |
|
||||
|
||||
#### DPN131 (Pretrain)
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ----------------- | --------------------------- |
|
||||
| Model Version | DPN131 (Pretrain) |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 09/19/2020 (month/day/year) |
|
||||
| MindSpore Version | 0.5.0 |
|
||||
| Dataset | ImageNet-1K |
|
||||
| outputs | probability |
|
||||
| train performance | Top1:79.96%; Top5:94.81% |
|
||||
|
||||
#### DPN92 (Fine tune)
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ----------------- | --------------------------- |
|
||||
| Model Version | DPN92 (Pretrain) |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 09/19/2020 (month/day/year) |
|
||||
| MindSpore Version | 0.5.0 |
|
||||
| Dataset | ImageNet-1K |
|
||||
| epochs | 30 |
|
||||
| outputs | probability |
|
||||
| train performance | Top1:79.30%; Top5:94.58% |
|
||||
|
||||
#### DPN92 (Training)
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ----------------- | --------------------------- |
|
||||
| Model Version | DPN92 (Train) |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 11/13/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | ImageNet-1K |
|
||||
| epochs | 180 |
|
||||
| outputs | probability |
|
||||
| train performance | Top1:78.91%; Top5:94.53% |
|
||||
|
||||
### [Efficiency](#contents)
|
||||
|
||||
#### DPN92
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ----------------- | --------------------------------- |
|
||||
| Model Version | DPN92 |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 09/19/2020 (month/day/year) |
|
||||
| MindSpore Version | 0.5.0 |
|
||||
| Dataset | ImageNet-1K |
|
||||
| batch_size | 32 |
|
||||
| outputs | probability |
|
||||
| speed | 1pc:127.90 img/s;8pc:1023.2 img/s |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DPN model eval with MindSpore"""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.dpn import dpns
|
||||
from src.config import config
|
||||
from src.imagenet_dataset import classification_dataset
|
||||
set_seed(1)
|
||||
# set context
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('dpn evaluating')
|
||||
# dataset related
|
||||
parser.add_argument('--data_dir', type=str, default='', help='eval data dir')
|
||||
# network related
|
||||
parser.add_argument('--pretrained', type=str, default='', help='ckpt path to load')
|
||||
args, _ = parser.parse_known_args()
|
||||
args.image_size = config.image_size
|
||||
args.num_classes = config.num_classes
|
||||
args.batch_size = config.batch_size
|
||||
args.num_parallel_workers = config.num_parallel_workers
|
||||
args.backbone = config.backbone
|
||||
args.loss_scale_num = config.loss_scale_num
|
||||
args.rank = config.rank
|
||||
args.group_size = config.group_size
|
||||
args.dataset = config.dataset
|
||||
return args
|
||||
|
||||
|
||||
def dpn_evaluate(args):
|
||||
# create evaluate dataset
|
||||
eval_path = os.path.join(args.data_dir, 'val')
|
||||
eval_dataset = classification_dataset(eval_path,
|
||||
image_size=args.image_size,
|
||||
num_parallel_workers=args.num_parallel_workers,
|
||||
per_batch_size=args.batch_size,
|
||||
max_epoch=1,
|
||||
rank=args.rank,
|
||||
shuffle=False,
|
||||
group_size=args.group_size,
|
||||
mode='eval')
|
||||
|
||||
# create network
|
||||
net = dpns[args.backbone](num_classes=args.num_classes)
|
||||
# load checkpoint
|
||||
if os.path.isfile(args.pretrained):
|
||||
load_param_into_net(net, load_checkpoint(args.pretrained))
|
||||
# loss
|
||||
if args.dataset == "imagenet-1K":
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
else:
|
||||
if not args.label_smooth:
|
||||
args.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
|
||||
|
||||
# create model
|
||||
model = Model(net, amp_level="O2", keep_batchnorm_fp32=False, loss_fn=loss,
|
||||
metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
# evaluate
|
||||
output = model.eval(eval_dataset)
|
||||
print(f'Evaluation result: {output}.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dpn_evaluate(parse_args())
|
||||
print('DPN evaluate success!')
|
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
PATH_CHECKPOINT=$3
|
||||
|
||||
python eval.py \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR > eval_log.txt 2>&1 &
|
|
@ -0,0 +1,69 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#Usage: sh train_distributed.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [SAVE_CKPT_PATH] [RANK_SIZE] [EVAL_EACH_EPOCH] [PRETRAINED_CKPT_PATH](optional)
|
||||
|
||||
DATA_DIR=$2
|
||||
export RANK_TABLE_FILE=$1
|
||||
echo "RaNK_TABLE_FiLE=$RANK_TABLE_FILE"
|
||||
export RANK_SIZE=$4
|
||||
SAVE_PATH=$3
|
||||
EVAL_EACH_EPOCH=$5
|
||||
PATH_CHECKPOINT=""
|
||||
if [ $# == 6 ]
|
||||
then
|
||||
PATH_CHECKPOINT=$6
|
||||
fi
|
||||
|
||||
device=(0 1 2 3 4 5 6 7)
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=${device[$i]}
|
||||
export RANK_ID=$i
|
||||
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
python train.py \
|
||||
--is_distributed=1 \
|
||||
--ckpt_path=$SAVE_PATH \
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH\
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
echo "python train.py \
|
||||
--is_distributed=1 \
|
||||
--ckpt_path=$SAVE_PATH \
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH\
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &"
|
||||
fi
|
||||
|
||||
if [ $# == 6 ]
|
||||
then
|
||||
python train.py \
|
||||
--is_distributed=1 \
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH\
|
||||
--ckpt_path=$SAVE_PATH \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#Usage: sh train_standalone.sh [DEVICE_ID] [DATA_DIR] [SAVE_CKPT_PATH] [EVAL_EACH_EPOCH] [PATH_CHECKPOINT]!
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
SAVE_CKPT_PATH=$3
|
||||
EVAL_EACH_EPOCH=$4
|
||||
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
PATH_CHECKPOINT=$5
|
||||
fi
|
||||
|
||||
if [ $# == 4 ]
|
||||
then
|
||||
python train.py \
|
||||
--is_distributed=0 \
|
||||
--ckpt_path=$SAVE_CKPT_PATH\
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH\
|
||||
--data_dir=$DATA_DIR > train_log.txt 2>&1 &
|
||||
echo " python train.py \
|
||||
--is_distributed=0 \
|
||||
--ckpt_path=$SAVE_CKPT_PATH\
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH\
|
||||
--data_dir=$DATA_DIR > train_log.txt 2>&1 &"
|
||||
fi
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
python train.py \
|
||||
--is_distributed=0 \
|
||||
--ckpt_path=$SAVE_CKPT_PATH\
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR\
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH > train_log.txt 2>&1 &
|
||||
fi
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
||||
|
||||
class SaveCallback(Callback):
|
||||
"""
|
||||
Evaluating on eval_dataset after each epoch.
|
||||
And it will save the parameters if the accuracy is better.
|
||||
"""
|
||||
|
||||
def __init__(self, model, eval_dataset, ckpt_path):
|
||||
super(SaveCallback, self).__init__()
|
||||
self.model = model
|
||||
self.eval_dataset = eval_dataset
|
||||
self.acc = 0.2
|
||||
self.ckpt_path = ckpt_path
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
epoch_num = cb_params.cur_epoch_num
|
||||
result = self.model.eval(self.eval_dataset)
|
||||
print("epoch", epoch_num, " top_1_accuracy:", result['top_1_accuracy'])
|
||||
if result['top_1_accuracy'] > self.acc:
|
||||
self.acc = result['top_1_accuracy']
|
||||
file_name = "max.ckpt"
|
||||
file_name = os.path.join(self.ckpt_path, file_name)
|
||||
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
|
||||
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# config for dpn,imagenet-1K
|
||||
config = edict()
|
||||
|
||||
# model config
|
||||
config.image_size = (224, 224) # inpute image size
|
||||
config.num_classes = 1000 # dataset class number
|
||||
config.backbone = 'dpn92' # backbone network
|
||||
config.is_save_on_master = True
|
||||
|
||||
# parallel config
|
||||
config.num_parallel_workers = 4 # number of workers to read the data
|
||||
config.rank = 0 # local rank of distributed
|
||||
config.group_size = 1 # group size of distributed
|
||||
|
||||
# training config
|
||||
config.batch_size = 32 # batch_size
|
||||
config.global_step = 0 # start step of learning rate
|
||||
config.epoch_size = 180 # epoch_size
|
||||
config.loss_scale_num = 1024 # loss scale
|
||||
# optimizer config
|
||||
config.momentum = 0.9 # momentum (SGD)
|
||||
config.weight_decay = 1e-4 # weight_decay (SGD)
|
||||
# learning rate config
|
||||
config.lr_schedule = 'warmup' # learning rate schedule
|
||||
config.lr_init = 0.01 # init learning rate
|
||||
config.lr_max = 0.1 # max learning rate
|
||||
config.factor = 0.1 # factor of lr to drop
|
||||
config.epoch_number_to_drop = [5, 15] # learing rate will drop after these epochs
|
||||
config.warmup_epochs = 5 # warmup epochs in learning rate schedule
|
||||
|
||||
# dataset config
|
||||
config.dataset = "imagenet-1K" # dataset
|
||||
config.label_smooth = False # label_smooth
|
||||
config.label_smooth_factor = 0.0 # label_smooth_factor
|
||||
|
||||
# parameter save config
|
||||
config.keep_checkpoint_max = 3 # only keep the last keep_checkpoint_max checkpoint
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
define loss function for network.
|
||||
"""
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
"""
|
||||
the redefined loss function with SoftmaxCrossEntropyWithLogits.
|
||||
"""
|
||||
|
||||
def __init__(self, smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logit, label):
|
||||
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, one_hot_label)
|
||||
loss = self.mean(loss, 0)
|
||||
return loss
|
|
@ -0,0 +1,206 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from collections import OrderedDict
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as F
|
||||
|
||||
__all__ = ['DPN', 'dpn92', 'dpn98', 'dpn131', 'dpn107', 'dpns']
|
||||
|
||||
|
||||
def dpn92(num_classes=1000):
|
||||
return DPN(num_init_features=64, k_R=96, G=32, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
|
||||
num_classes=num_classes)
|
||||
|
||||
|
||||
def dpn98(num_classes=1000):
|
||||
return DPN(num_init_features=96, k_R=160, G=40, k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128),
|
||||
num_classes=num_classes)
|
||||
|
||||
|
||||
def dpn131(num_classes=1000):
|
||||
return DPN(num_init_features=128, k_R=160, G=40, k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128),
|
||||
num_classes=num_classes)
|
||||
|
||||
|
||||
def dpn107(num_classes=1000):
|
||||
return DPN(num_init_features=128, k_R=200, G=50, k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128),
|
||||
num_classes=num_classes)
|
||||
|
||||
|
||||
dpns = {
|
||||
'dpn92': dpn92,
|
||||
'dpn98': dpn98,
|
||||
'dpn107': dpn107,
|
||||
'dpn131': dpn131,
|
||||
}
|
||||
|
||||
|
||||
class BottleBlock(nn.Cell):
|
||||
def __init__(self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, G, key_stride):
|
||||
super(BottleBlock, self).__init__()
|
||||
self.G = G
|
||||
self.bn1 = nn.BatchNorm2d(in_chs, eps=1e-3, momentum=0.9)
|
||||
self.conv1 = nn.Conv2d(in_chs, num_1x1_a, 1, stride=1)
|
||||
self.bn2 = nn.BatchNorm2d(num_1x1_a, eps=1e-3, momentum=0.9)
|
||||
self.conv2 = nn.CellList()
|
||||
for _ in range(G):
|
||||
self.conv2.append(nn.Conv2d(num_1x1_a // G, num_3x3_b // G, 3, key_stride, pad_mode='pad', padding=1))
|
||||
self.bn3 = nn.BatchNorm2d(num_3x3_b, eps=1e-3, momentum=0.9)
|
||||
self.conv3_r = nn.Conv2d(num_3x3_b, num_1x1_c, 1, stride=1)
|
||||
self.conv3_d = nn.Conv2d(num_3x3_b, inc, 1, stride=1)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.concat = F.Concat(axis=1)
|
||||
self.split = F.Split(axis=1, output_num=G)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x)
|
||||
group_x = ()
|
||||
input_x = self.split(x)
|
||||
for i in range(self.G):
|
||||
group_x = group_x + (self.conv2[i](input_x[i]),)
|
||||
x = self.concat(group_x)
|
||||
x = self.bn3(x)
|
||||
x = self.relu(x)
|
||||
return (self.conv3_r(x), self.conv3_d(x))
|
||||
|
||||
|
||||
class DualPathBlock(nn.Cell):
|
||||
def __init__(self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, G, _type='normal', cat_input=True):
|
||||
super(DualPathBlock, self).__init__()
|
||||
self.num_1x1_c = num_1x1_c
|
||||
|
||||
if _type == 'proj':
|
||||
key_stride = 1
|
||||
self.has_proj = True
|
||||
if _type == 'down':
|
||||
key_stride = 2
|
||||
self.has_proj = True
|
||||
if _type == 'normal':
|
||||
key_stride = 1
|
||||
self.has_proj = False
|
||||
|
||||
self.cat_input = cat_input
|
||||
|
||||
if self.has_proj:
|
||||
self.c1x1_w_bn = nn.BatchNorm2d(in_chs, eps=1e-3, momentum=0.9)
|
||||
self.c1x1_w_relu = nn.ReLU()
|
||||
self.c1x1_w_r = self.Conv1x1(in_chs=in_chs, out_chs=num_1x1_c, stride=key_stride)
|
||||
self.c1x1_w_d = self.Conv1x1(in_chs=in_chs, out_chs=2 * inc, stride=key_stride)
|
||||
|
||||
self.layers = BottleBlock(in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, G, key_stride)
|
||||
self.concat = F.Concat(axis=1)
|
||||
self.add = F.TensorAdd()
|
||||
|
||||
def Conv1x1(self, in_chs, out_chs, stride):
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=stride, pad_mode='pad', padding=0)
|
||||
|
||||
def construct(self, x):
|
||||
if self.cat_input:
|
||||
data_in = self.concat(x)
|
||||
else:
|
||||
data_in = x
|
||||
|
||||
if self.has_proj:
|
||||
data_o = self.c1x1_w_bn(data_in)
|
||||
data_o = self.c1x1_w_relu(data_o)
|
||||
data_o1 = self.c1x1_w_r(data_o)
|
||||
data_o2 = self.c1x1_w_d(data_o)
|
||||
else:
|
||||
data_o1 = x[0]
|
||||
data_o2 = x[1]
|
||||
|
||||
out = self.layers(data_in)
|
||||
summ = self.add(data_o1, out[0])
|
||||
dense = self.concat((data_o2, out[1]))
|
||||
return (summ, dense)
|
||||
|
||||
|
||||
class DPN(nn.Cell):
|
||||
|
||||
def __init__(self, num_init_features=64, k_R=96, G=32,
|
||||
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), num_classes=1000):
|
||||
|
||||
super(DPN, self).__init__()
|
||||
blocks = OrderedDict()
|
||||
|
||||
# conv1
|
||||
blocks['conv1'] = nn.SequentialCell(OrderedDict([
|
||||
('conv', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, pad_mode='pad', padding=3)),
|
||||
('norm', nn.BatchNorm2d(num_init_features, eps=1e-3, momentum=0.9)),
|
||||
('relu', nn.ReLU()),
|
||||
('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')),
|
||||
]))
|
||||
|
||||
# conv2
|
||||
bw = 256
|
||||
inc = inc_sec[0]
|
||||
R = int((k_R * bw) / 256)
|
||||
blocks['conv2_1'] = DualPathBlock(num_init_features, R, R, bw, inc, G, 'proj', False)
|
||||
in_chs = bw + 3 * inc
|
||||
for i in range(2, k_sec[0] + 1):
|
||||
blocks['conv2_{}'.format(i)] = DualPathBlock(in_chs, R, R, bw, inc, G, 'normal')
|
||||
in_chs += inc
|
||||
|
||||
# conv3
|
||||
bw = 512
|
||||
inc = inc_sec[1]
|
||||
R = int((k_R * bw) / 256)
|
||||
blocks['conv3_1'] = DualPathBlock(in_chs, R, R, bw, inc, G, 'down')
|
||||
in_chs = bw + 3 * inc
|
||||
for i in range(2, k_sec[1] + 1):
|
||||
blocks['conv3_{}'.format(i)] = DualPathBlock(in_chs, R, R, bw, inc, G, 'normal')
|
||||
in_chs += inc
|
||||
|
||||
# conv4
|
||||
bw = 1024
|
||||
inc = inc_sec[2]
|
||||
R = int((k_R * bw) / 256)
|
||||
blocks['conv4_1'] = DualPathBlock(in_chs, R, R, bw, inc, G, 'down')
|
||||
in_chs = bw + 3 * inc
|
||||
for i in range(2, k_sec[2] + 1):
|
||||
blocks['conv4_{}'.format(i)] = DualPathBlock(in_chs, R, R, bw, inc, G, 'normal')
|
||||
in_chs += inc
|
||||
|
||||
# conv5
|
||||
bw = 2048
|
||||
inc = inc_sec[3]
|
||||
R = int((k_R * bw) / 256)
|
||||
blocks['conv5_1'] = DualPathBlock(in_chs, R, R, bw, inc, G, 'down')
|
||||
in_chs = bw + 3 * inc
|
||||
for i in range(2, k_sec[3] + 1):
|
||||
blocks['conv5_{}'.format(i)] = DualPathBlock(in_chs, R, R, bw, inc, G, 'normal')
|
||||
in_chs += inc
|
||||
|
||||
self.features = nn.SequentialCell(blocks)
|
||||
self.concat = F.Concat(axis=1)
|
||||
self.conv5_x = nn.SequentialCell(OrderedDict([
|
||||
('norm', nn.BatchNorm2d(in_chs, eps=1e-3, momentum=0.9)),
|
||||
('relu', nn.ReLU()),
|
||||
]))
|
||||
self.avgpool = F.ReduceMean(False)
|
||||
self.classifier = nn.Dense(in_chs, num_classes)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.features(x)
|
||||
x = self.concat(x)
|
||||
x = self.conv5_x(x)
|
||||
x = self.avgpool(x, (2, 3))
|
||||
x = self.classifier(x)
|
||||
return x
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
dataset processing.
|
||||
"""
|
||||
from PIL import ImageFile
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as V_C
|
||||
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank, group_size,
|
||||
mode='train',
|
||||
num_parallel_workers=None,
|
||||
shuffle=None,
|
||||
sampler=None,
|
||||
class_indexing=None,
|
||||
transform=None,
|
||||
target_transform=None):
|
||||
"""
|
||||
A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
|
||||
If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
|
||||
are written into a textfile.
|
||||
|
||||
Args:
|
||||
data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
|
||||
Or path of the textfile that contains every image's path of the dataset.
|
||||
image_size (str): Size of the input images.
|
||||
per_batch_size (int): the batch size of evey step during training.
|
||||
max_epoch (int): the number of epochs.
|
||||
rank (int): The shard ID within num_shards (default=None).
|
||||
group_size (int): Number of shards that the dataset should be divided
|
||||
into (default=None).
|
||||
mode (str): "train" or others. Default: " train".
|
||||
input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
|
||||
root (str): the images path for "input_mode="txt"". Default: " ".
|
||||
num_parallel_workers (int): Number of workers to read the data. Default: None.
|
||||
shuffle (bool): Whether or not to perform shuffle on the dataset
|
||||
(default=None, performs shuffle).
|
||||
sampler (Sampler): Object used to choose samples from the dataset. Default: None.
|
||||
class_indexing (dict): A str-to-int mapping from folder name to index
|
||||
(default=None, the folder names will be sorted
|
||||
alphabetically and each class will be given a
|
||||
unique index starting from 0).
|
||||
|
||||
Examples:
|
||||
>>> from mindvision.common.datasets.classification import classification_dataset
|
||||
>>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
|
||||
>>> dataset_dir = "/path/to/imagefolder_directory"
|
||||
>>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
|
||||
>>> per_batch_size=64, max_epoch=100,
|
||||
>>> rank=0, group_size=4)
|
||||
>>> # Path of the textfile that contains every image's path of the dataset.
|
||||
>>> dataset_dir = "/path/to/dataset/images/train.txt"
|
||||
>>> images_dir = "/path/to/dataset/images"
|
||||
>>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
|
||||
>>> per_batch_size=64, max_epoch=100,
|
||||
>>> rank=0, group_size=4,
|
||||
>>> input_mode="txt", root=images_dir)
|
||||
"""
|
||||
if mode == 'eval':
|
||||
drop_remainder = False
|
||||
else:
|
||||
drop_remainder = True
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
|
||||
std = [255 * 0.229, 255 * 0.224, 255 * 0.225]
|
||||
|
||||
if transform is None:
|
||||
if mode == 'train':
|
||||
transform_img = [
|
||||
V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
V_C.RandomHorizontalFlip(prob=0.5),
|
||||
V_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
V_C.Normalize(mean=mean, std=std),
|
||||
V_C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
transform_img = [
|
||||
V_C.Decode(),
|
||||
V_C.Resize((256, 256)),
|
||||
V_C.CenterCrop(image_size),
|
||||
V_C.Normalize(mean=mean, std=std),
|
||||
V_C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
transform_img = transform
|
||||
|
||||
if target_transform is None:
|
||||
transform_label = [C.TypeCast(mstype.int32)]
|
||||
else:
|
||||
transform_label = target_transform
|
||||
|
||||
if group_size == 1 or mode == 'eval':
|
||||
de_dataset = de.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers,
|
||||
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing)
|
||||
else:
|
||||
de_dataset = de.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers,
|
||||
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
|
||||
num_shards=group_size, shard_id=rank)
|
||||
|
||||
de_dataset = de_dataset.map(operations=transform_img, input_columns="image", num_parallel_workers=8)
|
||||
de_dataset = de_dataset.map(operations=transform_label, input_columns="label", num_parallel_workers=8)
|
||||
|
||||
columns_to_project = ["image", "label"]
|
||||
de_dataset = de_dataset.project(columns=columns_to_project)
|
||||
|
||||
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
|
||||
de_dataset = de_dataset.repeat(max_epoch)
|
||||
|
||||
return de_dataset
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr_drop(global_step,
|
||||
total_epochs,
|
||||
steps_per_epoch,
|
||||
lr_init=0.316,
|
||||
factor=0.1,
|
||||
epoch_number_to_drop=(5, 15)
|
||||
):
|
||||
"""
|
||||
Generate learning rate array.
|
||||
|
||||
Args:
|
||||
global_step (int): Initial step of training.
|
||||
total_epochs (int): Total epoch of training.
|
||||
steps_per_epoch (float): Steps of one epoch.
|
||||
lr_init (float): Initial learning rate. Default: 0.316.
|
||||
epoch_number_to_drop:Learing rate will drop after these epochs.
|
||||
factor:Factor of lr to drop.
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
step_number_to_drop = [steps_per_epoch * x for x in epoch_number_to_drop]
|
||||
for i in range(int(total_steps)):
|
||||
if i in step_number_to_drop:
|
||||
lr_init = lr_init * factor
|
||||
lr_each_step.append(lr_init)
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
return learning_rate
|
||||
|
||||
|
||||
def get_lr_warmup(global_step,
|
||||
total_epochs,
|
||||
steps_per_epoch,
|
||||
lr_init=0.01,
|
||||
lr_max=0.1,
|
||||
warmup_epochs=5):
|
||||
"""
|
||||
Generate learning rate array.
|
||||
|
||||
Args:
|
||||
global_step (int): Initial step of training.
|
||||
total_epochs (int): Total epoch of training.
|
||||
steps_per_epoch (float): Steps of one epoch.
|
||||
lr_init (float): Initial learning rate. Default: 0.01.
|
||||
lr_max (float): Maximum learning rate. Default: 0.1.
|
||||
warmup_epochs (int): The number of warming up epochs. Default: 5.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(int(total_steps)):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
|
@ -0,0 +1,193 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DPN model train with MindSpore"""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import SGD
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.imagenet_dataset import classification_dataset
|
||||
from src.dpn import dpns
|
||||
from src.config import config
|
||||
from src.lr_scheduler import get_lr_drop, get_lr_warmup
|
||||
from src.crossentropy import CrossEntropy
|
||||
from src.callbacks import SaveCallback
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('dpn training')
|
||||
|
||||
# dataset related
|
||||
parser.add_argument('--data_dir', type=str, default='', help='Imagenet data dir')
|
||||
# network related
|
||||
parser.add_argument('--pretrained', default='', type=str, help='ckpt path to load')
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
|
||||
parser.add_argument('--ckpt_path', type=str, default='', help='ckpt path to save')
|
||||
parser.add_argument('--eval_each_epoch', type=int, default=0, help='ckpt path to save')
|
||||
args, _ = parser.parse_known_args()
|
||||
args.image_size = config.image_size
|
||||
args.num_classes = config.num_classes
|
||||
args.lr_init = config.lr_init
|
||||
args.lr_max = config.lr_max
|
||||
args.factor = config.factor
|
||||
args.global_step = config.global_step
|
||||
args.epoch_number_to_drop = config.epoch_number_to_drop
|
||||
args.epoch_size = config.epoch_size
|
||||
args.warmup_epochs = config.warmup_epochs
|
||||
args.weight_decay = config.weight_decay
|
||||
args.momentum = config.momentum
|
||||
args.batch_size = config.batch_size
|
||||
args.num_parallel_workers = config.num_parallel_workers
|
||||
args.backbone = config.backbone
|
||||
args.loss_scale_num = config.loss_scale_num
|
||||
args.is_save_on_master = config.is_save_on_master
|
||||
args.rank = config.rank
|
||||
args.group_size = config.group_size
|
||||
args.dataset = config.dataset
|
||||
args.label_smooth = config.label_smooth
|
||||
args.label_smooth_factor = config.label_smooth_factor
|
||||
args.keep_checkpoint_max = config.keep_checkpoint_max
|
||||
args.lr_schedule = config.lr_schedule
|
||||
return args
|
||||
|
||||
|
||||
def dpn_train(args):
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
context.set_auto_parallel_context(device_num=args.group_size, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
|
||||
# select for master rank save ckpt or all rank save, compatiable for model parallel
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
# create dataset
|
||||
args.train_dir = os.path.join(args.data_dir, 'train')
|
||||
args.eval_dir = os.path.join(args.data_dir, 'val')
|
||||
train_dataset = classification_dataset(args.train_dir,
|
||||
image_size=args.image_size,
|
||||
per_batch_size=args.batch_size,
|
||||
max_epoch=1,
|
||||
num_parallel_workers=args.num_parallel_workers,
|
||||
shuffle=True,
|
||||
rank=args.rank,
|
||||
group_size=args.group_size)
|
||||
if args.eval_each_epoch:
|
||||
print("create eval_dataset")
|
||||
eval_dataset = classification_dataset(args.eval_dir,
|
||||
image_size=args.image_size,
|
||||
per_batch_size=args.batch_size,
|
||||
max_epoch=1,
|
||||
num_parallel_workers=args.num_parallel_workers,
|
||||
shuffle=False,
|
||||
rank=args.rank,
|
||||
group_size=args.group_size,
|
||||
mode='eval')
|
||||
train_step_size = train_dataset.get_dataset_size()
|
||||
|
||||
# choose net
|
||||
net = dpns[args.backbone](num_classes=args.num_classes)
|
||||
|
||||
# load checkpoint
|
||||
if os.path.isfile(args.pretrained):
|
||||
print("load ckpt")
|
||||
load_param_into_net(net, load_checkpoint(args.pretrained))
|
||||
# learing rate schedule
|
||||
if args.lr_schedule == 'drop':
|
||||
print("lr_schedule:drop")
|
||||
lr = Tensor(get_lr_drop(global_step=args.global_step,
|
||||
total_epochs=args.epoch_size,
|
||||
steps_per_epoch=train_step_size,
|
||||
lr_init=args.lr_init,
|
||||
factor=args.factor))
|
||||
elif args.lr_schedule == 'warmup':
|
||||
print("lr_schedule:warmup")
|
||||
lr = Tensor(get_lr_warmup(global_step=args.global_step,
|
||||
total_epochs=args.epoch_size,
|
||||
steps_per_epoch=train_step_size,
|
||||
lr_init=args.lr_init,
|
||||
lr_max=args.lr_max,
|
||||
warmup_epochs=args.warmup_epochs))
|
||||
|
||||
# optimizer
|
||||
opt = SGD(net.trainable_params(),
|
||||
lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale_num)
|
||||
# loss scale
|
||||
loss_scale = FixedLossScaleManager(args.loss_scale_num, False)
|
||||
# loss function
|
||||
if args.dataset == "imagenet-1K":
|
||||
print("Use SoftmaxCrossEntropyWithLogits")
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
else:
|
||||
if not args.label_smooth:
|
||||
args.label_smooth_factor = 0.0
|
||||
print("Use Label_smooth CrossEntropy")
|
||||
loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
|
||||
# create model
|
||||
model = Model(net, amp_level="O2",
|
||||
keep_batchnorm_fp32=False,
|
||||
loss_fn=loss,
|
||||
optimizer=opt,
|
||||
loss_scale_manager=loss_scale,
|
||||
metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
# loss/time monitor & ckpt save callback
|
||||
loss_cb = LossMonitor()
|
||||
time_cb = TimeMonitor(data_size=train_step_size)
|
||||
cb = [loss_cb, time_cb]
|
||||
if args.rank_save_ckpt_flag:
|
||||
if args.eval_each_epoch:
|
||||
save_cb = SaveCallback(model, eval_dataset, args.ckpt_path)
|
||||
cb += [save_cb]
|
||||
else:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=train_step_size,
|
||||
keep_checkpoint_max=args.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="dpn", directory=args.ckpt_path, config=config_ck)
|
||||
cb.append(ckpoint_cb)
|
||||
# train model
|
||||
model.train(args.epoch_size, train_dataset, callbacks=cb)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dpn_train(parse_args())
|
||||
print('DPN training success!')
|
Loading…
Reference in New Issue