!23336 WideResNet to master

Merge pull request !23336 from yangwm/master
This commit is contained in:
i-robot 2021-09-14 09:20:59 +00:00 committed by Gitee
commit c8c57ed04f
5 changed files with 46 additions and 18 deletions

View File

@ -59,12 +59,14 @@ WideResNet的总体网络架构如下[链接](https://arxiv.org/abs/1605.0714
- 下载数据集,目录结构如下: - 下载数据集,目录结构如下:
```text ```text
└─cifar-10-batches-bin └─cifar10
├── train
├─data_batch_1.bin # 训练数据集 ├─data_batch_1.bin # 训练数据集
├─data_batch_2.bin # 训练数据集 ├─data_batch_2.bin # 训练数据集
├─data_batch_3.bin # 训练数据集 ├─data_batch_3.bin # 训练数据集
├─data_batch_4.bin # 训练数据集 ├─data_batch_4.bin # 训练数据集
├─data_batch_5.bin # 训练数据集 ├─data_batch_5.bin # 训练数据集
├── eval
└─test_batch.bin # 评估数据集 └─test_batch.bin # 评估数据集
``` ```
@ -86,13 +88,23 @@ WideResNet的总体网络架构如下[链接](https://arxiv.org/abs/1605.0714
```Shell ```Shell
# 分布式训练 # 分布式训练
用法bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) 用法bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[MODELART]为True时执行ModelArts云上版本[CKPT_URL]是训练过程中保存ckpt文件的路径。
[MODELART]为False时执行线下版本[CKPT_URL]用“”省略只保留最佳ckpt结果文件名为WideResNet_best.ckpt
# 单机训练 # 单机训练
用法bash run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) 用法bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[MODELART]为True时执行ModelArts云上版本[CKPT_URL]是训练过程中保存ckpt文件的路径。
[MODELART]为False时执行线下版本[CKPT_URL]用“”省略只保留最佳ckpt结果文件名为WideResNet_best.ckpt
# 运行评估示例 # 运行评估示例
用法bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] 用法bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[CKPT_URL]训练好的ckpt文件。
[MODELART]为True时执行ModelArts云上版本为Flase执行线下脚本。
``` ```
# 脚本说明 # 脚本说明
@ -164,11 +176,16 @@ WideResNet的总体网络架构如下[链接](https://arxiv.org/abs/1605.0714
```Shell ```Shell
# 分布式训练 # 分布式训练
用法bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) 用法bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[MODELART]为True时执行ModelArts云上版本[CKPT_URL]是训练过程中保存ckpt文件的路径。
[MODELART]为False时执行线下版本[CKPT_URL]用“”省略只保留最佳ckpt结果文件名为WideResNet_best.ckpt
# 单机训练 # 单机训练
用法bash run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) 用法bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[MODELART]为True时执行ModelArts云上版本[CKPT_URL]是训练过程中保存ckpt文件的路径。
[MODELART]为False时执行线下版本[CKPT_URL]用“”省略只保留最佳ckpt结果文件名为WideResNet_best.ckpt
``` ```
分布式训练需要提前创建JSON格式的HCCL配置文件。 分布式训练需要提前创建JSON格式的HCCL配置文件。
@ -218,12 +235,15 @@ epoch: 4 step: 195, loss is 1.221174
```Shell ```Shell
# 评估 # 评估
Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]
[DATA_URL]是数据集的路径。
[CKPT_URL]训练好的ckpt文件。
[MODELART]为True时执行ModelArts云上版本为Flase执行线下脚本。
``` ```
```Shell ```Shell
# 评估示例 # 评估示例
bash run_eval.sh /cifar10 WideResNet_best.ckpt bash run_eval.sh /cifar10 WideResNet_best.ckpt False
``` ```
训练过程中可以生成检查点。 训练过程中可以生成检查点。
@ -244,6 +264,8 @@ result: {'top_1_accuracy': 0.9622395833333334}
```shell ```shell
python export.py --ckpt_file [CKPT_PATH] --file_format [FILE_FORMAT] --device_id [0] python export.py --ckpt_file [CKPT_PATH] --file_format [FILE_FORMAT] --device_id [0]
[CKPT_PATH]是训练后保存的ckpt文件
``` ```
参数ckpt_file为必填项 参数ckpt_file为必填项

View File

@ -14,6 +14,12 @@
# limitations under the License. # limitations under the License.
# ========================================================================== # ==========================================================================
if [ $# != 4 ]
then
echo "Usage: bash run_standalone_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]"
exit 1
fi
get_real_path(){ get_real_path(){
if [ "${1:0:1}" == "/" ]; then if [ "${1:0:1}" == "/" ]; then
echo "$1" echo "$1"

View File

@ -14,11 +14,11 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
#if [$# != 3] if [ $# != 3 ]
#then then
#echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]" echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]"
#exit 1 exit 1
#fi fi
get_real_path(){ get_real_path(){
if [ "${1:0:1}" == "/" ]; then if [ "${1:0:1}" == "/" ]; then

View File

@ -40,9 +40,9 @@ then
exit 1 exit 1
fi fi
if [ ! -f $PATH2 ] if [ ! -d $PATH2 ]
then then
echo "error: CKPT_URL=$PATH2 is not a file" echo "error: CKPT_URL=$PATH2 is not a directory"
exit 1 exit 1
fi fi

View File

@ -27,7 +27,7 @@ class SaveCallback(Callback):
super(SaveCallback, self).__init__() super(SaveCallback, self).__init__()
self.model = model self.model = model
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.cpkt_path = ckpt_path self.ckpt_path = ckpt_path
self.acc = 0.96 self.acc = 0.96
self.cur_acc = 0.0 self.cur_acc = 0.0
self.modelart = modelart self.modelart = modelart
@ -47,5 +47,5 @@ class SaveCallback(Callback):
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name) save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
if self.modelart: if self.modelart:
import moxing as mox import moxing as mox
mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.cpkt_path) mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.ckpt_path)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc) print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)