forked from mindspore-Ecosystem/mindspore
!23336 WideResNet to master
Merge pull request !23336 from yangwm/master
This commit is contained in:
commit
c8c57ed04f
|
@ -59,12 +59,14 @@ WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.0714
|
||||||
- 下载数据集,目录结构如下:
|
- 下载数据集,目录结构如下:
|
||||||
|
|
||||||
```text
|
```text
|
||||||
└─cifar-10-batches-bin
|
└─cifar10
|
||||||
|
├── train
|
||||||
├─data_batch_1.bin # 训练数据集
|
├─data_batch_1.bin # 训练数据集
|
||||||
├─data_batch_2.bin # 训练数据集
|
├─data_batch_2.bin # 训练数据集
|
||||||
├─data_batch_3.bin # 训练数据集
|
├─data_batch_3.bin # 训练数据集
|
||||||
├─data_batch_4.bin # 训练数据集
|
├─data_batch_4.bin # 训练数据集
|
||||||
├─data_batch_5.bin # 训练数据集
|
├─data_batch_5.bin # 训练数据集
|
||||||
|
├── eval
|
||||||
└─test_batch.bin # 评估数据集
|
└─test_batch.bin # 评估数据集
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -86,13 +88,23 @@ WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.0714
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
# 分布式训练
|
# 分布式训练
|
||||||
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]
|
||||||
|
[DATA_URL]是数据集的路径。
|
||||||
|
[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。
|
||||||
|
[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。
|
||||||
|
。
|
||||||
|
|
||||||
# 单机训练
|
# 单机训练
|
||||||
用法:bash run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
用法:bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]
|
||||||
|
[DATA_URL]是数据集的路径。
|
||||||
|
[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。
|
||||||
|
[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。
|
||||||
|
|
||||||
# 运行评估示例
|
# 运行评估示例
|
||||||
用法:bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
用法:bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]
|
||||||
|
[DATA_URL]是数据集的路径。
|
||||||
|
[CKPT_URL]训练好的ckpt文件。
|
||||||
|
[MODELART]为True时执行ModelArts云上版本,为Flase执行线下脚本。
|
||||||
```
|
```
|
||||||
|
|
||||||
# 脚本说明
|
# 脚本说明
|
||||||
|
@ -164,11 +176,16 @@ WideResNet的总体网络架构如下:[链接](https://arxiv.org/abs/1605.0714
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
# 分布式训练
|
# 分布式训练
|
||||||
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]
|
||||||
|
[DATA_URL]是数据集的路径。
|
||||||
|
[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。
|
||||||
|
[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。
|
||||||
|
|
||||||
# 单机训练
|
# 单机训练
|
||||||
用法:bash run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选)
|
用法:bash run_standalone_train.sh [DATA_URL] [CKPT_URL] [MODELART]
|
||||||
|
[DATA_URL]是数据集的路径。
|
||||||
|
[MODELART]为True时执行ModelArts云上版本,[CKPT_URL]是训练过程中保存ckpt文件的路径。
|
||||||
|
[MODELART]为False时执行线下版本,[CKPT_URL]用“”省略,只保留最佳ckpt结果,文件名为‘WideResNet_best.ckpt’。
|
||||||
```
|
```
|
||||||
|
|
||||||
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
||||||
|
@ -218,12 +235,15 @@ epoch: 4 step: 195, loss is 1.221174
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
# 评估
|
# 评估
|
||||||
Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]
|
||||||
|
[DATA_URL]是数据集的路径。
|
||||||
|
[CKPT_URL]训练好的ckpt文件。
|
||||||
|
[MODELART]为True时执行ModelArts云上版本,为Flase执行线下脚本。
|
||||||
```
|
```
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
# 评估示例
|
# 评估示例
|
||||||
bash run_eval.sh /cifar10 WideResNet_best.ckpt
|
bash run_eval.sh /cifar10 WideResNet_best.ckpt False
|
||||||
```
|
```
|
||||||
|
|
||||||
训练过程中可以生成检查点。
|
训练过程中可以生成检查点。
|
||||||
|
@ -244,6 +264,8 @@ result: {'top_1_accuracy': 0.9622395833333334}
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python export.py --ckpt_file [CKPT_PATH] --file_format [FILE_FORMAT] --device_id [0]
|
python export.py --ckpt_file [CKPT_PATH] --file_format [FILE_FORMAT] --device_id [0]
|
||||||
|
|
||||||
|
[CKPT_PATH]是训练后保存的ckpt文件
|
||||||
```
|
```
|
||||||
|
|
||||||
参数ckpt_file为必填项,
|
参数ckpt_file为必填项,
|
||||||
|
|
|
@ -14,6 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
|
|
||||||
|
if [ $# != 4 ]
|
||||||
|
then
|
||||||
|
echo "Usage: bash run_standalone_train.sh [RANK_TABLE_FILE] [DATA_URL] [CKPT_URL] [MODELART]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
get_real_path(){
|
get_real_path(){
|
||||||
if [ "${1:0:1}" == "/" ]; then
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
echo "$1"
|
echo "$1"
|
||||||
|
|
|
@ -14,11 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
#if [$# != 3]
|
if [ $# != 3 ]
|
||||||
#then
|
then
|
||||||
#echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]"
|
echo "Usage: bash run_eval.sh [DATA_URL] [CKPT_URL] [MODELART]"
|
||||||
#exit 1
|
exit 1
|
||||||
#fi
|
fi
|
||||||
|
|
||||||
get_real_path(){
|
get_real_path(){
|
||||||
if [ "${1:0:1}" == "/" ]; then
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
|
|
@ -40,9 +40,9 @@ then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -f $PATH2 ]
|
if [ ! -d $PATH2 ]
|
||||||
then
|
then
|
||||||
echo "error: CKPT_URL=$PATH2 is not a file"
|
echo "error: CKPT_URL=$PATH2 is not a directory"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ class SaveCallback(Callback):
|
||||||
super(SaveCallback, self).__init__()
|
super(SaveCallback, self).__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
self.cpkt_path = ckpt_path
|
self.ckpt_path = ckpt_path
|
||||||
self.acc = 0.96
|
self.acc = 0.96
|
||||||
self.cur_acc = 0.0
|
self.cur_acc = 0.0
|
||||||
self.modelart = modelart
|
self.modelart = modelart
|
||||||
|
@ -47,5 +47,5 @@ class SaveCallback(Callback):
|
||||||
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
|
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
|
||||||
if self.modelart:
|
if self.modelart:
|
||||||
import moxing as mox
|
import moxing as mox
|
||||||
mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.cpkt_path)
|
mox.file.copy_parallel(src_url=cfg.save_checkpoint_path, dst_url=self.ckpt_path)
|
||||||
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
|
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
|
||||||
|
|
Loading…
Reference in New Issue