!23336 WideResNet to master
Merge pull request !23336 from yangwm/master
This commit is contained in:
commit
c8c57ed04f
|
@ -59,12 +59,14 @@ WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.0714
|
|||
- 下载数据集,目录结构如下:
|
||||
|
||||
```text
|
||||
└─cifar-10-batches-bin
|
||||
└─cifar10
|
||||
├── train
|
||||
├─data_batch_1.bin # 训练数据集
|
||||
├─data_batch_2.bin # 训练数据集
|
||||
├─data_batch_3.bin # 训练数据集
|
||||
├─data_batch_4.bin # 训练数据集
|
||||
├─data_batch_5.bin # 训练数据集
|
||||
├── eval
|
||||
└─test_batch.bin # 评估数据集
|
||||
```
|
||||
|
||||
|
@ -86,13 +88,23 @@ WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.0714
|
|||
|
||||
```Shell
|
||||
# 分布式训练
|
||||
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]
|
||||
[DATA_URL]是数据集的路径。
|
||||
[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。
|
||||
[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。
|
||||
。
|
||||
|
||||
# 单机训练
|
||||
用法:bash run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
用法:bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]
|
||||
[DATA_URL]是数据集的路径。
|
||||
[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。
|
||||
[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。
|
||||
|
||||
# 运行评估示例
|
||||
用法:bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
用法:bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]
|
||||
[DATA_URL]是数据集的路径。
|
||||
[CKPT_URL]训练好的ckpt文件。
|
||||
[MODELART]为True时执行ModelArts云上版本,为Flase执行线下脚本。
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
@ -164,11 +176,16 @@ WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.0714
|
|||
|
||||
```Shell
|
||||
# 分布式训练
|
||||
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]
|
||||
[DATA_URL]是数据集的路径。
|
||||
[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。
|
||||
[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。
|
||||
|
||||
# 单机训练
|
||||
用法:bash run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
||||
|
||||
用法:bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]
|
||||
[DATA_URL]是数据集的路径。
|
||||
[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。
|
||||
[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。
|
||||
```
|
||||
|
||||
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
||||
|
@ -218,12 +235,15 @@ epoch: 4 step: 195, loss is 1.221174
|
|||
|
||||
```Shell
|
||||
# 评估
|
||||
Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]
|
||||
[DATA_URL]是数据集的路径。
|
||||
[CKPT_URL]训练好的ckpt文件。
|
||||
[MODELART]为True时执行ModelArts云上版本,为Flase执行线下脚本。
|
||||
```
|
||||
|
||||
```Shell
|
||||
# 评估示例
|
||||
bash run_eval.sh /cifar10 WideResNet_best.ckpt
|
||||
bash run_eval.sh /cifar10 WideResNet_best.ckpt False
|
||||
```
|
||||
|
||||
训练过程中可以生成检查点。
|
||||
|
@ -244,6 +264,8 @@ result: {'top_1_accuracy': 0.9622395833333334}
|
|||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_format [FILE_FORMAT] --device_id [0]
|
||||
|
||||
[CKPT_PATH]是训练后保存的ckpt文件
|
||||
```
|
||||
|
||||
参数ckpt_file为必填项,
|
||||
|
|
|
@ -14,6 +14,12 @@
|
|||
# limitations under the License.
|
||||
# ==========================================================================
|
||||
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: bash run_standalone_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
|
|
|
@ -14,11 +14,11 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
#if [$# != 3]
|
||||
#then
|
||||
#echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]"
|
||||
#exit 1
|
||||
#fi
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
|
|
|
@ -40,9 +40,9 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: CKPT_URL=$PATH2 is not a file"
|
||||
echo "error: CKPT_URL=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ class SaveCallback(Callback):
|
|||
super(SaveCallback, self).__init__()
|
||||
self.model = model
|
||||
self.eval_dataset = eval_dataset
|
||||
self.cpkt_path = ckpt_path
|
||||
self.ckpt_path = ckpt_path
|
||||
self.acc = 0.96
|
||||
self.cur_acc = 0.0
|
||||
self.modelart = modelart
|
||||
|
@ -47,5 +47,5 @@ class SaveCallback(Callback):
|
|||
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
|
||||
if self.modelart:
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.cpkt_path)
|
||||
mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.ckpt_path)
|
||||
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
|
||||
|
|
Loading…
Reference in New Issue