forked from mindspore-Ecosystem/mindspore
fix cnn efficientnet bugs
This commit is contained in:
parent
027152b6ac
commit
d72feda64b
|
@ -15,7 +15,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
if [ $# != 2 ] && [ $# != 3 ]
|
if [ $# != 2 ] && [ $# != 3 ]
|
||||||
then
|
then
|
||||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
|
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -76,13 +76,13 @@ do
|
||||||
|
|
||||||
if [ $# == 2 ]
|
if [ $# == 2 ]
|
||||||
then
|
then
|
||||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
|
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 > train.log 2>&1 &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $# == 3 ]
|
if [ $# == 3 ]
|
||||||
then
|
then
|
||||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log &
|
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 > train.log 2>&1 &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
done
|
done
|
||||||
|
|
|
@ -57,6 +57,6 @@ cd ./eval || exit
|
||||||
echo "start evaluation for device $DEVICE_ID"
|
echo "start evaluation for device $DEVICE_ID"
|
||||||
env > env.log
|
env > env.log
|
||||||
|
|
||||||
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 #&> log &
|
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 > eval.log 2>&1 &
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
|
@ -66,7 +66,7 @@ fi
|
||||||
|
|
||||||
if [ $# == 2 ]
|
if [ $# == 2 ]
|
||||||
then
|
then
|
||||||
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
|
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 > train.log 2>&1 &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
|
@ -41,6 +41,7 @@ parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||||
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
||||||
|
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
||||||
|
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
@ -70,6 +71,13 @@ if __name__ == '__main__':
|
||||||
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL)
|
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||||
init()
|
init()
|
||||||
|
|
||||||
|
args_opt.rank_save_ckpt_flag = 0
|
||||||
|
if args_opt.is_save_on_master:
|
||||||
|
if rank_id == 0:
|
||||||
|
args_opt.rank_save_ckpt_flag = 1
|
||||||
|
else:
|
||||||
|
args_opt.rank_save_ckpt_flag = 1
|
||||||
|
|
||||||
# create dataset
|
# create dataset
|
||||||
dataset_name = config.dataset_name
|
dataset_name = config.dataset_name
|
||||||
dataset = create_dataset_train(args_opt.dataset_path + "/" + dataset_name +
|
dataset = create_dataset_train(args_opt.dataset_path + "/" + dataset_name +
|
||||||
|
@ -100,10 +108,11 @@ if __name__ == '__main__':
|
||||||
loss_cb = LossMonitor()
|
loss_cb = LossMonitor()
|
||||||
cb = [time_cb, loss_cb]
|
cb = [time_cb, loss_cb]
|
||||||
if config.save_checkpoint:
|
if config.save_checkpoint:
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
if args_opt.rank_save_ckpt_flag == 1:
|
||||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||||
ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck)
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
cb += [ckpt_cb]
|
ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck)
|
||||||
|
cb += [ckpt_cb]
|
||||||
|
|
||||||
# train model
|
# train model
|
||||||
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)
|
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
|
|
||||||
# [EfficientNet-B0 Description](#contents)
|
# [EfficientNet-B0 Description](#contents)
|
||||||
|
|
||||||
|
|
||||||
[Paper](https://arxiv.org/abs/1905.11946): Mingxing Tan, Quoc V. Le. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. 2019.
|
[Paper](https://arxiv.org/abs/1905.11946): Mingxing Tan, Quoc V. Le. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. 2019.
|
||||||
|
|
||||||
# [Model architecture](#contents)
|
# [Model architecture](#contents)
|
||||||
|
@ -27,27 +26,25 @@ The overall network architecture of EfficientNet-B0 is show below:
|
||||||
|
|
||||||
[Link](https://arxiv.org/abs/1905.11946)
|
[Link](https://arxiv.org/abs/1905.11946)
|
||||||
|
|
||||||
|
|
||||||
# [Dataset](#contents)
|
# [Dataset](#contents)
|
||||||
|
|
||||||
Dataset used: [imagenet](http://www.image-net.org/)
|
Dataset used: [imagenet](http://www.image-net.org/)
|
||||||
|
|
||||||
- Dataset size: ~125G, 1.2W colorful images in 1000 classes
|
- Dataset size: ~125G, 1.2W colorful images in 1000 classes
|
||||||
- Train: 120G, 1.2W images
|
- Train: 120G, 1.2W images
|
||||||
- Test: 5G, 50000 images
|
- Test: 5G, 50000 images
|
||||||
- Data format: RGB images.
|
- Data format: RGB images.
|
||||||
- Note: Data will be processed in src/dataset.py
|
- Note: Data will be processed in src/dataset.py
|
||||||
|
|
||||||
|
|
||||||
# [Environment Requirements](#contents)
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
- Hardware GPU
|
- Hardware GPU
|
||||||
- Prepare hardware environment with GPU processor.
|
- Prepare hardware environment with GPU processor.
|
||||||
- Framework
|
- Framework
|
||||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
- For more information, please check the resources below:
|
- For more information, please check the resources below:
|
||||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||||
|
|
||||||
# [Script description](#contents)
|
# [Script description](#contents)
|
||||||
|
|
||||||
|
@ -77,7 +74,7 @@ Dataset used: [imagenet](http://www.image-net.org/)
|
||||||
|
|
||||||
Parameters for both training and evaluating can be set in config.py.
|
Parameters for both training and evaluating can be set in config.py.
|
||||||
|
|
||||||
```
|
```python
|
||||||
'random_seed': 1, # fix random seed
|
'random_seed': 1, # fix random seed
|
||||||
'model': 'efficientnet_b0', # model name
|
'model': 'efficientnet_b0', # model name
|
||||||
'drop': 0.2, # dropout rate
|
'drop': 0.2, # dropout rate
|
||||||
|
@ -106,17 +103,17 @@ Parameters for both training and evaluating can be set in config.py.
|
||||||
|
|
||||||
## [Training Process](#contents)
|
## [Training Process](#contents)
|
||||||
|
|
||||||
#### Usage
|
### Usage
|
||||||
|
|
||||||
```
|
```python
|
||||||
GPU:
|
GPU:
|
||||||
# distribute training example(8p)
|
# distribute training example(8p)
|
||||||
sh run_distribute_train_for_gpu.sh
|
sh run_distribute_train_for_gpu.sh
|
||||||
# standalone training
|
# standalone training
|
||||||
sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR
|
sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Launch
|
### Launch
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# distributed training example(8p) for GPU
|
# distributed training example(8p) for GPU
|
||||||
|
@ -133,7 +130,7 @@ You can find checkpoint file together with result in log.
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
```
|
```bash
|
||||||
# Evaluation
|
# Evaluation
|
||||||
sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT
|
sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT
|
||||||
```
|
```
|
||||||
|
@ -148,9 +145,9 @@ sh run_eval_for_gpu.sh /dataset/eval ./checkpoint/efficientnet_b0-600_1251.ckpt
|
||||||
|
|
||||||
#### Result
|
#### Result
|
||||||
|
|
||||||
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
|
Evaluation result will be stored in the scripts path. Under this, you can find result like the following in log.
|
||||||
|
|
||||||
```
|
```python
|
||||||
acc=76.96%(TOP1)
|
acc=76.96%(TOP1)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -186,7 +183,6 @@ acc=76.96%(TOP1)
|
||||||
| outputs | probability |
|
| outputs | probability |
|
||||||
| Accuracy | acc=76.96%(TOP1) |
|
| Accuracy | acc=76.96%(TOP1) |
|
||||||
|
|
||||||
|
|
||||||
# [ModelZoo Homepage](#contents)
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||||
|
|
|
@ -31,7 +31,7 @@ class RandAugment:
|
||||||
self.hparams = hparams
|
self.hparams = hparams
|
||||||
|
|
||||||
def __call__(self, imgs, labels, batchInfo):
|
def __call__(self, imgs, labels, batchInfo):
|
||||||
# assert the imgs objetc are pil_images
|
# assert the imgs object are pil_images
|
||||||
ret_imgs = []
|
ret_imgs = []
|
||||||
ret_labels = []
|
ret_labels = []
|
||||||
py_to_pil_op = P.ToPIL()
|
py_to_pil_op = P.ToPIL()
|
||||||
|
|
Loading…
Reference in New Issue