fix cnn efficientnet bugs

This commit is contained in:
panfengfeng 2021-02-07 16:44:50 +08:00
parent 027152b6ac
commit d72feda64b
6 changed files with 38 additions and 33 deletions

View File

@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
if [ $# != 2 ] && [ $# != 3 ] if [ $# != 2 ] && [ $# != 3 ]
then then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1 exit 1
fi fi
@ -76,13 +76,13 @@ do
if [ $# == 2 ] if [ $# == 2 ]
then then
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log & python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 > train.log 2>&1 &
fi fi
if [ $# == 3 ] if [ $# == 3 ]
then then
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log & python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 > train.log 2>&1 &
fi fi
cd .. cd ..
done done

View File

@ -57,6 +57,6 @@ cd ./eval || exit
echo "start evaluation for device $DEVICE_ID" echo "start evaluation for device $DEVICE_ID"
env > env.log env > env.log
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 #&> log & python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 > eval.log 2>&1 &
cd .. cd ..

View File

@ -66,7 +66,7 @@ fi
if [ $# == 2 ] if [ $# == 2 ]
then then
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 > train.log 2>&1 &
fi fi
cd .. cd ..

View File

@ -41,6 +41,7 @@ parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
args_opt = parser.parse_args() args_opt = parser.parse_args()
@ -70,6 +71,13 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL) context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL)
init() init()
args_opt.rank_save_ckpt_flag = 0
if args_opt.is_save_on_master:
if rank_id == 0:
args_opt.rank_save_ckpt_flag = 1
else:
args_opt.rank_save_ckpt_flag = 1
# create dataset # create dataset
dataset_name = config.dataset_name dataset_name = config.dataset_name
dataset = create_dataset_train(args_opt.dataset_path + "/" + dataset_name + dataset = create_dataset_train(args_opt.dataset_path + "/" + dataset_name +
@ -100,10 +108,11 @@ if __name__ == '__main__':
loss_cb = LossMonitor() loss_cb = LossMonitor()
cb = [time_cb, loss_cb] cb = [time_cb, loss_cb]
if config.save_checkpoint: if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, if args_opt.rank_save_ckpt_flag == 1:
keep_checkpoint_max=config.keep_checkpoint_max) config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck) keep_checkpoint_max=config.keep_checkpoint_max)
cb += [ckpt_cb] ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# train model # train model
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False) model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)

View File

@ -18,7 +18,6 @@
# [EfficientNet-B0 Description](#contents) # [EfficientNet-B0 Description](#contents)
[Paper](https://arxiv.org/abs/1905.11946): Mingxing Tan, Quoc V. Le. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. 2019. [Paper](https://arxiv.org/abs/1905.11946): Mingxing Tan, Quoc V. Le. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. 2019.
# [Model architecture](#contents) # [Model architecture](#contents)
@ -27,27 +26,25 @@ The overall network architecture of EfficientNet-B0 is show below:
[Link](https://arxiv.org/abs/1905.11946) [Link](https://arxiv.org/abs/1905.11946)
# [Dataset](#contents) # [Dataset](#contents)
Dataset used: [imagenet](http://www.image-net.org/) Dataset used: [imagenet](http://www.image-net.org/)
- Dataset size: ~125G, 1.2W colorful images in 1000 classes - Dataset size: ~125G, 1.2W colorful images in 1000 classes
- Train: 120G, 1.2W images - Train: 120G, 1.2W images
- Test: 5G, 50000 images - Test: 5G, 50000 images
- Data format: RGB images. - Data format: RGB images.
- Note: Data will be processed in src/dataset.py - Note: Data will be processed in src/dataset.py
# [Environment Requirements](#contents) # [Environment Requirements](#contents)
- Hardware GPU - Hardware GPU
- Prepare hardware environment with GPU processor. - Prepare hardware environment with GPU processor.
- Framework - Framework
- [MindSpore](https://www.mindspore.cn/install/en) - [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below - For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents) # [Script description](#contents)
@ -77,7 +74,7 @@ Dataset used: [imagenet](http://www.image-net.org/)
Parameters for both training and evaluating can be set in config.py. Parameters for both training and evaluating can be set in config.py.
``` ```python
'random_seed': 1, # fix random seed 'random_seed': 1, # fix random seed
'model': 'efficientnet_b0', # model name 'model': 'efficientnet_b0', # model name
'drop': 0.2, # dropout rate 'drop': 0.2, # dropout rate
@ -106,17 +103,17 @@ Parameters for both training and evaluating can be set in config.py.
## [Training Process](#contents) ## [Training Process](#contents)
#### Usage ### Usage
``` ```python
GPU: GPU:
# distribute training example(8p) # distribute training example(8p)
sh run_distribute_train_for_gpu.sh sh run_distribute_train_for_gpu.sh
# standalone training # standalone training
sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR
``` ```
#### Launch ### Launch
```bash ```bash
# distributed training example(8p) for GPU # distributed training example(8p) for GPU
@ -133,7 +130,7 @@ You can find checkpoint file together with result in log.
### Usage ### Usage
``` ```bash
# Evaluation # Evaluation
sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT
``` ```
@ -148,9 +145,9 @@ sh run_eval_for_gpu.sh /dataset/eval ./checkpoint/efficientnet_b0-600_1251.ckpt
#### Result #### Result
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log. Evaluation result will be stored in the scripts path. Under this, you can find result like the following in log.
``` ```python
acc=76.96%(TOP1) acc=76.96%(TOP1)
``` ```
@ -186,7 +183,6 @@ acc=76.96%(TOP1)
| outputs | probability | | outputs | probability |
| Accuracy | acc=76.96%(TOP1) | | Accuracy | acc=76.96%(TOP1) |
# [ModelZoo Homepage](#contents) # [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -31,7 +31,7 @@ class RandAugment:
self.hparams = hparams self.hparams = hparams
def __call__(self, imgs, labels, batchInfo): def __call__(self, imgs, labels, batchInfo):
# assert the imgs objetc are pil_images # assert the imgs object are pil_images
ret_imgs = [] ret_imgs = []
ret_labels = [] ret_labels = []
py_to_pil_op = P.ToPIL() py_to_pil_op = P.ToPIL()