support xception lr slice and fix usability problems

This commit is contained in:
pengyanjun 2020-12-29 21:24:42 +08:00 committed by Yanjun Peng
parent b0794bb5f6
commit 806d4d304c
8 changed files with 43 additions and 33 deletions

View File

@ -81,9 +81,10 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
├─config.py # parameter configuration
├─dataset.py # data preprocessing
├─Xception.py # network definition
├─CrossEntropySmooth.py # Customized CrossEntropy loss function
├─loss.py # Customized CrossEntropy loss function
└─lr_generator.py # learning rate generator
├─train.py # train net
├─export.py # export net
└─eval.py # eval net
```
@ -110,7 +111,6 @@ Major parameters in train.py and config.py are:
'lr_init': 0.00004 # initiate learning rate
'lr_max': 0.4 # max bound of learning rate
'lr_end': 0.00004 # min bound of learning rate
"weight_init": 'xavier_uniform' # Weight initialization mode
```
## [Training process](#contents)
@ -149,13 +149,13 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
### Result
Training result will be stored in the example path. Checkpoints will be stored at `. /model_0` by default, and training log will be redirected to `log.txt` like followings.
Training result will be stored in the example path. Checkpoints will be stored at `. /ckpt_0` by default, and training log will be redirected to `log.txt` like followings.
``` shell
epoch: [ 0/250], step:[ 1250/ 1251], loss:[4.761/5.613], time:[529.305], lr:[0.400]
epoch time: 1128662.862, per step time: 902.209, avg loss: 5.609
epoch: [ 1/250], step:[ 1250/ 1251], loss:[4.164/4.318], time:[503.708], lr:[0.398]
epoch time: 889163.081, per step time: 710.762, avg loss: 4.312
epoch: 1 step: 1251, loss is 4.8427444
epoch time: 701242.350 ms, per step time: 560.545 ms
epoch: 2 step: 1251, loss is 4.0637593
epoch time: 598591.422 ms, per step time: 478.490ms
```
## [Eval process](#contents)
@ -199,18 +199,19 @@ result: {'Loss': 1.7797744848789312, 'Top_1_Acc': 0.7985777243589743, 'Top_5_Acc
| -------------------------- | ---------------------------------------------- |
| Model Version | Xception |
| Resource | HUAWEI CLOUD Modelarts |
| uploaded Date | 11/15/2020 |
| MindSpore Version | 1.0.0 |
| uploaded Date | 12/10/2020 |
| MindSpore Version | 1.1.0 |
| Dataset | 1200k images |
| Batch_size | 128 |
| Training Parameters | src/config.py |
| Optimizer | Momentum |
| Loss Function | CrossEntropySmooth |
| Loss | 1.78 |
| Accuracy (8p) | Top1[79.9%] Top5[94.9%] |
| Total time (8p) | 63h |
| Accuracy (8p) | Top1[79.8%] Top5[94.8%] |
| Per step time (8p) | 479 ms/step |
| Total time (8p) | 42h |
| Params (M) | 180M |
| Scripts | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/Xception) |
| Scripts | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/xception) |
#### Inference Performance
@ -231,4 +232,4 @@ In `dataset.py`, we set the seed inside `create_dataset` function. We also use r
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -27,13 +27,10 @@ echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf train_parallel$i
mkdir ./train_parallel$i
@ -42,7 +39,7 @@ do
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python ../train.py \
python ../train.py \
--is_distributed \
--device_target=Ascend \
--dataset_path=$DATA_DIR > log.txt 2>&1 &

View File

@ -18,8 +18,14 @@ export DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3
python ./eval.py \
rm -rf eval_output
mkdir ./eval_output
cd ./eval_output || exit
echo "start evaluating model..."
python ../eval.py \
--device_target=Ascend \
--device_id=$DEVICE_ID \
--checkpoint_path=$PATH_CHECKPOINT \
--dataset_path=$DATA_DIR > eval.log 2>&1 &
cd ../

View File

@ -16,7 +16,13 @@
export DEVICE_ID=$1
DATA_DIR=$2
python ./train.py \
rm -rf train_standalone
mkdir ./train_standalone
cd ./train_standalone || exit
echo "start training standalone on device $DEVICE_ID"
python ../train.py \
--device_target=Ascend \
--dataset_path=$DATA_DIR > log.txt 2>&1 &
cd ../

View File

@ -15,15 +15,14 @@
"""Xception."""
import mindspore.nn as nn
import mindspore.ops.operations as P
from src.config import config
class SeparableConv2d(nn.Cell):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super(SeparableConv2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, group=in_channels, pad_mode='pad',
padding=padding, weight_init=config.weight_init)
padding=padding, weight_init='xavier_uniform')
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, pad_mode='valid',
weight_init=config.weight_init)
weight_init='xavier_uniform')
def construct(self, x):
x = self.conv1(x)
@ -37,7 +36,7 @@ class Block(nn.Cell):
if out_filters != in_filters or strides != 1:
self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, pad_mode='valid', has_bias=False,
weight_init=config.weight_init)
weight_init='xavier_uniform')
self.skipbn = nn.BatchNorm2d(out_filters, momentum=0.9)
else:
self.skip = None
@ -96,10 +95,10 @@ class Xception(nn.Cell):
"""
super(Xception, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv2d(3, 32, 3, 2, pad_mode='valid', weight_init=config.weight_init)
self.conv1 = nn.Conv2d(3, 32, 3, 2, pad_mode='valid', weight_init='xavier_uniform')
self.bn1 = nn.BatchNorm2d(32, momentum=0.9)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(32, 64, 3, pad_mode='valid', weight_init=config.weight_init)
self.conv2 = nn.Conv2d(32, 64, 3, pad_mode='valid', weight_init='xavier_uniform')
self.bn2 = nn.BatchNorm2d(64, momentum=0.9)
# Entry flow

View File

@ -36,6 +36,5 @@ config = ed({
"label_smooth_factor": 0.1,
"lr_init": 0.00004,
"lr_max": 0.4,
"lr_end": 0.00004,
"weight_init": 'xavier_uniform'
"lr_end": 0.00004
})

View File

@ -17,7 +17,7 @@ import math
import numpy as np
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode, global_step=0):
"""
generate learning rate array
@ -82,6 +82,6 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)
lr_each_step = np.array(lr_each_step[global_step:]).astype(np.float32)
return lr_each_step

View File

@ -59,7 +59,8 @@ if __name__ == '__main__':
else:
rank = 0
group_size = 1
context.set_context(device_id=0)
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
# define network
net = xception(class_num=config.class_num)
@ -88,7 +89,8 @@ if __name__ == '__main__':
warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size,
steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode))
lr_decay_mode=config.lr_decay_mode,
global_step=config.finish_epoch * step_size))
# define optimization
opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay, config.loss_scale)
@ -100,7 +102,7 @@ if __name__ == '__main__':
# define callbacks
cb = [TimeMonitor(), LossMonitor()]
if config.save_checkpoint:
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'model_' + str(rank) + '/')
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(f"Xception-rank{rank}", directory=save_ckpt_path, config=config_ck)