From fe4efe5c74cd4e06f2846490d0869838219069f8 Mon Sep 17 00:00:00 2001 From: yangwm Date: Mon, 13 Sep 2021 14:40:35 +0800 Subject: [PATCH] WideResNet --- model_zoo/research/cv/wideresnet/README_CN.md | 40 ++++++++++++++----- .../scripts/run_distribute_train.sh | 6 +++ .../cv/wideresnet/scripts/run_eval.sh | 10 ++--- .../scripts/run_standalone_train.sh | 4 +- .../cv/wideresnet/src/save_callback.py | 4 +- 5 files changed, 46 insertions(+), 18 deletions(-) diff --git a/model_zoo/research/cv/wideresnet/README_CN.md b/model_zoo/research/cv/wideresnet/README_CN.md index 75b4117ce4c..c8d8b1f2232 100644 --- a/model_zoo/research/cv/wideresnet/README_CN.md +++ b/model_zoo/research/cv/wideresnet/README_CN.md @@ -59,12 +59,14 @@ WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.0714 - 下载数据集,目录结构如下: ```text -└─cifar-10-batches-bin +└─cifar10 + ├── train ├─data_batch_1.bin # 训练数据集 ├─data_batch_2.bin # 训练数据集 ├─data_batch_3.bin # 训练数据集 ├─data_batch_4.bin # 训练数据集 ├─data_batch_5.bin # 训练数据集 + ├── eval └─test_batch.bin # 评估数据集 ``` @@ -86,13 +88,23 @@ WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.0714 ```Shell # 分布式训练 -用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) +用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART] +[DATA_URL]是数据集的路径。 +[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。 +[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。 +。 # 单机训练 -用法:bash run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) +用法:bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART] +[DATA_URL]是数据集的路径。 +[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。 +[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。 # 运行评估示例 -用法:bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] +用法:bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART] +[DATA_URL]是数据集的路径。 +[CKPT_URL]训练好的ckpt文件。 +[MODELART]为True时执行ModelArts云上版本,为Flase执行线下脚本。 ``` # 脚本说明 @@ -164,11 +176,16 @@ WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.0714 ```Shell # 分布式训练 -用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) +用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART] +[DATA_URL]是数据集的路径。 +[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。 +[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。 # 单机训练 -用法:bash run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) - +用法:bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART] +[DATA_URL]是数据集的路径。 +[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。 +[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。 ``` 分布式训练需要提前创建JSON格式的HCCL配置文件。 @@ -218,12 +235,15 @@ epoch: 4 step: 195, loss is 1.221174 ```Shell # 评估 -Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] +Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART] +[DATA_URL]是数据集的路径。 +[CKPT_URL]训练好的ckpt文件。 +[MODELART]为True时执行ModelArts云上版本,为Flase执行线下脚本。 ``` ```Shell # 评估示例 -bash run_eval.sh /cifar10 WideResNet_best.ckpt +bash run_eval.sh /cifar10 WideResNet_best.ckpt False ``` 训练过程中可以生成检查点。 @@ -244,6 +264,8 @@ result: {'top_1_accuracy': 0.9622395833333334} ```shell python export.py --ckpt_file [CKPT_PATH] --file_format [FILE_FORMAT] --device_id [0] + +[CKPT_PATH]是训练后保存的ckpt文件 ``` 参数ckpt_file为必填项, diff --git a/model_zoo/research/cv/wideresnet/scripts/run_distribute_train.sh b/model_zoo/research/cv/wideresnet/scripts/run_distribute_train.sh index bff739691bd..2aea9583fd8 100644 --- a/model_zoo/research/cv/wideresnet/scripts/run_distribute_train.sh +++ b/model_zoo/research/cv/wideresnet/scripts/run_distribute_train.sh @@ -14,6 +14,12 @@ # limitations under the License. # ========================================================================== +if [ $# != 4 ] +then + echo "Usage: bash run_standalone_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]" +exit 1 +fi + get_real_path(){ if [ "${1:0:1}" == "/" ]; then echo "$1" diff --git a/model_zoo/research/cv/wideresnet/scripts/run_eval.sh b/model_zoo/research/cv/wideresnet/scripts/run_eval.sh index 2db1e9f3176..715e4c3d307 100644 --- a/model_zoo/research/cv/wideresnet/scripts/run_eval.sh +++ b/model_zoo/research/cv/wideresnet/scripts/run_eval.sh @@ -14,11 +14,11 @@ # limitations under the License. # ============================================================================ -#if [$# != 3] -#then - #echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]" -#exit 1 -#fi +if [ $# != 3 ] +then + echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]" +exit 1 +fi get_real_path(){ if [ "${1:0:1}" == "/" ]; then diff --git a/model_zoo/research/cv/wideresnet/scripts/run_standalone_train.sh b/model_zoo/research/cv/wideresnet/scripts/run_standalone_train.sh index 26201e70133..fc8a9df36c0 100644 --- a/model_zoo/research/cv/wideresnet/scripts/run_standalone_train.sh +++ b/model_zoo/research/cv/wideresnet/scripts/run_standalone_train.sh @@ -40,9 +40,9 @@ then exit 1 fi -if [ ! -f $PATH2 ] +if [ ! -d $PATH2 ] then - echo "error: CKPT_URL=$PATH2 is not a file" + echo "error: CKPT_URL=$PATH2 is not a directory" exit 1 fi diff --git a/model_zoo/research/cv/wideresnet/src/save_callback.py b/model_zoo/research/cv/wideresnet/src/save_callback.py index 061c644df17..065c87cb479 100644 --- a/model_zoo/research/cv/wideresnet/src/save_callback.py +++ b/model_zoo/research/cv/wideresnet/src/save_callback.py @@ -27,7 +27,7 @@ class SaveCallback(Callback): super(SaveCallback, self).__init__() self.model = model self.eval_dataset = eval_dataset - self.cpkt_path = ckpt_path + self.ckpt_path = ckpt_path self.acc = 0.96 self.cur_acc = 0.0 self.modelart = modelart @@ -47,5 +47,5 @@ class SaveCallback(Callback): save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name) if self.modelart: import moxing as mox - mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.cpkt_path) + mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.ckpt_path) print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)