forked from mindspore-Ecosystem/mindspore
!15947 modify deeplabv3 network for clould
From: @zhanghuiyao Reviewed-by: @c_34,@oacjiewen Signed-off-by: @c_34
This commit is contained in:
commit
3e4b2e049f
|
@ -104,7 +104,7 @@ For single device training, please config parameters, training script is:
|
||||||
run_standalone_train.sh
|
run_standalone_train.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
For 8 devices training, training steps are as follows:
|
- For 8 devices training, training steps are as follows:
|
||||||
|
|
||||||
1. Train s16 with vocaug dataset, finetuning from resnet101 pretrained model, script is:
|
1. Train s16 with vocaug dataset, finetuning from resnet101 pretrained model, script is:
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ For 8 devices training, training steps are as follows:
|
||||||
run_distribute_train_s8_r2.sh
|
run_distribute_train_s8_r2.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
For evaluation, evaluating steps are as follows:
|
- For evaluation, evaluating steps are as follows:
|
||||||
|
|
||||||
1. Eval s16 with voc val dataset, eval script is:
|
1. Eval s16 with voc val dataset, eval script is:
|
||||||
|
|
||||||
|
@ -150,6 +150,238 @@ For evaluation, evaluating steps are as follows:
|
||||||
run_eval_s8_multiscale_flip.sh
|
run_eval_s8_multiscale_flip.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- Train on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
|
||||||
|
|
||||||
|
1. Train s16 with vocaug dataset on modelarts, finetuning from resnet101 pretrained model, training steps are as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) Perform a or b.
|
||||||
|
# a. Set "enable_modelarts=True" on base_config.yaml file.
|
||||||
|
# Set "data_file='/cache/data/vocaug/vocaug_mindrecord/vocaug_mindrecord0'" on base_config.yaml file.
|
||||||
|
# Set "checkpoint_url=/The path of checkpoint in S3/" on beta_config.yaml file.
|
||||||
|
# Set "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/resnet101.ckpt" on base_config.yaml file.
|
||||||
|
# Set "base_lr=0.08" on base_config.yaml file.
|
||||||
|
# Set "is_distributed=True" on base_config.yaml file.
|
||||||
|
# Set "save_steps=410" on base_config.yaml file.
|
||||||
|
# Set other parameters on base_config.yaml file you need.
|
||||||
|
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||||
|
# Add "data_file=/cache/data/vocaug/vocaug_mindrecord/vocaug_mindrecord0" on the website UI interface.
|
||||||
|
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||||
|
# Add "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/resnet101.ckpt" on the website UI interface.
|
||||||
|
# Add "base_lr=0.08" on the website UI interface.
|
||||||
|
# Add "is_distributed=True" on the website UI interface.
|
||||||
|
# Add "save_steps=410" on the website UI interface.
|
||||||
|
# Add other parameters on the website UI interface.
|
||||||
|
# (2) Upload or copy your pretrained model to S3 bucket.
|
||||||
|
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||||
|
# (4) Set the code directory to "/path/deeplabv3" on the website UI interface.
|
||||||
|
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||||
|
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (7) Create your job.
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Train s8 with vocaug dataset on modelarts, finetuning from model in previous step, training steps are as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) Perform a or b.
|
||||||
|
# a. Set "enable_modelarts=True" on base_config.yaml file.
|
||||||
|
# Set "model='deeplab_v3_s8'" on base_config.yaml file.
|
||||||
|
# Set "train_epochs=800" on base_config.yaml file.
|
||||||
|
# Set "batch_size=16" on base_config.yaml file.
|
||||||
|
# Set "base_lr=0.02" on base_config.yaml file.
|
||||||
|
# Set "loss_scale=2048" on base_config.yaml file.
|
||||||
|
# Set "data_file='/cache/data/vocaug/vocaug_mindrecord/vocaug_mindrecord0'" on base_config.yaml file.
|
||||||
|
# Set "checkpoint_url=/The path of checkpoint in S3/" on beta_config.yaml file.
|
||||||
|
# Set "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s16-300_41.ckpt" on base_config.yaml file.
|
||||||
|
# Set "is_distributed=True" on base_config.yaml file.
|
||||||
|
# Set "save_steps=820" on base_config.yaml file.
|
||||||
|
# Set other parameters on base_config.yaml file you need.
|
||||||
|
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||||
|
# Add "model='deeplab_v3_s8'" on the website UI interface.
|
||||||
|
# Add "train_epochs=800" on the website UI interface.
|
||||||
|
# Add "batch_size=16" on the website UI interface.
|
||||||
|
# Add "base_lr=0.02" on the website UI interface.
|
||||||
|
# Add "loss_scale=2048" on the website UI interface.
|
||||||
|
# Add "data_file='/cache/data/vocaug/vocaug_mindrecord/vocaug_mindrecord0'" on the website UI interface.
|
||||||
|
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||||
|
# Add "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s16-300_41.ckpt" on the website UI interface.
|
||||||
|
# Add "is_distributed=True" on the website UI interface.
|
||||||
|
# Add "save_steps=820" on the website UI interface.
|
||||||
|
# Add other parameters on the website UI interface.
|
||||||
|
# (2) Upload or copy your pretrained model to S3 bucket.
|
||||||
|
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||||
|
# (4) Set the code directory to "/path/deeplabv3" on the website UI interface.
|
||||||
|
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||||
|
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (7) Create your job.
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Train s8 with voctrain dataset on modelarts, finetuning from model in previous step, training steps are as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) Perform a or b.
|
||||||
|
# a. Set "enable_modelarts=True" on base_config.yaml file.
|
||||||
|
# Set "model='deeplab_v3_s8'" on base_config.yaml file.
|
||||||
|
# Set "batch_size=16" on base_config.yaml file.
|
||||||
|
# Set "base_lr=0.008" on base_config.yaml file.
|
||||||
|
# Set "loss_scale=2048" on base_config.yaml file.
|
||||||
|
# Set "data_file='/cache/data/vocaug/voctrain_mindrecord/voctrain_mindrecord00'" on base_config.yaml file.
|
||||||
|
# Set "checkpoint_url=/The path of checkpoint in S3/" on beta_config.yaml file.
|
||||||
|
# Set "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-800_82.ckpt" on base_config.yaml file.
|
||||||
|
# Set "is_distributed=True" on base_config.yaml file.
|
||||||
|
# Set "save_steps=110" on base_config.yaml file.
|
||||||
|
# Set other parameters on base_config.yaml file you need.
|
||||||
|
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||||
|
# Add "model='deeplab_v3_s8'" on the website UI interface.
|
||||||
|
# Add "batch_size=16" on the website UI interface.
|
||||||
|
# Add "base_lr=0.008" on the website UI interface.
|
||||||
|
# Add "loss_scale=2048" on the website UI interface.
|
||||||
|
# Add "data_file='/cache/data/vocaug/voctrain_mindrecord/voctrain_mindrecord00'" on the website UI interface.
|
||||||
|
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||||
|
# Add "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-800_82.ckpt" on the website UI interface.
|
||||||
|
# Add "is_distributed=True" on the website UI interface.
|
||||||
|
# Add "save_steps=110" on the website UI interface.
|
||||||
|
# Add other parameters on the website UI interface.
|
||||||
|
# (2) Upload or copy your pretrained model to S3 bucket.
|
||||||
|
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||||
|
# (4) Set the code directory to "/path/deeplabv3" on the website UI interface.
|
||||||
|
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||||
|
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (7) Create your job.
|
||||||
|
```
|
||||||
|
|
||||||
|
- Eval on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start evaluating as follows)
|
||||||
|
|
||||||
|
1. Eval s16 with voc val dataset on modelarts, evaluating steps are as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) Perform a or b.
|
||||||
|
# a. Set "enable_modelarts=True" on base_config.yaml file.
|
||||||
|
# Set "model='deeplab_v3_s16'" on base_config.yaml file.
|
||||||
|
# Set "batch_size=32" on base_config.yaml file.
|
||||||
|
# Set "scales_type=0" on base_config.yaml file.
|
||||||
|
# Set "freeze_bn=True" on base_config.yaml file.
|
||||||
|
# Set "data_root='/cache/data/vocaug'" on base_config.yaml file.
|
||||||
|
# Set "data_lst='/cache/data/vocaug/voc_val_lst.txt'" on base_config.yaml file.
|
||||||
|
# Set "checkpoint_url=/The path of checkpoint in S3/" on beta_config.yaml file.
|
||||||
|
# Set "ckpt_path='/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s16-300_41.ckpt'" on base_config.yaml file.
|
||||||
|
# Set other parameters on base_config.yaml file you need.
|
||||||
|
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||||
|
# Add "model=deeplab_v3_s16" on the website UI interface.
|
||||||
|
# Add "batch_size=32" on the website UI interface.
|
||||||
|
# Add "scales_type=0" on the website UI interface.
|
||||||
|
# Add "freeze_bn=True" on the website UI interface.
|
||||||
|
# Add "data_root=/cache/data/vocaug" on the website UI interface.
|
||||||
|
# Add "data_lst=/cache/data/vocaug/voc_val_lst.txt" on the website UI interface.
|
||||||
|
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||||
|
# Add "ckpt_path=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s16-300_41.ckpt" on the website UI interface.
|
||||||
|
# Add other parameters on the website UI interface.
|
||||||
|
# (2) Upload or copy your pretrained model to S3 bucket.
|
||||||
|
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||||
|
# (4) Set the code directory to "/path/deeplabv3" on the website UI interface.
|
||||||
|
# (5) Set the startup file to "eval.py" on the website UI interface.
|
||||||
|
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (7) Create your job.
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Eval s8 with voc val dataset on modelarts, evaluating steps are as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) Perform a or b.
|
||||||
|
# a. Set "enable_modelarts=True" on base_config.yaml file.
|
||||||
|
# Set "model='deeplab_v3_s8'" on base_config.yaml file.
|
||||||
|
# Set "batch_size=16" on base_config.yaml file.
|
||||||
|
# Set "scales_type=0" on base_config.yaml file.
|
||||||
|
# Set "freeze_bn=True" on base_config.yaml file.
|
||||||
|
# Set "data_root='/cache/data/vocaug'" on base_config.yaml file.
|
||||||
|
# Set "data_lst='/cache/data/vocaug/voc_val_lst.txt'" on base_config.yaml file.
|
||||||
|
# Set "checkpoint_url='/The path of checkpoint in S3/'" on beta_config.yaml file.
|
||||||
|
# Set "ckpt_path='/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt'" on base_config.yaml file.
|
||||||
|
# Set other parameters on base_config.yaml file you need.
|
||||||
|
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||||
|
# Add "model=deeplab_v3_s8" on the website UI interface.
|
||||||
|
# Add "batch_size=16" on the website UI interface.
|
||||||
|
# Add "scales_type=0" on the website UI interface.
|
||||||
|
# Add "freeze_bn=True" on the website UI interface.
|
||||||
|
# Add "data_root=/cache/data/vocaug" on the website UI interface.
|
||||||
|
# Add "data_lst=/cache/data/vocaug/voc_val_lst.txt" on the website UI interface.
|
||||||
|
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||||
|
# Add "ckpt_path=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt" on the website UI interface.
|
||||||
|
# Add other parameters on the website UI interface.
|
||||||
|
# (2) Upload or copy your pretrained model to S3 bucket.
|
||||||
|
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||||
|
# (4) Set the code directory to "/path/deeplabv3" on the website UI interface.
|
||||||
|
# (5) Set the startup file to "eval.py" on the website UI interface.
|
||||||
|
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (7) Create your job.
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Eval s8 multiscale with voc val dataset on modelarts, evaluating steps are as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) Perform a or b.
|
||||||
|
# a. Set "enable_modelarts=True" on base_config.yaml file.
|
||||||
|
# Set "model='deeplab_v3_s8'" on base_config.yaml file.
|
||||||
|
# Set "batch_size=16" on base_config.yaml file.
|
||||||
|
# Set "scales_type=1" on base_config.yaml file.
|
||||||
|
# Set "freeze_bn=True" on base_config.yaml file.
|
||||||
|
# Set "data_root='/cache/data/vocaug'" on base_config.yaml file.
|
||||||
|
# Set "data_lst='/cache/data/vocaug/voc_val_lst.txt'" on base_config.yaml file.
|
||||||
|
# Set "checkpoint_url='/The path of checkpoint in S3/'" on beta_config.yaml file.
|
||||||
|
# Set "ckpt_path='/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt'" on base_config.yaml file.
|
||||||
|
# Set other parameters on base_config.yaml file you need.
|
||||||
|
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||||
|
# Add "model=deeplab_v3_s8" on the website UI interface.
|
||||||
|
# Add "batch_size=16" on the website UI interface.
|
||||||
|
# Add "scales_type=1" on the website UI interface.
|
||||||
|
# Add "freeze_bn=True" on the website UI interface.
|
||||||
|
# Add "data_root=/cache/data/vocaug" on the website UI interface.
|
||||||
|
# Add "data_lst=/cache/data/vocaug/voc_val_lst.txt" on the website UI interface.
|
||||||
|
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||||
|
# Add "ckpt_path=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt" on the website UI interface.
|
||||||
|
# Add other parameters on the website UI interface.
|
||||||
|
# (2) Upload or copy your pretrained model to S3 bucket.
|
||||||
|
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||||
|
# (4) Set the code directory to "/path/deeplabv3" on the website UI interface.
|
||||||
|
# (5) Set the startup file to "eval.py" on the website UI interface.
|
||||||
|
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (7) Create your job.
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Eval s8 multiscale and flip with voc val dataset on modelarts, evaluating steps are as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) Perform a or b.
|
||||||
|
# a. Set "enable_modelarts=True" on base_config.yaml file.
|
||||||
|
# Set "model='deeplab_v3_s8'" on base_config.yaml file.
|
||||||
|
# Set "batch_size=16" on base_config.yaml file.
|
||||||
|
# Set "scales_type=1" on base_config.yaml file.
|
||||||
|
# Set "freeze_bn=True" on base_config.yaml file.
|
||||||
|
# Set "flip=True" on base_config.yaml file.
|
||||||
|
# Set "data_root='/cache/data/vocaug'" on base_config.yaml file.
|
||||||
|
# Set "data_lst='/cache/data/vocaug/voc_val_lst.txt'" on base_config.yaml file.
|
||||||
|
# Set "checkpoint_url='/The path of checkpoint in S3/'" on beta_config.yaml file.
|
||||||
|
# Set "ckpt_path='/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt'" on base_config.yaml file.
|
||||||
|
# Set other parameters on base_config.yaml file you need.
|
||||||
|
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||||
|
# Add "model=deeplab_v3_s8" on the website UI interface.
|
||||||
|
# Add "batch_size=16" on the website UI interface.
|
||||||
|
# Add "scales_type=1" on the website UI interface.
|
||||||
|
# Add "freeze_bn=True" on the website UI interface.
|
||||||
|
# Add "flip=True" on the website UI interface.
|
||||||
|
# Add "data_root=/cache/data/vocaug" on the website UI interface.
|
||||||
|
# Add "data_lst=/cache/data/vocaug/voc_val_lst.txt" on the website UI interface.
|
||||||
|
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
||||||
|
# Add "ckpt_path=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt" on the website UI interface.
|
||||||
|
# Add other parameters on the website UI interface.
|
||||||
|
# (2) Upload or copy your pretrained model to S3 bucket.
|
||||||
|
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||||
|
# (4) Set the code directory to "/path/deeplabv3" on the website UI interface.
|
||||||
|
# (5) Set the startup file to "eval.py" on the website UI interface.
|
||||||
|
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||||
|
# (7) Create your job.
|
||||||
|
```
|
||||||
|
|
||||||
# [Script Description](#contents)
|
# [Script Description](#contents)
|
||||||
|
|
||||||
## [Script and Sample Code](#contents)
|
## [Script and Sample Code](#contents)
|
||||||
|
|
|
@ -119,7 +119,7 @@ Pascal VOC数据集和语义边界数据集(Semantic Boundaries Dataset,SBD
|
||||||
run_standalone_train.sh
|
run_standalone_train.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
按照以下训练步骤进行8卡训练:
|
- 按照以下训练步骤进行8卡训练:
|
||||||
|
|
||||||
1. 使用VOCaug数据集训练s16,微调ResNet-101预训练模型。脚本如下:
|
1. 使用VOCaug数据集训练s16,微调ResNet-101预训练模型。脚本如下:
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ run_standalone_train.sh
|
||||||
run_distribute_train_s8_r2.sh
|
run_distribute_train_s8_r2.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
评估步骤如下:
|
- 评估步骤如下:
|
||||||
|
|
||||||
1. 使用voc val数据集评估s16。评估脚本如下:
|
1. 使用voc val数据集评估s16。评估脚本如下:
|
||||||
|
|
||||||
|
@ -165,6 +165,238 @@ run_standalone_train.sh
|
||||||
run_eval_s8_multiscale_flip.sh
|
run_eval_s8_multiscale_flip.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- 在 ModelArts 进行训练 (如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
|
||||||
|
|
||||||
|
1. 在 modelarts 使用VOCaug数据集训练s16,微调ResNet-101预训练模型。训练步骤如下:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) 执行 a 或者 b.
|
||||||
|
# a. 在 base_config.yaml 文件中设置 "enable_modelarts=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_file='/cache/data/vocaug/vocaug_mindrecord/vocaug_mindrecord0'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在 base_config.yaml 文件中设置 "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/resnet101.ckpt"
|
||||||
|
# 在 base_config.yaml 文件中设置 "base_lr=0.08"
|
||||||
|
# 在 base_config.yaml 文件中设置 "is_distributed=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "save_steps=410"
|
||||||
|
# 在 base_config.yaml 文件中设置 其他参数
|
||||||
|
# b. 在网页上设置 "enable_modelarts=True"
|
||||||
|
# 在网页上设置 "data_file=/cache/data/vocaug/vocaug_mindrecord/vocaug_mindrecord0"
|
||||||
|
# 在网页上设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在网页上设置 "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/resnet101.ckpt"
|
||||||
|
# 在网页上设置 "base_lr=0.08"
|
||||||
|
# 在网页上设置 "is_distributed=True"
|
||||||
|
# 在网页上设置 "save_steps=410"
|
||||||
|
# 在网页上设置 其他参数
|
||||||
|
# (2) 上传你的预训练模型到 S3 桶上
|
||||||
|
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||||
|
# (4) 在网页上设置你的代码路径为 "/path/deeplabv3"
|
||||||
|
# (5) 在网页上设置启动文件为 "train.py"
|
||||||
|
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||||
|
# (7) 创建训练作业
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 使用VOCaug数据集训练s8,微调上一步的模型。训练步骤如下:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) 执行 a 或者 b.
|
||||||
|
# a. 在 base_config.yaml 文件中设置 "enable_modelarts=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "model='deeplab_v3_s8'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "train_epochs=800"
|
||||||
|
# 在 base_config.yaml 文件中设置 "batch_size=16"
|
||||||
|
# 在 base_config.yaml 文件中设置 "base_lr=0.02"
|
||||||
|
# 在 base_config.yaml 文件中设置 "loss_scale=2048"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_file='/cache/data/vocaug/vocaug_mindrecord/vocaug_mindrecord0'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在 base_config.yaml 文件中设置 "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s16-300_41.ckpt"
|
||||||
|
# 在 base_config.yaml 文件中设置 "is_distributed=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "save_steps=820"
|
||||||
|
# 在 base_config.yaml 文件中设置 其他参数
|
||||||
|
# b. 在网页上设置 "enable_modelarts=True"
|
||||||
|
# 在网页上设置 "model='deeplab_v3_s8'"
|
||||||
|
# 在网页上设置 "train_epochs=800"
|
||||||
|
# 在网页上设置 "batch_size=16"
|
||||||
|
# 在网页上设置 "base_lr=0.02"
|
||||||
|
# 在网页上设置 "loss_scale=2048"
|
||||||
|
# 在网页上设置 "data_file='/cache/data/vocaug/vocaug_mindrecord/vocaug_mindrecord0'"
|
||||||
|
# 在网页上设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在网页上设置 "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s16-300_41.ckpt"
|
||||||
|
# 在网页上设置 "is_distributed=True"
|
||||||
|
# 在网页上设置 "save_steps=820"
|
||||||
|
# 在网页上设置 其他参数
|
||||||
|
# (2) 上传你的预训练模型到 S3 桶上
|
||||||
|
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||||
|
# (4) 在网页上设置你的代码路径为 "/path/deeplabv3"
|
||||||
|
# (5) 在网页上设置启动文件为 "train.py"
|
||||||
|
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||||
|
# (7) 创建训练作业
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 使用VOCtrain数据集训练s8,微调上一步的模型。训练步骤如下:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) 执行 a 或者 b.
|
||||||
|
# a. 在 base_config.yaml 文件中设置 "enable_modelarts=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "model='deeplab_v3_s8'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "batch_size=16"
|
||||||
|
# 在 base_config.yaml 文件中设置 "base_lr=0.008"
|
||||||
|
# 在 base_config.yaml 文件中设置 "loss_scale=2048"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_file='/cache/data/vocaug/voctrain_mindrecord/voctrain_mindrecord00'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在 base_config.yaml 文件中设置 "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-800_82.ckpt"
|
||||||
|
# 在 base_config.yaml 文件中设置 "is_distributed=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "save_steps=110"
|
||||||
|
# 在 base_config.yaml 文件中设置 其他参数
|
||||||
|
# b. 在网页上设置 "enable_modelarts=True"
|
||||||
|
# 在网页上设置 "model='deeplab_v3_s8'"
|
||||||
|
# 在网页上设置 "batch_size=16"
|
||||||
|
# 在网页上设置 "base_lr=0.008"
|
||||||
|
# 在网页上设置 "loss_scale=2048"
|
||||||
|
# 在网页上设置 "data_file='/cache/data/vocaug/voctrain_mindrecord/voctrain_mindrecord00'"
|
||||||
|
# 在网页上设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在网页上设置 "ckpt_pre_trained=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-800_82.ckpt"
|
||||||
|
# 在网页上设置 "is_distributed=True"
|
||||||
|
# 在网页上设置 "save_steps=110"
|
||||||
|
# 在网页上设置 其他参数
|
||||||
|
# (2) 上传你的预训练模型到 S3 桶上
|
||||||
|
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||||
|
# (4) 在网页上设置你的代码路径为 "/path/deeplabv3"
|
||||||
|
# (5) 在网页上设置启动文件为 "train.py"
|
||||||
|
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||||
|
# (7) 创建训练作业
|
||||||
|
```
|
||||||
|
|
||||||
|
- 在 ModelArts 进行验证 (如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
|
||||||
|
|
||||||
|
1. 使用voc val数据集评估s16。评估步骤如下:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) 执行 a 或者 b.
|
||||||
|
# a. 在 base_config.yaml 文件中设置 "enable_modelarts=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "model='deeplab_v3_s16'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "batch_size=32"
|
||||||
|
# 在 base_config.yaml 文件中设置 "scales_type=0"
|
||||||
|
# 在 base_config.yaml 文件中设置 "freeze_bn=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_root='/cache/data/vocaug'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_lst='/cache/data/vocaug/voc_val_lst.txt'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在 base_config.yaml 文件中设置 "ckpt_path='/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s16-300_41.ckpt'"
|
||||||
|
# 在 base_config.yaml 文件中设置 其他参数
|
||||||
|
# b. 在网页上设置 "enable_modelarts=True"
|
||||||
|
# 在网页上设置 "model=deeplab_v3_s16"
|
||||||
|
# 在网页上设置 "batch_size=32"
|
||||||
|
# 在网页上设置 "scales_type=0"
|
||||||
|
# 在网页上设置 "freeze_bn=True"
|
||||||
|
# 在网页上设置 "data_root=/cache/data/vocaug"
|
||||||
|
# 在网页上设置 "data_lst=/cache/data/vocaug/voc_val_lst.txt"
|
||||||
|
# 在网页上设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在网页上设置 "ckpt_path=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s16-300_41.ckpt"
|
||||||
|
# 在网页上设置 其他参数
|
||||||
|
# (2) 上传你的预训练模型到 S3 桶上
|
||||||
|
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||||
|
# (4) 在网页上设置你的代码路径为 "/path/deeplabv3"
|
||||||
|
# (5) 在网页上设置启动文件为 "eval.py"
|
||||||
|
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||||
|
# (7) 创建训练作业
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 使用voc val数据集评估s8。评估步骤如下:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) 执行 a 或者 b.
|
||||||
|
# a. 在 base_config.yaml 文件中设置 "enable_modelarts=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "model='deeplab_v3_s8'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "batch_size=16"
|
||||||
|
# 在 base_config.yaml 文件中设置 "scales_type=0"
|
||||||
|
# 在 base_config.yaml 文件中设置 "freeze_bn=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_root='/cache/data/vocaug'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_lst='/cache/data/vocaug/voc_val_lst.txt'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "checkpoint_url='/The path of checkpoint in S3/'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "ckpt_path='/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt'"
|
||||||
|
# 在 base_config.yaml 文件中设置 其他参数
|
||||||
|
# b. 在网页上设置 "enable_modelarts=True"
|
||||||
|
# 在网页上设置 "model=deeplab_v3_s8"
|
||||||
|
# 在网页上设置 "batch_size=16"
|
||||||
|
# 在网页上设置 "scales_type=0"
|
||||||
|
# 在网页上设置 "freeze_bn=True"
|
||||||
|
# 在网页上设置 "data_root=/cache/data/vocaug"
|
||||||
|
# 在网页上设置 "data_lst=/cache/data/vocaug/voc_val_lst.txt"
|
||||||
|
# 在网页上设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在网页上设置 "ckpt_path=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt"
|
||||||
|
# 在网页上设置 其他参数.
|
||||||
|
# (2) 上传你的预训练模型到 S3 桶上
|
||||||
|
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||||
|
# (4) 在网页上设置你的代码路径为 "/path/deeplabv3"
|
||||||
|
# (5) 在网页上设置启动文件为 "eval.py"
|
||||||
|
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||||
|
# (7) 创建训练作业
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 使用voc val数据集评估多尺度s8。评估步骤如下:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) 执行 a 或者 b.
|
||||||
|
# a. 在 base_config.yaml 文件中设置 "enable_modelarts=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "model='deeplab_v3_s8'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "batch_size=16"
|
||||||
|
# 在 base_config.yaml 文件中设置 "scales_type=1"
|
||||||
|
# 在 base_config.yaml 文件中设置 "freeze_bn=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_root='/cache/data/vocaug'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_lst='/cache/data/vocaug/voc_val_lst.txt'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "checkpoint_url='/The path of checkpoint in S3/'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "ckpt_path='/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt'"
|
||||||
|
# 在 base_config.yaml 文件中设置 其他参数
|
||||||
|
# b. 在网页上设置 "enable_modelarts=True"
|
||||||
|
# 在网页上设置 "model=deeplab_v3_s8"
|
||||||
|
# 在网页上设置 "batch_size=16"
|
||||||
|
# 在网页上设置 "scales_type=1"
|
||||||
|
# 在网页上设置 "freeze_bn=True"
|
||||||
|
# 在网页上设置 "data_root=/cache/data/vocaug"
|
||||||
|
# 在网页上设置 "data_lst=/cache/data/vocaug/voc_val_lst.txt"
|
||||||
|
# 在网页上设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在网页上设置 "ckpt_path=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt"
|
||||||
|
# 在网页上设置 其他参数
|
||||||
|
# (2) 上传你的预训练模型到 S3 桶上
|
||||||
|
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||||
|
# (4) 在网页上设置你的代码路径为 "/path/deeplabv3"
|
||||||
|
# (5) 在网页上设置启动文件为 "eval.py"
|
||||||
|
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||||
|
# (7) 创建训练作业
|
||||||
|
```
|
||||||
|
|
||||||
|
4. 使用voc val数据集评估多尺度和翻转s8。评估步骤如下:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# (1) 执行 a 或者 b.
|
||||||
|
# a. 在 base_config.yaml 文件中设置 "enable_modelarts=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "model='deeplab_v3_s8'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "batch_size=16"
|
||||||
|
# 在 base_config.yaml 文件中设置 "scales_type=1"
|
||||||
|
# 在 base_config.yaml 文件中设置 "freeze_bn=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "flip=True"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_root='/cache/data/vocaug'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "data_lst='/cache/data/vocaug/voc_val_lst.txt'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "checkpoint_url='/The path of checkpoint in S3/'"
|
||||||
|
# 在 base_config.yaml 文件中设置 "ckpt_path='/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt'"
|
||||||
|
# 在 base_config.yaml 文件中设置 其他参数
|
||||||
|
# b. 在网页上设置 "enable_modelarts=True"
|
||||||
|
# 在网页上设置 "model=deeplab_v3_s8"
|
||||||
|
# 在网页上设置 "batch_size=16"
|
||||||
|
# 在网页上设置 "scales_type=1"
|
||||||
|
# 在网页上设置 "freeze_bn=True"
|
||||||
|
# 在网页上设置 "flip=True"
|
||||||
|
# 在网页上设置 "data_root=/cache/data/vocaug"
|
||||||
|
# 在网页上设置 "data_lst=/cache/data/vocaug/voc_val_lst.txt"
|
||||||
|
# 在网页上设置 "checkpoint_url=/The path of checkpoint in S3/"
|
||||||
|
# 在网页上设置 "ckpt_path=/cache/checkpoint_path/path_to_pretrain/deeplab_v3_s8-300_11.ckpt"
|
||||||
|
# 在网页上设置 其他参数
|
||||||
|
# (2) 上传你的预训练模型到 S3 桶上
|
||||||
|
# (3) 上传你的压缩数据集到 S3 桶上 (你也可以上传原始的数据集,但那可能会很慢。)
|
||||||
|
# (4) 在网页上设置你的代码路径为 "/path/deeplabv3"
|
||||||
|
# (5) 在网页上设置启动文件为 "eval.py"
|
||||||
|
# (6) 在网页上设置"训练数据集"、"训练输出文件路径"、"作业日志路径"等
|
||||||
|
# (7) 创建训练作业
|
||||||
|
```
|
||||||
|
|
||||||
# 脚本说明
|
# 脚本说明
|
||||||
|
|
||||||
## 脚本及样例代码
|
## 脚本及样例代码
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path"
|
||||||
|
device_target: "Ascend" # ['Ascend', 'CPU']
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
train_dir: "/cache/train/ckpt"
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
need_modelarts_dataset_unzip: True
|
||||||
|
data_file: ""
|
||||||
|
batch_size: 32
|
||||||
|
crop_size: 513
|
||||||
|
image_mean: [103.53, 116.28, 123.675]
|
||||||
|
image_std: [57.375, 57.120, 58.395]
|
||||||
|
min_scale: 0.5
|
||||||
|
max_scale: 2.0
|
||||||
|
ignore_label: 255
|
||||||
|
num_classes: 21
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
train_epochs: 300
|
||||||
|
lr_type: "cos"
|
||||||
|
base_lr: 0.015
|
||||||
|
lr_decay_step: 40000
|
||||||
|
lr_decay_rate: 0.1
|
||||||
|
loss_scale: 3072.0
|
||||||
|
|
||||||
|
# model
|
||||||
|
model: "deeplab_v3_s16"
|
||||||
|
freeze_bn: False
|
||||||
|
ckpt_pre_trained: ""
|
||||||
|
filter_weight: False
|
||||||
|
|
||||||
|
# train
|
||||||
|
is_distributed: False
|
||||||
|
rank: 0
|
||||||
|
group_size: 1
|
||||||
|
save_steps: 3000
|
||||||
|
keep_checkpoint_max: 1
|
||||||
|
|
||||||
|
# eval param
|
||||||
|
data_root: ""
|
||||||
|
data_lst: ""
|
||||||
|
scales: [1.0,]
|
||||||
|
scales_list: [[1.0,], [0.5, 0.75, 1.0, 1.25, 1.75]]
|
||||||
|
scales_type: 0
|
||||||
|
flip: False
|
||||||
|
ckpt_path: ""
|
||||||
|
input_format: "NCHW" # ["NCHW", "NHWC"]
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: "Whether training on modelarts, default: False"
|
||||||
|
data_url: "Url for modelarts"
|
||||||
|
train_url: "Url for modelarts"
|
||||||
|
data_path: "The location of the input data."
|
||||||
|
output_path: "The location of the output file."
|
||||||
|
device_target: 'Target device type'
|
||||||
|
train_dir: "where training log and ckpts saved"
|
||||||
|
data_file: "path and name of one mindrecord file"
|
||||||
|
batch_size: "batch size"
|
||||||
|
crop_size: "crop size"
|
||||||
|
image_mean: "image mean"
|
||||||
|
image_std: "image std"
|
||||||
|
min_scale: "minimum scale of data argumentation"
|
||||||
|
max_scale: "maximum scale of data argumentation"
|
||||||
|
ignore_label: "ignore label"
|
||||||
|
num_classes: "number of classes"
|
||||||
|
train_epochs: "epoch"
|
||||||
|
lr_type: "type of learning rate"
|
||||||
|
base_lr: "base learning rate"
|
||||||
|
lr_decay_step: "learning rate decay step"
|
||||||
|
lr_decay_rate: "learning rate decay rate"
|
||||||
|
loss_scale: "loss scale"
|
||||||
|
model: "select model"
|
||||||
|
freeze_bn: "freeze bn"
|
||||||
|
ckpt_pre_trained: "pretrained model"
|
||||||
|
filter_weight: "Filter the last weight parameters, default is False."
|
||||||
|
is_distributed: "distributed training"
|
||||||
|
rank: "local rank of distributed"
|
||||||
|
group_size: "world size of distributed"
|
||||||
|
save_steps: "steps interval for saving"
|
||||||
|
keep_checkpoint_max: "max checkpoint for saving"
|
||||||
|
|
||||||
|
data_root: "root path of val data"
|
||||||
|
data_lst: "list of val data"
|
||||||
|
scales: "scales of evaluation"
|
||||||
|
flip: "perform left-right flip"
|
||||||
|
ckpt_path: "model to evaluat"
|
||||||
|
input_format: "NCHW or NHWC"
|
|
@ -15,7 +15,7 @@
|
||||||
"""eval deeplabv3."""
|
"""eval deeplabv3."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
@ -25,34 +25,18 @@ import mindspore.ops as ops
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from src.nets import net_factory
|
from src.nets import net_factory
|
||||||
|
|
||||||
|
from utils.config import config
|
||||||
|
from utils.moxing_adapter import moxing_wrapper
|
||||||
|
from utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||||
device_id=int(os.getenv('DEVICE_ID')))
|
device_id=get_device_id())
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
# parser.add_argument('--scales', type=float, action='append', help='scales of evaluation')
|
||||||
parser = argparse.ArgumentParser('mindspore deeplabv3 eval')
|
# parser.add_argument('--flip', action='store_true', help='perform left-right flip')
|
||||||
|
|
||||||
# val data
|
|
||||||
parser.add_argument('--data_root', type=str, default='', help='root path of val data')
|
|
||||||
parser.add_argument('--data_lst', type=str, default='', help='list of val data')
|
|
||||||
parser.add_argument('--batch_size', type=int, default=16, help='batch size')
|
|
||||||
parser.add_argument('--crop_size', type=int, default=513, help='crop size')
|
|
||||||
parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean')
|
|
||||||
parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std')
|
|
||||||
parser.add_argument('--scales', type=float, action='append', help='scales of evaluation')
|
|
||||||
parser.add_argument('--flip', action='store_true', help='perform left-right flip')
|
|
||||||
parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
|
|
||||||
parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
|
|
||||||
|
|
||||||
# model
|
|
||||||
parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
|
|
||||||
parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
|
|
||||||
parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate')
|
|
||||||
parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW",
|
|
||||||
help="NCHW or NHWC")
|
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def cal_hist(a, b, n):
|
def cal_hist(a, b, n):
|
||||||
|
@ -153,8 +137,63 @@ def eval_batch_scales(args, eval_net, img_lst, scales,
|
||||||
return result_msk
|
return result_msk
|
||||||
|
|
||||||
|
|
||||||
|
def modelarts_pre_process():
|
||||||
|
'''modelarts pre process function.'''
|
||||||
|
def unzip(zip_file, save_dir):
|
||||||
|
import zipfile
|
||||||
|
s_time = time.time()
|
||||||
|
if not os.path.exists(os.path.join(save_dir, "vocaug")):
|
||||||
|
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||||
|
if zip_isexist:
|
||||||
|
fz = zipfile.ZipFile(zip_file, 'r')
|
||||||
|
data_num = len(fz.namelist())
|
||||||
|
print("Extract Start...")
|
||||||
|
print("unzip file num: {}".format(data_num))
|
||||||
|
i = 0
|
||||||
|
for file in fz.namelist():
|
||||||
|
if i % int(data_num / 100) == 0:
|
||||||
|
print("unzip percent: {}%".format(i / int(data_num / 100)), flush=True)
|
||||||
|
i += 1
|
||||||
|
fz.extract(file, save_dir)
|
||||||
|
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||||
|
int(int(time.time() - s_time) % 60)))
|
||||||
|
print("Extract Done.")
|
||||||
|
else:
|
||||||
|
print("This is not zip.")
|
||||||
|
else:
|
||||||
|
print("Zip has been extracted.")
|
||||||
|
|
||||||
|
if config.need_modelarts_dataset_unzip:
|
||||||
|
zip_file_1 = os.path.join(config.data_path, "vocaug.zip")
|
||||||
|
save_dir_1 = os.path.join(config.data_path)
|
||||||
|
|
||||||
|
sync_lock = "/tmp/unzip_sync.lock"
|
||||||
|
|
||||||
|
# Each server contains 8 devices as most.
|
||||||
|
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||||
|
print("Zip file path: ", zip_file_1)
|
||||||
|
print("Unzip file save dir: ", save_dir_1)
|
||||||
|
unzip(zip_file_1, save_dir_1)
|
||||||
|
print("===Finish extract data synchronization===")
|
||||||
|
try:
|
||||||
|
os.mknod(sync_lock)
|
||||||
|
except IOError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if os.path.exists(sync_lock):
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||||
|
|
||||||
|
config.train_dir = os.path.join(config.output_path, str(get_rank_id()), config.train_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||||
def net_eval():
|
def net_eval():
|
||||||
args = parse_args()
|
config.scales = config.scales_list[config.scales_type]
|
||||||
|
args = config
|
||||||
|
|
||||||
# data list
|
# data list
|
||||||
with open(args.data_lst) as f:
|
with open(args.data_lst) as f:
|
||||||
|
|
|
@ -15,7 +15,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
export DEVICE_ID=7
|
export DEVICE_ID=7
|
||||||
python /PATH/TO/MODEL_ZOO_CODE/data/build_seg_data.py --data_root=/PATH/TO/DATA_ROOT \
|
EXECUTE_PATH=$(pwd)
|
||||||
|
|
||||||
|
python ${EXECUTE_PATH}/../src/data/build_seg_data.py --data_root=/PATH/TO/DATA_ROOT \
|
||||||
--data_lst=/PATH/TO/DATA_lst.txt \
|
--data_lst=/PATH/TO/DATA_lst.txt \
|
||||||
--dst_path=/PATH/TO/MINDRECORED_NAME.mindrecord \
|
--dst_path=/PATH/TO/MINDRECORED_NAME.mindrecord \
|
||||||
--num_shards=8 \
|
--num_shards=8 \
|
||||||
|
|
|
@ -14,11 +14,34 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 1 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_distribute_train_base.sh [RANK_TABLE_FILE]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
echo $PATH1
|
||||||
|
|
||||||
|
if [ ! -f $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
ulimit -c unlimited
|
ulimit -c unlimited
|
||||||
train_path=/PATH/TO/EXPERIMENTS_DIR
|
EXECUTE_PATH=$(pwd)
|
||||||
|
train_path=${EXECUTE_PATH}/s16_aug_train
|
||||||
export SLOG_PRINT_TO_STDOUT=0
|
export SLOG_PRINT_TO_STDOUT=0
|
||||||
train_code_path=/PATH/TO/MODEL_ZOO_CODE
|
export RANK_TABLE_FILE=$PATH1
|
||||||
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
|
|
||||||
export RANK_SIZE=8
|
export RANK_SIZE=8
|
||||||
export RANK_START_ID=0
|
export RANK_START_ID=0
|
||||||
|
|
||||||
|
@ -35,8 +58,8 @@ do
|
||||||
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
||||||
mkdir ${train_path}/device${DEVICE_ID}
|
mkdir ${train_path}/device${DEVICE_ID}
|
||||||
cd ${train_path}/device${DEVICE_ID} || exit
|
cd ${train_path}/device${DEVICE_ID} || exit
|
||||||
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
|
python ${EXECUTE_PATH}/../train.py --train_dir=${train_path}/ckpt \
|
||||||
--data_file=/PATH/TO/MINDRECORD_NAME \
|
--data_file=/PATH_TO_DATA/vocaug/vocaug_mindrecord/vocaug_mindrecord0 \
|
||||||
--train_epochs=300 \
|
--train_epochs=300 \
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--crop_size=513 \
|
--crop_size=513 \
|
||||||
|
@ -48,7 +71,7 @@ do
|
||||||
--num_classes=21 \
|
--num_classes=21 \
|
||||||
--model=deeplab_v3_s16 \
|
--model=deeplab_v3_s16 \
|
||||||
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
|
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
|
||||||
--is_distributed \
|
--is_distributed=True \
|
||||||
--save_steps=410 \
|
--save_steps=410 \
|
||||||
--keep_checkpoint_max=200 >log 2>&1 &
|
--keep_checkpoint_max=1 >log 2>&1 &
|
||||||
done
|
done
|
||||||
|
|
|
@ -14,11 +14,34 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 1 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_distribute_train_base.sh [RANK_TABLE_FILE]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
echo $PATH1
|
||||||
|
|
||||||
|
if [ ! -f $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
ulimit -c unlimited
|
ulimit -c unlimited
|
||||||
train_path=/PATH/TO/EXPERIMENTS_DIR
|
EXECUTE_PATH=$(pwd)
|
||||||
|
train_path=${EXECUTE_PATH}/s8_aug_train
|
||||||
export SLOG_PRINT_TO_STDOUT=0
|
export SLOG_PRINT_TO_STDOUT=0
|
||||||
train_code_path=/PATH/TO/MODEL_ZOO_CODE
|
export RANK_TABLE_FILE=$PATH1
|
||||||
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
|
|
||||||
export RANK_SIZE=8
|
export RANK_SIZE=8
|
||||||
export RANK_START_ID=0
|
export RANK_START_ID=0
|
||||||
|
|
||||||
|
@ -35,8 +58,8 @@ do
|
||||||
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
||||||
mkdir ${train_path}/device${DEVICE_ID}
|
mkdir ${train_path}/device${DEVICE_ID}
|
||||||
cd ${train_path}/device${DEVICE_ID} || exit
|
cd ${train_path}/device${DEVICE_ID} || exit
|
||||||
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
|
python ${EXECUTE_PATH}/../train.py --train_dir=${train_path}/ckpt \
|
||||||
--data_file=/PATH/TO/MINDRECORD_NAME \
|
--data_file=/PATH_TO_DATA/vocaug/vocaug_mindrecord/vocaug_mindrecord0 \
|
||||||
--train_epochs=800 \
|
--train_epochs=800 \
|
||||||
--batch_size=16 \
|
--batch_size=16 \
|
||||||
--crop_size=513 \
|
--crop_size=513 \
|
||||||
|
@ -49,7 +72,7 @@ do
|
||||||
--model=deeplab_v3_s8 \
|
--model=deeplab_v3_s8 \
|
||||||
--loss_scale=2048 \
|
--loss_scale=2048 \
|
||||||
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
|
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
|
||||||
--is_distributed \
|
--is_distributed=True \
|
||||||
--save_steps=820 \
|
--save_steps=820 \
|
||||||
--keep_checkpoint_max=200 >log 2>&1 &
|
--keep_checkpoint_max=1 >log 2>&1 &
|
||||||
done
|
done
|
||||||
|
|
|
@ -14,11 +14,34 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 1 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_distribute_train_base.sh [RANK_TABLE_FILE]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
echo $PATH1
|
||||||
|
|
||||||
|
if [ ! -f $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
ulimit -c unlimited
|
ulimit -c unlimited
|
||||||
train_path=/PATH/TO/EXPERIMENTS_DIR
|
EXECUTE_PATH=$(pwd)
|
||||||
|
train_path=${EXECUTE_PATH}/s8_voc_train
|
||||||
export SLOG_PRINT_TO_STDOUT=0
|
export SLOG_PRINT_TO_STDOUT=0
|
||||||
train_code_path=/PATH/TO/MODEL_ZOO_CODE
|
export RANK_TABLE_FILE=$PATH1
|
||||||
export RANK_TABLE_FILE=${train_code_path}/src/tools/rank_table_8p.json
|
|
||||||
export RANK_SIZE=8
|
export RANK_SIZE=8
|
||||||
export RANK_START_ID=0
|
export RANK_START_ID=0
|
||||||
|
|
||||||
|
@ -35,8 +58,8 @@ do
|
||||||
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
|
||||||
mkdir ${train_path}/device${DEVICE_ID}
|
mkdir ${train_path}/device${DEVICE_ID}
|
||||||
cd ${train_path}/device${DEVICE_ID} || exit
|
cd ${train_path}/device${DEVICE_ID} || exit
|
||||||
python ${train_code_path}/train.py --train_dir=${train_path}/ckpt \
|
python ${EXECUTE_PATH}/../train.py --train_dir=${train_path}/ckpt \
|
||||||
--data_file=/PATH/TO/MINDRECORD_NAME \
|
--data_file=/PATH_TO_DATA/vocaug/voctrain_mindrecord/voctrain_mindrecord00 \
|
||||||
--train_epochs=300 \
|
--train_epochs=300 \
|
||||||
--batch_size=16 \
|
--batch_size=16 \
|
||||||
--crop_size=513 \
|
--crop_size=513 \
|
||||||
|
@ -49,7 +72,7 @@ do
|
||||||
--model=deeplab_v3_s8 \
|
--model=deeplab_v3_s8 \
|
||||||
--loss_scale=2048 \
|
--loss_scale=2048 \
|
||||||
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
|
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
|
||||||
--is_distributed \
|
--is_distributed=True \
|
||||||
--save_steps=110 \
|
--save_steps=110 \
|
||||||
--keep_checkpoint_max=200 >log 2>&1 &
|
--keep_checkpoint_max=1 >log 2>&1 &
|
||||||
done
|
done
|
||||||
|
|
|
@ -14,24 +14,24 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
export DEVICE_ID=3
|
export DEVICE_ID=0
|
||||||
export SLOG_PRINT_TO_STDOUT=0
|
export SLOG_PRINT_TO_STDOUT=0
|
||||||
train_code_path=/PATH/TO/MODEL_ZOO_CODE
|
EXECUTE_PATH=$(pwd)
|
||||||
eval_path=/PATH/TO/EVAL
|
eval_path=${EXECUTE_PATH}/s16_eval
|
||||||
|
|
||||||
if [ -d ${eval_path} ]; then
|
if [ -d ${eval_path} ]; then
|
||||||
rm -rf ${eval_path}
|
rm -rf ${eval_path}
|
||||||
fi
|
fi
|
||||||
mkdir -p ${eval_path}
|
mkdir -p ${eval_path}
|
||||||
|
|
||||||
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
|
python ${EXECUTE_PATH}/../eval.py --data_root=/PATH_TO_DATA/vocaug \
|
||||||
--data_lst=/PATH/TO/DATA_lst.txt \
|
--data_lst=/PATH_TO_DATA/vocaug/voc_val_lst.txt \
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--crop_size=513 \
|
--crop_size=513 \
|
||||||
--ignore_label=255 \
|
--ignore_label=255 \
|
||||||
--num_classes=21 \
|
--num_classes=21 \
|
||||||
--model=deeplab_v3_s16 \
|
--model=deeplab_v3_s16 \
|
||||||
--scales=1.0 \
|
--scales_type=0 \
|
||||||
--freeze_bn \
|
--freeze_bn=True \
|
||||||
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
|
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
|
||||||
|
|
||||||
|
|
|
@ -14,24 +14,24 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
export DEVICE_ID=3
|
export DEVICE_ID=1
|
||||||
export SLOG_PRINT_TO_STDOUT=0
|
export SLOG_PRINT_TO_STDOUT=0
|
||||||
train_code_path=/PATH/TO/MODEL_ZOO_CODE
|
EXECUTE_PATH=$(pwd)
|
||||||
eval_path=/PATH/TO/EVAL
|
eval_path=${EXECUTE_PATH}/s8_eval
|
||||||
|
|
||||||
if [ -d ${eval_path} ]; then
|
if [ -d ${eval_path} ]; then
|
||||||
rm -rf ${eval_path}
|
rm -rf ${eval_path}
|
||||||
fi
|
fi
|
||||||
mkdir -p ${eval_path}
|
mkdir -p ${eval_path}
|
||||||
|
|
||||||
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
|
python ${EXECUTE_PATH}/../eval.py --data_root=/PATH_TO_DATA/vocaug \
|
||||||
--data_lst=/PATH/TO/DATA_lst.txt \
|
--data_lst=/PATH_TO_DATA/vocaug/voc_val_lst.txt \
|
||||||
--batch_size=16 \
|
--batch_size=16 \
|
||||||
--crop_size=513 \
|
--crop_size=513 \
|
||||||
--ignore_label=255 \
|
--ignore_label=255 \
|
||||||
--num_classes=21 \
|
--num_classes=21 \
|
||||||
--model=deeplab_v3_s8 \
|
--model=deeplab_v3_s8 \
|
||||||
--scales=1.0 \
|
--scales_type=0 \
|
||||||
--freeze_bn \
|
--freeze_bn=True \
|
||||||
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
|
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
|
||||||
|
|
||||||
|
|
|
@ -14,28 +14,24 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
export DEVICE_ID=3
|
export DEVICE_ID=2
|
||||||
export SLOG_PRINT_TO_STDOUT=0
|
export SLOG_PRINT_TO_STDOUT=0
|
||||||
train_code_path=/PATH/TO/MODEL_ZOO_CODE
|
EXECUTE_PATH=$(pwd)
|
||||||
eval_path=/PATH/TO/EVAL
|
eval_path=${EXECUTE_PATH}/multiscale_eval
|
||||||
|
|
||||||
if [ -d ${eval_path} ]; then
|
if [ -d ${eval_path} ]; then
|
||||||
rm -rf ${eval_path}
|
rm -rf ${eval_path}
|
||||||
fi
|
fi
|
||||||
mkdir -p ${eval_path}
|
mkdir -p ${eval_path}
|
||||||
|
|
||||||
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
|
python ${EXECUTE_PATH}/../eval.py --data_root=/PATH_TO_DATA/vocaug \
|
||||||
--data_lst=/PATH/TO/DATA_lst.txt \
|
--data_lst=/PATH_TO_DATA/vocaug/voc_val_lst.txt \
|
||||||
--batch_size=16 \
|
--batch_size=16 \
|
||||||
--crop_size=513 \
|
--crop_size=513 \
|
||||||
--ignore_label=255 \
|
--ignore_label=255 \
|
||||||
--num_classes=21 \
|
--num_classes=21 \
|
||||||
--model=deeplab_v3_s8 \
|
--model=deeplab_v3_s8 \
|
||||||
--scales=0.5 \
|
--scales_type=1 \
|
||||||
--scales=0.75 \
|
--freeze_bn=True \
|
||||||
--scales=1.0 \
|
|
||||||
--scales=1.25 \
|
|
||||||
--scales=1.75 \
|
|
||||||
--freeze_bn \
|
|
||||||
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
|
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
|
||||||
|
|
||||||
|
|
|
@ -16,27 +16,23 @@
|
||||||
|
|
||||||
export DEVICE_ID=3
|
export DEVICE_ID=3
|
||||||
export SLOG_PRINT_TO_STDOUT=0
|
export SLOG_PRINT_TO_STDOUT=0
|
||||||
train_code_path=/PATH/TO/MODEL_ZOO_CODE
|
EXECUTE_PATH=$(pwd)
|
||||||
eval_path=/PATH/TO/EVAL
|
eval_path=${EXECUTE_PATH}/multiscale_flip_eval
|
||||||
|
|
||||||
if [ -d ${eval_path} ]; then
|
if [ -d ${eval_path} ]; then
|
||||||
rm -rf ${eval_path}
|
rm -rf ${eval_path}
|
||||||
fi
|
fi
|
||||||
mkdir -p ${eval_path}
|
mkdir -p ${eval_path}
|
||||||
|
|
||||||
python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
|
python ${EXECUTE_PATH}/../eval.py --data_root=/PATH_TO_DATA/vocaug \
|
||||||
--data_lst=/PATH/TO/DATA_lst.txt \
|
--data_lst=/PATH_TO_DATA/vocaug/voc_val_lst.txt \
|
||||||
--batch_size=16 \
|
--batch_size=16 \
|
||||||
--crop_size=513 \
|
--crop_size=513 \
|
||||||
--ignore_label=255 \
|
--ignore_label=255 \
|
||||||
--num_classes=21 \
|
--num_classes=21 \
|
||||||
--model=deeplab_v3_s8 \
|
--model=deeplab_v3_s8 \
|
||||||
--scales=0.5 \
|
--scales_type=1 \
|
||||||
--scales=0.75 \
|
--flip=True \
|
||||||
--scales=1.0 \
|
--freeze_bn=True \
|
||||||
--scales=1.25 \
|
|
||||||
--scales=1.75 \
|
|
||||||
--flip \
|
|
||||||
--freeze_bn \
|
|
||||||
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
|
--ckpt_path=/PATH/TO/PRETRAIN_MODEL >${eval_path}/eval_log 2>&1 &
|
||||||
|
|
||||||
|
|
|
@ -14,10 +14,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
export DEVICE_ID=5
|
export DEVICE_ID=0
|
||||||
export SLOG_PRINT_TO_STDOUT=0
|
export SLOG_PRINT_TO_STDOUT=0
|
||||||
train_path=/PATH/TO/EXPERIMENTS_DIR
|
EXECUTE_PATH=$(pwd)
|
||||||
train_code_path=/PATH/TO/MODEL_ZOO_CODE
|
train_path=${EXECUTE_PATH}/s16_aug_train_1p
|
||||||
|
|
||||||
if [ -d ${train_path} ]; then
|
if [ -d ${train_path} ]; then
|
||||||
rm -rf ${train_path}
|
rm -rf ${train_path}
|
||||||
|
@ -27,7 +27,7 @@ mkdir ${train_path}/device${DEVICE_ID}
|
||||||
mkdir ${train_path}/ckpt
|
mkdir ${train_path}/ckpt
|
||||||
cd ${train_path}/device${DEVICE_ID} || exit
|
cd ${train_path}/device${DEVICE_ID} || exit
|
||||||
|
|
||||||
python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \
|
python ${EXECUTE_PATH}/../train.py --data_file=/PATH_TO_DATA/vocaug/vocaug_mindrecord/vocaug_mindrecord0 \
|
||||||
--train_dir=${train_path}/ckpt \
|
--train_dir=${train_path}/ckpt \
|
||||||
--train_epochs=200 \
|
--train_epochs=200 \
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
export DEVICE_ID=0
|
export DEVICE_ID=0
|
||||||
export SLOG_PRINT_TO_STDOUT=0
|
export SLOG_PRINT_TO_STDOUT=0
|
||||||
train_path=/PATH/TO/EXPERIMENTS_DIR
|
EXECUTE_PATH=$(pwd)
|
||||||
train_code_path=/PATH/TO/MODEL_ZOO_CODE
|
train_path=${EXECUTE_PATH}/s16_aug_train_cpu
|
||||||
|
|
||||||
if [ -d ${train_path} ]; then
|
if [ -d ${train_path} ]; then
|
||||||
rm -rf ${train_path}
|
rm -rf ${train_path}
|
||||||
|
@ -27,7 +27,7 @@ mkdir ${train_path}/device${DEVICE_ID}
|
||||||
mkdir ${train_path}/ckpt
|
mkdir ${train_path}/ckpt
|
||||||
cd ${train_path}/device${DEVICE_ID} || exit
|
cd ${train_path}/device${DEVICE_ID} || exit
|
||||||
|
|
||||||
python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \
|
python ${EXECUTE_PATH}/../train.py --data_file=/PATH_TO_DATA/vocaug/vocaug_mindrecord/vocaug_mindrecord0 \
|
||||||
--device_target=CPU \
|
--device_target=CPU \
|
||||||
--train_dir=${train_path}/ckpt \
|
--train_dir=${train_path}/ckpt \
|
||||||
--train_epochs=200 \
|
--train_epochs=200 \
|
||||||
|
|
|
@ -15,8 +15,7 @@
|
||||||
"""train deeplabv3."""
|
"""train deeplabv3."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import time
|
||||||
import ast
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
@ -31,6 +30,9 @@ from src.data import dataset as data_generator
|
||||||
from src.loss import loss
|
from src.loss import loss
|
||||||
from src.nets import net_factory
|
from src.nets import net_factory
|
||||||
from src.utils import learning_rates
|
from src.utils import learning_rates
|
||||||
|
from utils.config import config
|
||||||
|
from utils.moxing_adapter import moxing_wrapper
|
||||||
|
from utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
|
@ -47,57 +49,68 @@ class BuildTrainNetwork(nn.Cell):
|
||||||
return net_loss
|
return net_loss
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def modelarts_pre_process():
|
||||||
parser = argparse.ArgumentParser('mindspore deeplabv3 training')
|
'''modelarts pre process function.'''
|
||||||
parser.add_argument('--train_dir', type=str, default='', help='where training log and ckpts saved')
|
def unzip(zip_file, save_dir):
|
||||||
|
import zipfile
|
||||||
|
s_time = time.time()
|
||||||
|
if not os.path.exists(os.path.join(save_dir, "vocaug")):
|
||||||
|
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||||
|
if zip_isexist:
|
||||||
|
fz = zipfile.ZipFile(zip_file, 'r')
|
||||||
|
data_num = len(fz.namelist())
|
||||||
|
print("Extract Start...")
|
||||||
|
print("unzip file num: {}".format(data_num))
|
||||||
|
i = 0
|
||||||
|
for file in fz.namelist():
|
||||||
|
if i % int(data_num / 100) == 0:
|
||||||
|
print("unzip percent: {}%".format(i / int(data_num / 100)), flush=True)
|
||||||
|
i += 1
|
||||||
|
fz.extract(file, save_dir)
|
||||||
|
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||||
|
int(int(time.time() - s_time) % 60)))
|
||||||
|
print("Extract Done.")
|
||||||
|
else:
|
||||||
|
print("This is not zip.")
|
||||||
|
else:
|
||||||
|
print("Zip has been extracted.")
|
||||||
|
|
||||||
# dataset
|
if config.need_modelarts_dataset_unzip:
|
||||||
parser.add_argument('--data_file', type=str, default='', help='path and name of one mindrecord file')
|
zip_file_1 = os.path.join(config.data_path, "vocaug.zip")
|
||||||
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
|
save_dir_1 = os.path.join(config.data_path)
|
||||||
parser.add_argument('--crop_size', type=int, default=513, help='crop size')
|
|
||||||
parser.add_argument('--image_mean', type=list, default=[103.53, 116.28, 123.675], help='image mean')
|
|
||||||
parser.add_argument('--image_std', type=list, default=[57.375, 57.120, 58.395], help='image std')
|
|
||||||
parser.add_argument('--min_scale', type=float, default=0.5, help='minimum scale of data argumentation')
|
|
||||||
parser.add_argument('--max_scale', type=float, default=2.0, help='maximum scale of data argumentation')
|
|
||||||
parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
|
|
||||||
parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
|
|
||||||
|
|
||||||
# optimizer
|
sync_lock = "/tmp/unzip_sync.lock"
|
||||||
parser.add_argument('--train_epochs', type=int, default=300, help='epoch')
|
|
||||||
parser.add_argument('--lr_type', type=str, default='cos', help='type of learning rate')
|
|
||||||
parser.add_argument('--base_lr', type=float, default=0.015, help='base learning rate')
|
|
||||||
parser.add_argument('--lr_decay_step', type=int, default=40000, help='learning rate decay step')
|
|
||||||
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='learning rate decay rate')
|
|
||||||
parser.add_argument('--loss_scale', type=float, default=3072.0, help='loss scale')
|
|
||||||
|
|
||||||
# model
|
# Each server contains 8 devices as most.
|
||||||
parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
|
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||||
parser.add_argument('--freeze_bn', action='store_true', help='freeze bn')
|
print("Zip file path: ", zip_file_1)
|
||||||
parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model')
|
print("Unzip file save dir: ", save_dir_1)
|
||||||
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
|
unzip(zip_file_1, save_dir_1)
|
||||||
help="Filter the last weight parameters, default is False.")
|
print("===Finish extract data synchronization===")
|
||||||
|
try:
|
||||||
|
os.mknod(sync_lock)
|
||||||
|
except IOError:
|
||||||
|
pass
|
||||||
|
|
||||||
# train
|
while True:
|
||||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
|
if os.path.exists(sync_lock):
|
||||||
help='device where the code will be implemented. (Default: Ascend)')
|
break
|
||||||
parser.add_argument('--is_distributed', action='store_true', help='distributed training')
|
time.sleep(1)
|
||||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
|
||||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
|
||||||
parser.add_argument('--save_steps', type=int, default=3000, help='steps interval for saving')
|
|
||||||
parser.add_argument('--keep_checkpoint_max', type=int, default=int, help='max checkpoint for saving')
|
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||||
return args
|
|
||||||
|
config.train_dir = os.path.join(config.output_path, str(get_rank_id()), config.train_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||||
def train():
|
def train():
|
||||||
args = parse_args()
|
args = config
|
||||||
|
|
||||||
if args.device_target == "CPU":
|
if args.device_target == "CPU":
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
|
||||||
else:
|
else:
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
||||||
device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))
|
device_target="Ascend", device_id=get_device_id())
|
||||||
|
|
||||||
# init multicards training
|
# init multicards training
|
||||||
if args.is_distributed:
|
if args.is_distributed:
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Parse arguments"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import ast
|
||||||
|
import argparse
|
||||||
|
from pprint import pprint, pformat
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""
|
||||||
|
Configuration namespace. Convert dictionary to members.
|
||||||
|
"""
|
||||||
|
def __init__(self, cfg_dict):
|
||||||
|
for k, v in cfg_dict.items():
|
||||||
|
if isinstance(v, (list, tuple)):
|
||||||
|
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||||
|
else:
|
||||||
|
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return pformat(self.__dict__)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||||
|
"""
|
||||||
|
Parse command line arguments to the configuration according to the default yaml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: Parent parser.
|
||||||
|
cfg: Base configuration.
|
||||||
|
helper: Helper description.
|
||||||
|
cfg_path: Path to the default yaml config.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||||
|
parents=[parser])
|
||||||
|
helper = {} if helper is None else helper
|
||||||
|
choices = {} if choices is None else choices
|
||||||
|
for item in cfg:
|
||||||
|
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||||
|
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||||
|
choice = choices[item] if item in choices else None
|
||||||
|
if isinstance(cfg[item], bool):
|
||||||
|
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||||
|
help=help_description)
|
||||||
|
else:
|
||||||
|
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||||
|
help=help_description)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def parse_yaml(yaml_path):
|
||||||
|
"""
|
||||||
|
Parse the yaml config file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
yaml_path: Path to the yaml config.
|
||||||
|
"""
|
||||||
|
with open(yaml_path, 'r') as fin:
|
||||||
|
try:
|
||||||
|
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||||
|
cfgs = [x for x in cfgs]
|
||||||
|
if len(cfgs) == 1:
|
||||||
|
cfg_helper = {}
|
||||||
|
cfg = cfgs[0]
|
||||||
|
cfg_choices = {}
|
||||||
|
elif len(cfgs) == 2:
|
||||||
|
cfg, cfg_helper = cfgs
|
||||||
|
cfg_choices = {}
|
||||||
|
elif len(cfgs) == 3:
|
||||||
|
cfg, cfg_helper, cfg_choices = cfgs
|
||||||
|
else:
|
||||||
|
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||||
|
print(cfg_helper)
|
||||||
|
except:
|
||||||
|
raise ValueError("Failed to parse yaml")
|
||||||
|
return cfg, cfg_helper, cfg_choices
|
||||||
|
|
||||||
|
|
||||||
|
def merge(args, cfg):
|
||||||
|
"""
|
||||||
|
Merge the base config from yaml file and command line arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Command line arguments.
|
||||||
|
cfg: Base configuration.
|
||||||
|
"""
|
||||||
|
args_var = vars(args)
|
||||||
|
for item in args_var:
|
||||||
|
cfg[item] = args_var[item]
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
"""
|
||||||
|
Get Config according to the yaml file and cli arguments.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||||
|
help="Config file path")
|
||||||
|
path_args, _ = parser.parse_known_args()
|
||||||
|
default, helper, choices = parse_yaml(path_args.config_path)
|
||||||
|
pprint(default)
|
||||||
|
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||||
|
final_config = merge(args, default)
|
||||||
|
return Config(final_config)
|
||||||
|
|
||||||
|
config = get_config()
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Device adapter for ModelArts"""
|
||||||
|
|
||||||
|
from .config import config
|
||||||
|
|
||||||
|
if config.enable_modelarts:
|
||||||
|
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
|
else:
|
||||||
|
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||||
|
]
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Local adapter"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
def get_device_id():
|
||||||
|
device_id = os.getenv('DEVICE_ID', '0')
|
||||||
|
return int(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_num():
|
||||||
|
device_num = os.getenv('RANK_SIZE', '1')
|
||||||
|
return int(device_num)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_id():
|
||||||
|
global_rank_id = os.getenv('RANK_ID', '0')
|
||||||
|
return int(global_rank_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id():
|
||||||
|
return "Local Job"
|
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Moxing adapter for ModelArts"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import functools
|
||||||
|
from mindspore import context
|
||||||
|
from .config import config
|
||||||
|
|
||||||
|
_global_sync_count = 0
|
||||||
|
|
||||||
|
def get_device_id():
|
||||||
|
device_id = os.getenv('DEVICE_ID', '0')
|
||||||
|
return int(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_num():
|
||||||
|
device_num = os.getenv('RANK_SIZE', '1')
|
||||||
|
return int(device_num)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_id():
|
||||||
|
global_rank_id = os.getenv('RANK_ID', '0')
|
||||||
|
return int(global_rank_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_id():
|
||||||
|
job_id = os.getenv('JOB_ID')
|
||||||
|
job_id = job_id if job_id != "" else "default"
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
def sync_data(from_path, to_path):
|
||||||
|
"""
|
||||||
|
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||||
|
Upload data from local directory to remote obs in contrast.
|
||||||
|
"""
|
||||||
|
import moxing as mox
|
||||||
|
import time
|
||||||
|
global _global_sync_count
|
||||||
|
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||||
|
_global_sync_count += 1
|
||||||
|
|
||||||
|
# Each server contains 8 devices as most.
|
||||||
|
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||||
|
print("from path: ", from_path)
|
||||||
|
print("to path: ", to_path)
|
||||||
|
mox.file.copy_parallel(from_path, to_path)
|
||||||
|
print("===finish data synchronization===")
|
||||||
|
try:
|
||||||
|
os.mknod(sync_lock)
|
||||||
|
except IOError:
|
||||||
|
pass
|
||||||
|
print("===save flag===")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if os.path.exists(sync_lock):
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||||
|
|
||||||
|
|
||||||
|
def moxing_wrapper(pre_process=None, post_process=None):
|
||||||
|
"""
|
||||||
|
Moxing wrapper to download dataset and upload outputs.
|
||||||
|
"""
|
||||||
|
def wrapper(run_func):
|
||||||
|
@functools.wraps(run_func)
|
||||||
|
def wrapped_func(*args, **kwargs):
|
||||||
|
# Download data from data_url
|
||||||
|
if config.enable_modelarts:
|
||||||
|
if config.data_url:
|
||||||
|
sync_data(config.data_url, config.data_path)
|
||||||
|
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||||
|
if config.checkpoint_url:
|
||||||
|
sync_data(config.checkpoint_url, config.load_path)
|
||||||
|
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||||
|
if config.train_url:
|
||||||
|
sync_data(config.train_url, config.output_path)
|
||||||
|
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||||
|
|
||||||
|
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||||
|
config.device_num = get_device_num()
|
||||||
|
config.device_id = get_device_id()
|
||||||
|
if not os.path.exists(config.output_path):
|
||||||
|
os.makedirs(config.output_path)
|
||||||
|
|
||||||
|
if pre_process:
|
||||||
|
pre_process()
|
||||||
|
|
||||||
|
# Run the main function
|
||||||
|
run_func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Upload data to train_url
|
||||||
|
if config.enable_modelarts:
|
||||||
|
if post_process:
|
||||||
|
post_process()
|
||||||
|
|
||||||
|
if config.train_url:
|
||||||
|
print("Start to copy output directory")
|
||||||
|
sync_data(config.output_path, config.train_url)
|
||||||
|
return wrapped_func
|
||||||
|
return wrapper
|
Loading…
Reference in New Issue