fix cnn efficientnet bugs

This commit is contained in:
panfengfeng 2021-02-07 16:44:50 +08:00
parent 027152b6ac
commit d72feda64b
6 changed files with 38 additions and 33 deletions

View File

@ -15,7 +15,7 @@
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
@ -76,13 +76,13 @@ do
if [ $# == 2 ]
then
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 > train.log 2>&1 &
fi
if [ $# == 3 ]
then
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log &
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 > train.log 2>&1 &
fi
cd ..
done
done

View File

@ -57,6 +57,6 @@ cd ./eval || exit
echo "start evaluation for device $DEVICE_ID"
env > env.log
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 #&> log &
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 > eval.log 2>&1 &
cd ..
cd ..

View File

@ -66,7 +66,7 @@ fi
if [ $# == 2 ]
then
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 > train.log 2>&1 &
fi
cd ..
cd ..

View File

@ -41,6 +41,7 @@ parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
args_opt = parser.parse_args()
@ -70,6 +71,13 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL)
init()
args_opt.rank_save_ckpt_flag = 0
if args_opt.is_save_on_master:
if rank_id == 0:
args_opt.rank_save_ckpt_flag = 1
else:
args_opt.rank_save_ckpt_flag = 1
# create dataset
dataset_name = config.dataset_name
dataset = create_dataset_train(args_opt.dataset_path + "/" + dataset_name +
@ -100,10 +108,11 @@ if __name__ == '__main__':
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
if args_opt.rank_save_ckpt_flag == 1:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# train model
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)

View File

@ -18,7 +18,6 @@
# [EfficientNet-B0 Description](#contents)
[Paper](https://arxiv.org/abs/1905.11946): Mingxing Tan, Quoc V. Le. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. 2019.
# [Model architecture](#contents)
@ -27,27 +26,25 @@ The overall network architecture of EfficientNet-B0 is show below:
[Link](https://arxiv.org/abs/1905.11946)
# [Dataset](#contents)
Dataset used: [imagenet](http://www.image-net.org/)
- Dataset size: ~125G, 1.2W colorful images in 1000 classes
- Train: 120G, 1.2W images
- Test: 5G, 50000 images
- Train: 120G, 1.2W images
- Test: 5G, 50000 images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py
- Note: Data will be processed in src/dataset.py
# [Environment Requirements](#contents)
- Hardware GPU
- Prepare hardware environment with GPU processor.
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
@ -77,7 +74,7 @@ Dataset used: [imagenet](http://www.image-net.org/)
Parameters for both training and evaluating can be set in config.py.
```
```python
'random_seed': 1, # fix random seed
'model': 'efficientnet_b0', # model name
'drop': 0.2, # dropout rate
@ -106,17 +103,17 @@ Parameters for both training and evaluating can be set in config.py.
## [Training Process](#contents)
#### Usage
### Usage
```
```python
GPU:
# distribute training example(8p)
sh run_distribute_train_for_gpu.sh
sh run_distribute_train_for_gpu.sh
# standalone training
sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR
```
#### Launch
### Launch
```bash
# distributed training example(8p) for GPU
@ -133,7 +130,7 @@ You can find checkpoint file together with result in log.
### Usage
```
```bash
# Evaluation
sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT
```
@ -148,9 +145,9 @@ sh run_eval_for_gpu.sh /dataset/eval ./checkpoint/efficientnet_b0-600_1251.ckpt
#### Result
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
Evaluation result will be stored in the scripts path. Under this, you can find result like the following in log.
```
```python
acc=76.96%(TOP1)
```
@ -186,7 +183,6 @@ acc=76.96%(TOP1)
| outputs | probability |
| Accuracy | acc=76.96%(TOP1) |
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -31,7 +31,7 @@ class RandAugment:
self.hparams = hparams
def __call__(self, imgs, labels, batchInfo):
# assert the imgs objetc are pil_images
# assert the imgs object are pil_images
ret_imgs = []
ret_labels = []
py_to_pil_op = P.ToPIL()