!23336 WideResNet to master

Merge pull request !23336 from yangwm/master
This commit is contained in:
i-robot 2021-09-14 09:20:59 +00:00 committed by Gitee
commit c8c57ed04f
5 changed files with 46 additions and 18 deletions

View File

@ -59,12 +59,14 @@ WideResNet的总体网络架构如下[链接](https://arxiv.org/abs/1605.0714
- 下载数据集,目录结构如下:
```text
└─cifar-10-batches-bin
└─cifar10
├── train
├─data_batch_1.bin # 训练数据集
├─data_batch_2.bin # 训练数据集
├─data_batch_3.bin # 训练数据集
├─data_batch_4.bin # 训练数据集
├─data_batch_5.bin # 训练数据集
├── eval
└─test_batch.bin # 评估数据集
```
@ -86,13 +88,23 @@ WideResNet的总体网络架构如下[链接](https://arxiv.org/abs/1605.0714
```Shell
# 分布式训练
用法bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
用法bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[MODELART]为True时执行ModelArts云上版本[CKPT_URL]是训练过程中保存ckpt文件的路径。
[MODELART]为False时执行线下版本[CKPT_URL]用“”省略只保留最佳ckpt结果文件名为WideResNet_best.ckpt
# 单机训练
用法bash run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
用法bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[MODELART]为True时执行ModelArts云上版本[CKPT_URL]是训练过程中保存ckpt文件的路径。
[MODELART]为False时执行线下版本[CKPT_URL]用“”省略只保留最佳ckpt结果文件名为WideResNet_best.ckpt
# 运行评估示例
用法bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
用法bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[CKPT_URL]训练好的ckpt文件。
[MODELART]为True时执行ModelArts云上版本为Flase执行线下脚本。
```
# 脚本说明
@ -164,11 +176,16 @@ WideResNet的总体网络架构如下[链接](https://arxiv.org/abs/1605.0714
```Shell
# 分布式训练
用法bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
用法bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[MODELART]为True时执行ModelArts云上版本[CKPT_URL]是训练过程中保存ckpt文件的路径。
[MODELART]为False时执行线下版本[CKPT_URL]用“”省略只保留最佳ckpt结果文件名为WideResNet_best.ckpt
# 单机训练
用法bash run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
用法bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[MODELART]为True时执行ModelArts云上版本[CKPT_URL]是训练过程中保存ckpt文件的路径。
[MODELART]为False时执行线下版本[CKPT_URL]用“”省略只保留最佳ckpt结果文件名为WideResNet_best.ckpt
```
分布式训练需要提前创建JSON格式的HCCL配置文件。
@ -218,12 +235,15 @@ epoch: 4 step: 195, loss is 1.221174
```Shell
# 评估
Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[CKPT_URL]训练好的ckpt文件。
[MODELART]为True时执行ModelArts云上版本为Flase执行线下脚本。
```
```Shell
# 评估示例
bash run_eval.sh /cifar10 WideResNet_best.ckpt
bash run_eval.sh /cifar10 WideResNet_best.ckpt False
```
训练过程中可以生成检查点。
@ -244,6 +264,8 @@ result: {'top_1_accuracy': 0.9622395833333334}
```shell
python export.py --ckpt_file [CKPT_PATH] --file_format [FILE_FORMAT] --device_id [0]
[CKPT_PATH]是训练后保存的ckpt文件
```
参数ckpt_file为必填项

View File

@ -14,6 +14,12 @@
# limitations under the License.
# ==========================================================================
if [ $# != 4 ]
then
echo "Usage: bash run_standalone_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"

View File

@ -14,11 +14,11 @@
# limitations under the License.
# ============================================================================
#if [$# != 3]
#then
#echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]"
#exit 1
#fi
if [ $# != 3 ]
then
echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then

View File

@ -40,9 +40,9 @@ then
exit 1
fi
if [ ! -f $PATH2 ]
if [ ! -d $PATH2 ]
then
echo "error: CKPT_URL=$PATH2 is not a file"
echo "error: CKPT_URL=$PATH2 is not a directory"
exit 1
fi

View File

@ -27,7 +27,7 @@ class SaveCallback(Callback):
super(SaveCallback, self).__init__()
self.model = model
self.eval_dataset = eval_dataset
self.cpkt_path = ckpt_path
self.ckpt_path = ckpt_path
self.acc = 0.96
self.cur_acc = 0.0
self.modelart = modelart
@ -47,5 +47,5 @@ class SaveCallback(Callback):
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
if self.modelart:
import moxing as mox
mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.cpkt_path)
mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.ckpt_path)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)