!17764 tinydarknet can been used on ModelArts and warpctc init

From: @ZhengBina
Reviewed-by: @oacjiewen,@wuxuejian
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-06-05 10:36:20 +08:00 committed by Gitee
commit ad30fff804
20 changed files with 848 additions and 257 deletions

View File

@ -79,7 +79,7 @@ After installing MindSpore via the official website, you can start training and
bash ./scripts/run_standalone_train.sh 0 bash ./scripts/run_standalone_train.sh 0
# run distributed training example # run distributed training example
bash ./scripts/run_distribute_train.sh rank_table.json bash ./scripts/run_distribute_train.sh /{path}/*.json
# run evaluation example # run evaluation example
python eval.py > eval.log 2>&1 & python eval.py > eval.log 2>&1 &
@ -87,12 +87,62 @@ After installing MindSpore via the official website, you can start training and
bash ./script/run_eval.sh bash ./script/run_eval.sh
``` ```
For distributed training, a hccl configuration file with JSON format needs to be created in advance. For distributed training, a hccl configuration file [RANK_TABLE_FILE] with JSON format needs to be created in advance.
Please follow the instructions in the link below: Please follow the instructions in the link below:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.> <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
- Running on ModelArts
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows.
- Training with 8 cards on ModelArts
```python
# (1) Upload the code folder to S3 bucket.
# (2) Click to "create training task" on the website UI interface.
# (3) Set the code directory to "/{path}/tinydarknet" on the website UI interface.
# (4) Set the startup file to /{path}/tinydarknet/train.py" on the website UI interface.
# (5) Perform a or b.
# a. setting parameters in /{path}/tinydarknet/imagenet_config.yaml.
# 1. Set ”batch_size: 64“ (not necessary)
# 2. Set ”enable_modelarts: True“
# 3. Set ”modelarts_dataset_unzip_name: {filenmae}",if the data is uploaded in the form of zip package.
# b. adding on the website UI interface.
# 1. Add ”batch_size=64“ (not necessary)
# 2. Add ”enable_modelarts=True“
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Under the item "resource pool selection", select the specification of 8 cards.
# (10) Create your job.
```
- evaluating with single card on ModelArts
```python
# (1) Upload the code folder to S3 bucket.
# (2) Click to "create training task" on the website UI interface.
# (3) Set the code directory to "/{path}/not necessary" on the website UI interface.
# (4) Set the startup file to /{path}/not necessary/eval.py" on the website UI interface.
# (5) Perform a or b.
# a. setting parameters in /{path}/not necessary/imagenet_config.yaml.
# 1. Set ”enable_modelarts: True“
# 2. Set “checkpoint_path: {checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
# 3. Add ”modelarts_dataset_unzip_name: {filenmae}",if the data is uploaded in the form of zip package.
# b. adding on the website UI interface.
# 1. Set ”enable_modelarts=True“
# 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Under the item "resource pool selection", select the specification of a single card.
# (10) Create your job.
```
For more details, please refer the specify script. For more details, please refer the specify script.
# [Script Description](#contents) # [Script Description](#contents)
@ -106,64 +156,69 @@ For more details, please refer the specify script.
├── README_CN.md // descriptions about Tiny-Darknet in Chinese ├── README_CN.md // descriptions about Tiny-Darknet in Chinese
├── ascend310_infer // application for 310 inference ├── ascend310_infer // application for 310 inference
├── scripts ├── scripts
├──run_standalone_train.sh // shell script for single on Ascend ├── run_standalone_train.sh // shell script for single on Ascend
├──run_distribute_train.sh // shell script for distributed on Ascend ├── run_distribute_train.sh // shell script for distributed on Ascend
├──run_eval.sh // shell script for evaluation on Ascend ├── run_eval.sh // shell script for evaluation on Ascend
├──run_infer_310.sh // shell script for inference on Ascend310 ├── run_infer_310.sh // shell script for inference on Ascend310
├── src ├── src
├─lr_scheduler //learning rate scheduler ├── lr_scheduler //learning rate scheduler
├─__init__.py // init ├── __init__.py // init
├─linear_warmup.py // linear_warmup ├── linear_warmup.py // linear_warmup
├─warmup_cosine_annealing_lr.py // warmup_cosine_annealing_lr ├── warmup_cosine_annealing_lr.py // warmup_cosine_annealing_lr
├─warmup_step_lr.py // warmup_step_lr ├── warmup_step_lr.py // warmup_step_lr
├──dataset.py // creating dataset ├── model_utils
├──CrossEntropySmooth.py // loss function ├── config.py // parsing parameter configuration file of "*.yaml"
├──tinydarknet.py // Tiny-Darknet architecture ├── device_adapter.py // local or ModelArts training
├──config.py // parameter configuration ├── local_adapter.py // get related environment variables in local training
├── train.py // training script └── moxing_adapter.py // get related environment variables in ModelArts training
├── eval.py // evaluation script ├── dataset.py // creating dataset
├── export.py // export checkpoint file into air/onnx ├── CrossEntropySmooth.py // loss function
├── mindspore_hub_conf.py // hub config ├── tinydarknet.py // Tiny-Darknet architecture
├── postprocess.py // postprocess script ├── train.py // training script
├── eval.py // evaluation script
├── export.py // export checkpoint file into air/onnx
├── imagenet_config.yaml // parameter configuration
├── mindspore_hub_conf.py // hub config
├── postprocess.py // postprocess script
``` ```
## [Script Parameters](#contents) ## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py Parameters for both training and evaluation can be set in `imagenet_config.yaml`
- config for Tiny-Darknet - config for Tiny-Darknet(only some parameters are listed)
```python ```python
'pre_trained': 'False' # whether training based on the pre-trained model pre_trained: False # whether training based on the pre-trained model
'num_classes': 1000 # the number of classes in the dataset num_classes: 1000 # the number of classes in the dataset
'lr_init': 0.1 # initial learning rate lr_init: 0.1 # initial learning rate
'batch_size': 128 # training batch_size batch_size: 128 # training batch_size
'epoch_size': 500 # total training epoch epoch_size: 500 # total training epoch
'momentum': 0.9 # momentum momentum: 0.9 # momentum
'weight_decay': 1e-4 # weight decay value weight_decay: 1e-4 # weight decay value
'image_height': 224 # image height used as input to the model image_height: 224 # image height used as input to the model
'image_width': 224 # image width used as input to the model image_width: 224 # image width used as input to the model
'data_path': './ImageNet_Original/train/' # absolute full path to the train datasets train_data_dir: './ImageNet_Original/train/' # absolute full path to the train datasets
'val_data_path': './ImageNet_Original/val/' # absolute full path to the evaluation datasets val_data_dir: './ImageNet_Original/val/' # absolute full path to the evaluation datasets
'device_target': 'Ascend' # device running the program device_target: 'Ascend' # device running the program
'keep_checkpoint_max': 10 # only keep the last keep_checkpoint_max checkpoint keep_checkpoint_max: 10 # only keep the last keep_checkpoint_max checkpoint
'checkpoint_path': '/train_tinydarknet.ckpt' # the absolute full path to save the checkpoint file checkpoint_path: '/train_tinydarknet.ckpt' # the absolute full path to save the checkpoint file
'onnx_filename': 'tinydarknet.onnx' # file name of the onnx model used in export.py onnx_filename: 'tinydarknet.onnx' # file name of the onnx model used in export.py
'air_filename': 'tinydarknet.air' # file name of the air model used in export.py air_filename: 'tinydarknet.air' # file name of the air model used in export.py
'lr_scheduler': 'exponential' # learning rate scheduler lr_scheduler: 'exponential' # learning rate scheduler
'lr_epochs': [70, 140, 210, 280] # epoch of lr changing lr_epochs: [70, 140, 210, 280] # epoch of lr changing
'lr_gamma': 0.3 # decrease lr by a factor of exponential lr_scheduler lr_gamma: 0.3 # decrease lr by a factor of exponential lr_scheduler
'eta_min': 0.0 # eta_min in cosine_annealing scheduler eta_min: 0.0 # eta_min in cosine_annealing scheduler
'T_max': 150 # T-max in cosine_annealing scheduler T_max: 150 # T-max in cosine_annealing scheduler
'warmup_epochs': 0 # warmup epoch warmup_epochs: 0 # warmup epoch
'is_dynamic_loss_scale': 0 # dynamic loss scale is_dynamic_loss_scale: 0 # dynamic loss scale
'loss_scale': 1024 # loss scale loss_scale: 1024 # loss scale
'label_smooth_factor': 0.1 # label_smooth_factor label_smooth_factor: 0.1 # label_smooth_factor
'use_label_smooth': True # label smooth use_label_smooth: True # label smooth
``` ```
For more configuration details, please refer the script config.py. For more configuration details, please refer the script `imagenet_config.yaml`.
## [Training Process](#contents) ## [Training Process](#contents)
@ -172,7 +227,7 @@ For more configuration details, please refer the script config.py.
- running on Ascend - running on Ascend
```python ```python
bash scripts/run_standalone_train.sh 0 bash scripts/run_standalone_train.sh [DEVICE_ID]
``` ```
The command above will run in the background, you can view the results through the file train.log. The command above will run in the background, you can view the results through the file train.log.
@ -199,7 +254,7 @@ For more configuration details, please refer the script config.py.
- running on Ascend - running on Ascend
```python ```python
bash ./scripts/run_distribute_train.sh rank_table.json bash ./scripts/run_distribute_train.sh [RANK_TABLE_FILE]
``` ```
The above shell script will run distribute training in the background. You can view the results through the file train_parallel[X]/log. The loss value will be achieved as follows: The above shell script will run distribute training in the background. You can view the results through the file train_parallel[X]/log. The loss value will be achieved as follows:
@ -252,7 +307,7 @@ For more configuration details, please refer the script config.py.
python export.py --dataset [DATASET] --file_name [FILE_NAME] --file_format [EXPORT_FORMAT] python export.py --dataset [DATASET] --file_name [FILE_NAME] --file_format [EXPORT_FORMAT]
``` ```
The parameter does not have the ckpt_file option. Please store the ckpt file according to the path of the parameter `checkpoint_path` in `config.py`. The parameter does not have the ckpt_file option. Please store the ckpt file according to the path of the parameter `checkpoint_path` in `imagenet_config.yaml`.
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"] `EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
### Infer on Ascend310 ### Infer on Ascend310

View File

@ -87,7 +87,7 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
bash ./scripts/run_standalone_train.sh 0 bash ./scripts/run_standalone_train.sh 0
# 分布式训练 # 分布式训练
bash ./scripts/run_distribute_train.sh rank_table.json bash ./scripts/run_distribute_train.sh /{path}/*.json
# 评估 # 评估
python eval.py > eval.log 2>&1 & python eval.py > eval.log 2>&1 &
@ -95,12 +95,61 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
bash ./script/run_eval.sh bash ./script/run_eval.sh
``` ```
进行并行训练时, 需要提前创建JSON格式的hccl配置文件。 进行并行训练时, 需要提前创建JSON格式的hccl配置文件 [RANK_TABLE_FILE]
请按照以下链接的指导进行设置: 请按照以下链接的指导进行设置:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.> <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
- 在ModelArts上运行
如果你想在modelarts上运行可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/)
- 在ModelArt上使用8卡训练
```python
# (1) 上传你的代码到 s3 桶上
# (2) 在ModelArts上创建训练任务
# (3) 选择代码目录 /{path}/tinydarknet
# (4) 选择启动文件 /{path}/tinydarknet/train.py
# (5) 执行a或b
# a. 在 /{path}/tinydarknet/imagenet_config.yaml 文件中设置参数
# 1. 设置 ”batch_size: 64“ (非必须)
# 2. 设置 ”enable_modelarts: True“
# 3. 如果数据采用zip格式压缩包的形式上传设置 ”modelarts_dataset_unzip_name: {filenmae}"
# b. 在 网页上设置
# 1. 添加 ”batch_size=True“
# 2. 添加 ”enable_modelarts=True“
# 3. 如果数据采用zip格式压缩包的形式上传添加 ”modelarts_dataset_unzip_name={filenmae}"
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
# (9) 在网页上的’资源池选择‘项目下, 选择8卡规格的资源
# (10) 创建训练作业
```
- 在ModelArts上使用单卡验证
```python
# (1) 上传你的代码到 s3 桶上
# (2) 在ModelArts上创建训练任务
# (3) 选择代码目录 /{path}/tinydarknet
# (4) 选择启动文件 /{path}/tinydarknet/eval.py
# (5) 执行a或b
# a. 在 /path/tinydarknet 下的imagenet_config.yaml 文件中设置参数
# 1. 设置 ”enable_modelarts: True“
# 2. 设置 “checkpoint_path: {checkpoint_path}”({checkpoint_path}表示待评估的 权重文件 相对于 eval.py 的路径,权重文件须包含在代码目录下。)
# 3. 如果数据采用zip格式压缩包的形式上传设置 ”modelarts_dataset_unzip_name: {filenmae}"
# b. 在 网页上设置
# 1. 设置 ”enable_modelarts=True“
# 2. 设置 “checkpoint_path={checkpoint_path}”({checkpoint_path}表示待评估的 权重文件 相对于 eval.py 的路径,权重文件须包含在代码目录下。)
# 3. 如果数据采用zip格式压缩包的形式上传设置 ”modelarts_dataset_unzip_name={filenmae}"
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
# (9) 在网页上的’资源池选择‘项目下, 选择单卡规格的资源
# (10) 创建训练作业
```
更多的细节请参考具体的script文件 更多的细节请参考具体的script文件
# [脚本描述](#目录) # [脚本描述](#目录)
@ -110,68 +159,73 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
```bash ```bash
├── tinydarknet ├── tinydarknet
├── README.md // Tiny-Darknet英文说明 ├── README.md // Tiny-Darknet英文说明
├── README_CN.md // Tiny-Darknet中文说明 ├── README_CN.md // Tiny-Darknet中文说明
├── ascend310_infer // 用于310推理 ├── ascend310_infer // 用于310推理
├── scripts ├── scripts
├──run_standalone_train.sh // Ascend单卡训练shell脚本 ├── run_standalone_train.sh // Ascend单卡训练shell脚本
├──run_distribute_train.sh // Ascend分布式训练shell脚本 ├── run_distribute_train.sh // Ascend分布式训练shell脚本
├──run_eval.sh // Ascend评估shell脚本 ├── run_eval.sh // Ascend评估shell脚本
├──run_infer_310.sh // Ascend310推理shell脚本 └── run_infer_310.sh // Ascend310推理shell脚本
├── src ├── src
├─lr_scheduler //学习率策略 ├── lr_scheduler // 学习率策略
├─__init__.py // 初始化文件 ├── __init__.py // 初始化文件
├─linear_warmup.py // linear_warmup策略 ├── linear_warmup.py // linear_warmup策略
├─warmup_cosine_annealing_lr.py // warmup_cosine_annealing_lr策略 ├── warmup_cosine_annealing_lr.py // warmup_cosine_annealing_lr策略
├─warmup_step_lr.py // warmup_step_lr策略 └── warmup_step_lr.py // warmup_step_lr策略
├──dataset.py // 创建数据集 ├── model_utils
├──CrossEntropySmooth.py // 损失函数 ├── config.py // 解析 *.yaml 参数配置文件
├──tinydarknet.py // Tiny-Darknet网络结构 ├── device_adapter.py // 区分本地/ModelArts训练
├──config.py // 参数配置 ├── local_adapter.py // 本地训练获取相关环境变量
├── train.py // 训练脚本 └── moxing_adapter.py // ModelArts训练获取相关环境变量、交换数据
├── eval.py // 评估脚本 ├── dataset.py // 创建数据集
├── export.py // 导出checkpoint文件 ├── CrossEntropySmooth.py // 损失函数
├── mindspore_hub_conf.py // hub配置文件 └── tinydarknet.py // Tiny-Darknet网络结构
├── postprocess.py // 310推理后处理脚本 ├── train.py // 训练脚本
├── eval.py // 评估脚本
├── export.py // 导出checkpoint文件
├── imagenet_config.yaml // 参数配置
├── mindspore_hub_conf.py // hub配置文件
└── postprocess.py // 310推理后处理脚本
``` ```
## [脚本参数](#目录) ## [脚本参数](#目录)
训练和测试的参数可在 config.py 中进行设置 训练和测试的参数可在 `imagenet_config.yaml` 中进行设置
- Tiny-Darknet的配置文件 - Tiny-Darknet的配置文件(仅列出部分参数)
```python ```python
'pre_trained': 'False' # 是否载入预训练模型 pre_trained: False # 是否载入预训练模型
'num_classes': 1000 # 数据集中类的数量 num_classes: 1000 # 数据集中类的数量
'lr_init': 0.1 # 初始学习率 lr_init: 0.1 # 初始学习率
'batch_size': 128 # 训练的batch_size batch_size: 128 # 训练的batch_size
'epoch_size': 500 # 总共的训练epoch epoch_size: 500 # 总共的训练epoch
'momentum': 0.9 # 动量 momentum: 0.9 # 动量
'weight_decay': 1e-4 # 权重衰减率 weight_decay: 1e-4 # 权重衰减率
'image_height': 224 # 输入图像的高度 image_height: 224 # 输入图像的高度
'image_width': 224 # 输入图像的宽度 image_width: 224 # 输入图像的宽度
'data_path': './ImageNet_Original/train/' # 训练数据集的绝对路径 train_data_dir: './ImageNet_Original/train/' # 训练数据集的绝对路径
'val_data_path': './ImageNet_Original/val/' # 评估数据集的绝对路径 val_data_dir: './ImageNet_Original/val/' # 评估数据集的绝对路径
'device_target': 'Ascend' # 程序运行的设备 device_target: 'Ascend' # 程序运行的设备
'keep_checkpoint_max': 10 # 仅仅保持最新的keep_checkpoint_max个checkpoint文件 keep_checkpoint_max: 10 # 仅仅保持最新的keep_checkpoint_max个checkpoint文件
'checkpoint_path': '/train_tinydarknet.ckpt' # 保存checkpoint文件的绝对路径 checkpoint_path: '/train_tinydarknet.ckpt' # 保存checkpoint文件的绝对路径
'onnx_filename': 'tinydarknet.onnx' # 用于export.py 文件中的onnx模型的文件名 onnx_filename: 'tinydarknet.onnx' # 用于export.py 文件中的onnx模型的文件名
'air_filename': 'tinydarknet.air' # 用于export.py 文件中的air模型的文件名 air_filename: 'tinydarknet.air' # 用于export.py 文件中的air模型的文件名
'lr_scheduler': 'exponential' # 学习率策略 lr_scheduler: 'exponential' # 学习率策略
'lr_epochs': [70, 140, 210, 280] # 学习率进行变化的epoch数 lr_epochs: [70, 140, 210, 280] # 学习率进行变化的epoch数
'lr_gamma': 0.3 # lr_scheduler为exponential时的学习率衰减因子 lr_gamma: 0.3 # lr_scheduler为exponential时的学习率衰减因子
'eta_min': 0.0 # cosine_annealing策略中的eta_min eta_min: 0.0 # cosine_annealing策略中的eta_min
'T_max': 150 # cosine_annealing策略中的T-max T_max: 150 # cosine_annealing策略中的T-max
'warmup_epochs': 0 # 热启动的epoch数 warmup_epochs: 0 # 热启动的epoch数
'is_dynamic_loss_scale': 0 # 动态损失尺度 is_dynamic_loss_scale: 0 # 动态损失尺度
'loss_scale': 1024 # 损失尺度 loss_scale: 1024 # 损失尺度
'label_smooth_factor': 0.1 # 训练标签平滑因子 label_smooth_factor: 0.1 # 训练标签平滑因子
'use_label_smooth': True # 是否采用训练标签平滑 use_label_smooth: True # 是否采用训练标签平滑
``` ```
更多的细节, 请参考`config.py`. 更多的细节, 请参考`imagenet_config.yaml`.
## [训练过程](#目录) ## [训练过程](#目录)
@ -180,7 +234,7 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
- 在Ascend资源上运行 - 在Ascend资源上运行
```python ```python
bash ./scripts/run_standalone_train.sh 0 bash ./scripts/run_standalone_train.sh [DEVICE_ID]
``` ```
上述的命令将运行在后台中,可以通过 `train.log` 文件查看运行结果. 上述的命令将运行在后台中,可以通过 `train.log` 文件查看运行结果.
@ -207,7 +261,7 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
- 在Ascend资源上运行 - 在Ascend资源上运行
```python ```python
bash scripts/run_distribute_train.sh rank_table.json bash scripts/run_distribute_train.sh [RANK_TABLE_FILE]
``` ```
上述的脚本命令将在后台中进行分布式训练,可以通过`train_parallel[X]/log`文件查看运行结果. 训练的损失值将以如下的形式展示: 上述的脚本命令将在后台中进行分布式训练,可以通过`train_parallel[X]/log`文件查看运行结果. 训练的损失值将以如下的形式展示:
@ -229,7 +283,7 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
- 在Ascend资源上进行评估: - 在Ascend资源上进行评估:
在运行如下命令前,请确认用于评估的checkpoint文件的路径.请将checkpoint路径设置为绝对路径,例如:"/username/imagenet/train_tinydarknet.ckpt" 在运行如下命令前,请确认用于评估的checkpoint文件的路径.checkpoint文件须包含在tinydarknet文件夹内.请将checkpoint路径设置为相对于 eval.py文件 的路径,例如:"./ckpts/train_tinydarknet.ckpt"(ckpts 与 eval.py 同级).
```python ```python
python eval.py > eval.log 2>&1 & python eval.py > eval.log 2>&1 &
@ -259,7 +313,7 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
python export.py --dataset [DATASET] --file_name [FILE_NAME] --file_format [EXPORT_FORMAT] python export.py --dataset [DATASET] --file_name [FILE_NAME] --file_format [EXPORT_FORMAT]
``` ```
参数没有ckpt_file选项ckpt文件请按照`config.py`中参数`checkpoint_path`的路径存放。 参数没有ckpt_file选项ckpt文件请按照`imagenet_config.yaml`中参数`checkpoint_path`的路径存放。
`EXPORT_FORMAT` 可选 ["AIR", "MINDIR"]. `EXPORT_FORMAT` 可选 ["AIR", "MINDIR"].
### 在Ascend310执行推理 ### 在Ascend310执行推理

View File

@ -16,34 +16,82 @@
##############test tinydarknet example on cifar10################# ##############test tinydarknet example on cifar10#################
python eval.py python eval.py
""" """
import argparse import os
import time
from mindspore import context from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed from mindspore.common import set_seed
from src.config import imagenet_cfg
from src.dataset import create_dataset_imagenet from src.dataset import create_dataset_imagenet
from src.tinydarknet import TinyDarkNet from src.tinydarknet import TinyDarkNet
from src.CrossEntropySmooth import CrossEntropySmooth from src.CrossEntropySmooth import CrossEntropySmooth
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_num
set_seed(1) set_seed(1)
parser = argparse.ArgumentParser(description='tinydarknet') def modelarts_pre_process():
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'], '''modelarts pre process function.'''
help='dataset name.') def unzip(zip_file, save_dir):
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') import zipfile
args_opt = parser.parse_args() s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("Unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if __name__ == '__main__': if config.modelarts_dataset_unzip_name:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
if args_opt.dataset_name == "imagenet": sync_lock = "/tmp/unzip_sync.lock"
cfg = imagenet_cfg
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False) # Each server contains 8 devices as most.
if config.device_id % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(config.device_id, zip_file_1, save_dir_1))
config.checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.checkpoint_path)
if not os.path.exists(config.checkpoint_path):
raise ValueError("Check parameter 'checkpoint_path'. for more details, you can see README.md")
config.val_data_dir = config.data_path
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
if config.dataset_name == "imagenet":
cfg = config
dataset = create_dataset_imagenet(cfg.val_data_dir, 1, False)
if not cfg.use_label_smooth: if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0 cfg.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean", loss = CrossEntropySmooth(sparse=True, reduction="mean",
@ -52,20 +100,19 @@ if __name__ == '__main__':
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
else: else:
raise ValueError("dataset is not support.") raise ValueError("Dataset is not support.")
device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
if config.device_target == "Ascend":
if args_opt.checkpoint_path is not None: context.set_context(device_id=config.device_id)
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(cfg.checkpoint_path)
print("load checkpoint from [{}].".format(args_opt.checkpoint_path)) print("Load checkpoint from [{}].".format(cfg.checkpoint_path))
else:
param_dict = load_checkpoint(cfg.checkpoint_path)
print("load checkpoint from [{}].".format(cfg.checkpoint_path))
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net.set_train(False) net.set_train(False)
acc = model.eval(dataset) acc = model.eval(dataset)
print("accuracy: ", acc) print("accuracy: ", acc)
if __name__ == '__main__':
run_eval()

View File

@ -23,7 +23,7 @@ import mindspore as ms
from mindspore import Tensor from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import imagenet_cfg from src.model_utils.config import config as imagenet_cfg
from src.tinydarknet import TinyDarkNet from src.tinydarknet import TinyDarkNet
if __name__ == '__main__': if __name__ == '__main__':
@ -38,7 +38,7 @@ if __name__ == '__main__':
if args_opt.dataset_name == 'imagenet': if args_opt.dataset_name == 'imagenet':
cfg = imagenet_cfg cfg = imagenet_cfg
else: else:
raise ValueError("dataset is not support.") raise ValueError("Dataset is not support.")
net = TinyDarkNet(num_classes=cfg.num_classes) net = TinyDarkNet(num_classes=cfg.num_classes)

View File

@ -0,0 +1,57 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
enable_profiling: False
modelarts_dataset_unzip_name: ''
# ==============================================================================
#train-eval-export related
dataset_name: imagenet
ckpt_save_dir: checkpoints
pre_trained: False
device_id: 0
num_classes: 1000
lr_init: 0.1
batch_size: 128
epoch_size: 500
momentum: 0.9
weight_decay: 0.0001
image_height: 224
image_width: 224
train_data_dir: './dataset/imagenet_original/train/'
val_data_dir: './dataset/imagenet_original/val/'
keep_checkpoint_max: 1
checkpoint_path: './scripts/train_parallel4/ckpt_4/train_tinydarknet_imagenet-300_1251.ckpt'
onnx_filename: 'tinydarknet.onnx'
air_filename: 'tinydarknet.air'
# optimizer and lr related
lr_scheduler: 'exponential'
lr_epochs: [70, 140, 210, 280]
lr_gamma: 0.3
eta_min: 0.0
T_max: 150
warmup_epochs: 0
# loss related
is_dynamic_loss_scale: False
loss_scale: 1024
label_smooth_factor: 0.1
use_label_smooth: True
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts, default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
enable_profiling: 'Whether enable profiling while training, default: False'

View File

@ -28,7 +28,6 @@ then
exit 1 exit 1
fi fi
dataset_type='imagenet' dataset_type='imagenet'
if [ $# == 2 ] if [ $# == 2 ]
then then
@ -56,11 +55,12 @@ do
export RANK_ID=$((rank_start + i)) export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i rm -rf ./train_parallel$i
mkdir ./train_parallel$i mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i cp -r ../src ./train_parallel$i
cp ./train.py ./train_parallel$i cp ../train.py ./train_parallel$i
cp ../*.yaml ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type" echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
cd ./train_parallel$i || exit cd ./train_parallel$i || exit
env > env.log env > env.log
python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 & python train.py --dataset_name=$dataset_type > log 2>&1 &
cd .. cd ..
done done

View File

@ -14,6 +14,16 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
abs_path=$(readlink -f "$0")
cur_path=$(dirname $abs_path)
cd $cur_path
rm -rf ./eval rm -rf ./eval
mkdir ./eval mkdir ./eval
python ./eval.py > ./eval/eval.log 2>&1 & cp -r ../src ./eval
cp ../eval.py ./eval
cp ../*.yaml ./eval
cd ./eval || exit
env >env.log
python ./eval.py > ./eval.log 2>&1 &
cd ..

View File

@ -14,10 +14,48 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
echo "$1 $2 $3"
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: bash run_distribute_train.sh [DEVICE_ID] [TRAIN_DATA_DIR] [cifar10|imagenet]"
exit 1
fi
expr $1 + 6 &>/dev/null
if [ $? != 0 ]
then
echo "error:DEVICE_ID=$1 is not a integer"
exit 1
fi
if [ ! -d $2 ]
then
echo "error:TRAIN_DATA_DIR=$2 is not a folder"
exit 1
fi
train_data_dir=$2
dataset_type='imagenet'
if [ $# == 3 ]
then
if [ $3 != "cifar10" ] && [ $3 != "imagenet" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet"
exit 1
fi
dataset_type=$3
fi
export DEVICE_ID=$1 export DEVICE_ID=$1
export RANK_ID=0
export DEVICE_NUM=1
export RANK_SIZE=1
rm -rf ./train_single rm -rf ./train_single
mkdir ./train_single mkdir ./train_single
cp -r ./src ./train_single cp -r ../src ./train_single
cp ./train.py ./train_single cp ../train.py ./train_single
cp ../*.yaml ./train_single
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
cd ./train_single || exit cd ./train_single || exit
python ./train.py --device_id=$DEVICE_ID > ./train.log 2>&1 & python ./train.py --dataset_name=$dataset_type --train_data_dir=$train_data_dir> ./train.log 2>&1 &

View File

@ -1,52 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
imagenet_cfg = edict({
'name': 'imagenet',
'pre_trained': False,
'num_classes': 1000,
'lr_init': 0.1,
'batch_size': 128,
'epoch_size': 500,
'momentum': 0.9,
'weight_decay': 1e-4,
'image_height': 224,
'image_width': 224,
'data_path': './dataset/imagenet_original/train/',
'val_data_path': './dataset/imagenet_original/val/',
'device_target': 'Ascend',
'keep_checkpoint_max': 1,
'checkpoint_path': './scripts/train_parallel4/ckpt_4/train_tinydarknet_imagenet-300_1251.ckpt',
'onnx_filename': 'tinydarknet.onnx',
'air_filename': 'tinydarknet.air',
# optimizer and lr related
'lr_scheduler': 'exponential',
'lr_epochs': [70, 140, 210, 280],
'lr_gamma': 0.3,
'eta_min': 0.0,
'T_max': 150,
'warmup_epochs': 0,
# loss related
'is_dynamic_loss_scale': False,
'loss_scale': 1024,
'label_smooth_factor': 0.1,
'use_label_smooth': True,
})

View File

@ -21,7 +21,7 @@ import mindspore.common.dtype as mstype
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision import mindspore.dataset.vision.c_transforms as vision
from src.config import imagenet_cfg from src.model_utils.config import config as imagenet_cfg
def create_dataset_imagenet(dataset_path, repeat_num=1, training=True, def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
num_parallel_workers=None, shuffle=None): num_parallel_workers=None, shuffle=None):

View File

@ -0,0 +1,132 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
_config = "imagenet_config.yaml"
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path=_config):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../{}".format(_config)),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()
if __name__ == '__main__':
print(config)

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from src.model_utils.config import config
if config.enable_modelarts:
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,123 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from src.model_utils.config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
# print("os.mknod({}) success".format(sync_lock))
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -17,11 +17,11 @@
python train.py python train.py
""" """
import os import os
import argparse import time
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.communication.management import init, get_rank from mindspore.communication.management import init
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
@ -30,14 +30,15 @@ from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed from mindspore.common import set_seed
from src.config import imagenet_cfg
from src.dataset import create_dataset_imagenet from src.dataset import create_dataset_imagenet
from src.tinydarknet import TinyDarkNet from src.tinydarknet import TinyDarkNet
from src.CrossEntropySmooth import CrossEntropySmooth from src.CrossEntropySmooth import CrossEntropySmooth
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
set_seed(1) set_seed(1)
def lr_steps_imagenet(_cfg, steps_per_epoch): def lr_steps_imagenet(_cfg, steps_per_epoch):
"""lr step for imagenet""" """lr step for imagenet"""
from src.lr_scheduler.warmup_step_lr import warmup_step_lr from src.lr_scheduler.warmup_step_lr import warmup_step_lr
@ -62,54 +63,104 @@ def lr_steps_imagenet(_cfg, steps_per_epoch):
return _lr return _lr
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("Unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if __name__ == '__main__': if config.modelarts_dataset_unzip_name:
parser = argparse.ArgumentParser(description='Classification') zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'], save_dir_1 = os.path.join(config.data_path)
help='dataset name.')
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
args_opt = parser.parse_args()
if args_opt.dataset_name == "imagenet": sync_lock = "/tmp/unzip_sync.lock"
cfg = imagenet_cfg
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
config.ckpt_save_dir = os.path.join(config.output_path, config.ckpt_save_dir)
config.train_data_dir = config.data_path
config.checkpoint_path = config.load_path
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
if config.dataset_name == "imagenet":
pass
elif config.dataset_name == "cifar10":
raise ValueError("Unsupported dataset: 'cifar10'.")
else: else:
raise ValueError("Unsupported dataset.") raise ValueError("Unsupported dataset.")
# set context # set context
device_target = cfg.device_target device_target = config.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
device_num = int(os.environ.get("DEVICE_NUM", 1)) device_num = get_device_num()
rank = 0 rank = 0
if device_target == "Ascend": if device_target == "Ascend":
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=get_device_id())
if device_num > 1: if device_num > 1:
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
init() init()
rank = get_rank() rank = get_rank_id()
else: else:
raise ValueError("Unsupported platform.") raise ValueError("Unsupported platform.")
if args_opt.dataset_name == "imagenet": if config.dataset_name == "imagenet":
dataset = create_dataset_imagenet(cfg.data_path, 1) dataset = create_dataset_imagenet(config.train_data_dir, 1)
else: else:
raise ValueError("Unsupported dataset.") raise ValueError("Unsupported dataset.")
batch_num = dataset.get_dataset_size() batch_num = dataset.get_dataset_size()
net = TinyDarkNet(num_classes=cfg.num_classes) net = TinyDarkNet(num_classes=config.num_classes)
# Continue training if set pre_trained to be True # Continue training if set pre_trained to be True
if cfg.pre_trained: if config.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path) param_dict = load_checkpoint(config.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
loss_scale_manager = None loss_scale_manager = None
if args_opt.dataset_name == 'imagenet': if config.dataset_name == 'imagenet':
lr = lr_steps_imagenet(cfg, batch_num) lr = lr_steps_imagenet(config, batch_num)
def get_param_groups(network): def get_param_groups(network):
""" get param groups """ """ get param groups """
@ -132,32 +183,35 @@ if __name__ == '__main__':
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
if cfg.is_dynamic_loss_scale: if config.is_dynamic_loss_scale:
cfg.loss_scale = 1 config.loss_scale = 1
opt = Momentum(params=get_param_groups(net), opt = Momentum(params=get_param_groups(net),
learning_rate=Tensor(lr), learning_rate=Tensor(lr),
momentum=cfg.momentum, momentum=config.momentum,
weight_decay=cfg.weight_decay, weight_decay=config.weight_decay,
loss_scale=cfg.loss_scale) loss_scale=config.loss_scale)
if not cfg.use_label_smooth: if not config.use_label_smooth:
cfg.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean", loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
if cfg.is_dynamic_loss_scale: if config.is_dynamic_loss_scale:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else: else:
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False) loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O3", loss_scale_manager=loss_scale_manager) amp_level="O3", loss_scale_manager=loss_scale_manager)
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 50, keep_checkpoint_max=cfg.keep_checkpoint_max) config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 50, keep_checkpoint_max=config.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num) time_cb = TimeMonitor(data_size=batch_num)
ckpt_save_dir = "./ckpt_" + str(rank) + "/" ckpt_save_dir = os.path.join(config.ckpt_save_dir, str(rank))
ckpoint_cb = ModelCheckpoint(prefix="train_tinydarknet_" + args_opt.dataset_name, directory=ckpt_save_dir, ckpoint_cb = ModelCheckpoint(prefix="train_tinydarknet_" + config.dataset_name, directory=ckpt_save_dir,
config=config_ck) config=config_ck)
loss_cb = LossMonitor() loss_cb = LossMonitor()
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) model.train(config.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success") print("train success")
if __name__ == '__main__':
run_train()

View File

@ -119,7 +119,7 @@ The dataset is self-generated using a third-party library called [captcha](https
- running on ModelArts - running on ModelArts
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
- 在ModelArt上使用8卡训练 - Training with 8 cards on ModelArts
```python ```python
# (1) Upload the code folder to S3 bucket. # (1) Upload the code folder to S3 bucket.
@ -138,10 +138,11 @@ The dataset is self-generated using a third-party library called [captcha](https
# (6) Upload the dataset or the zip package of dataset to S3 bucket. # (6) Upload the dataset or the zip package of dataset to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path). # (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface. # (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Create your job. # (9) Under the item "resource pool selection", select the specification of 8 cards.
# (10) Create your job.
``` ```
- 在ModelArts上使用单卡验证 - evaluating with single card on ModelArts
```python ```python
# (1) Upload the code folder to S3 bucket. # (1) Upload the code folder to S3 bucket.
@ -160,7 +161,8 @@ The dataset is self-generated using a third-party library called [captcha](https
# (6) Upload the dataset or the zip package of dataset to S3 bucket. # (6) Upload the dataset or the zip package of dataset to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path). # (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface. # (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Create your job. # (9) Under the item "resource pool selection", select the specification of a single card.
# (10) Create your job.
``` ```
## [Script Description](#contents) ## [Script Description](#contents)

View File

@ -122,7 +122,7 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
``` ```
- 在ModelArts上运行 - 在ModelArts上运行
如果你想在modelarts上运行可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/)) 如果你想在modelarts上运行可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/)
- 在ModelArt上使用8卡训练 - 在ModelArt上使用8卡训练
```python ```python
@ -142,7 +142,8 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上 # (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包 # (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径” # (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
# (9) 创建训练作业 # (9) 在网页上的’资源池选择‘项目下, 选择8卡规格的资源
# (10) 创建训练作业
``` ```
- 在ModelArts上使用单卡验证 - 在ModelArts上使用单卡验证
@ -164,7 +165,8 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上 # (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包 # (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径” # (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
# (9) 创建训练作业 # (9) 在网页上的’资源池选择‘项目下, 选择单卡规格的资源
# (10) 创建训练作业
``` ```
## 脚本说明 ## 脚本说明

View File

@ -42,15 +42,15 @@ def modelarts_pre_process():
fz = zipfile.ZipFile(zip_file, 'r') fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist()) data_num = len(fz.namelist())
print("Extract Start...") print("Extract Start...")
print("unzip file num: {}".format(data_num)) print("Unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1 data_print = int(data_num / 100) if data_num > 100 else 1
i = 0 i = 0
for file in fz.namelist(): for file in fz.namelist():
if i % data_print == 0: if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True) print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1 i += 1
fz.extract(file, save_dir) fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60), print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60))) int(int(time.time() - s_time) % 60)))
print("Extract Done.") print("Extract Done.")
else: else:

View File

@ -27,7 +27,7 @@ if config.device_target == "Ascend":
context.set_context(device_id=get_device_id()) context.set_context(device_id=get_device_id())
if config.file_format == "AIR" and config.device_target != "Ascend": if config.file_format == "AIR" and config.device_target != "Ascend":
raise ValueError("export AIR must on Ascend") raise ValueError("Export AIR must on Ascend")
if __name__ == "__main__": if __name__ == "__main__":
input_size = m.ceil(config.captcha_height / 64) * 64 * 3 input_size = m.ceil(config.captcha_height / 64) * 64 * 3

View File

@ -49,15 +49,15 @@ def modelarts_pre_process():
fz = zipfile.ZipFile(zip_file, 'r') fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist()) data_num = len(fz.namelist())
print("Extract Start...") print("Extract Start...")
print("unzip file num: {}".format(data_num)) print("Unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1 data_print = int(data_num / 100) if data_num > 100 else 1
i = 0 i = 0
for file in fz.namelist(): for file in fz.namelist():
if i % data_print == 0: if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True) print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1 i += 1
fz.extract(file, save_dir) fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60), print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60))) int(int(time.time() - s_time) % 60)))
print("Extract Done.") print("Extract Done.")
else: else:
@ -88,7 +88,7 @@ def modelarts_pre_process():
time.sleep(1) time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
config.save_checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.save_checkpoint_path) config.save_checkpoint_path = os.path.join(config.output_path, config.save_checkpoint_path)
@moxing_wrapper(pre_process=modelarts_pre_process) @moxing_wrapper(pre_process=modelarts_pre_process)
@ -100,16 +100,22 @@ def train():
lr_scale = 1 lr_scale = 1
if config.run_distribute: if config.run_distribute:
if config.device_target == 'Ascend': if config.device_target == 'Ascend':
device_num = int(os.environ.get("RANK_SIZE")) device_num = get_device_num()
rank = int(os.environ.get("RANK_ID")) rank = get_rank_id()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else: else:
init()
device_num = get_group_size() device_num = get_group_size()
rank = get_rank() rank = get_rank()
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
init()
else: else:
device_num = 1 device_num = 1
rank = 0 rank = 0
@ -149,7 +155,7 @@ def train():
if config.save_checkpoint: if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max) keep_checkpoint_max=config.keep_checkpoint_max)
save_ckpt_path = config.save_checkpoint_path save_ckpt_path = os.path.join(config.save_checkpoint_path, str(rank))
ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=save_ckpt_path, config=config_ck) ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=save_ckpt_path, config=config_ck)
callbacks.append(ckpt_cb) callbacks.append(ckpt_cb)
model.train(config.epoch_size, dataset, callbacks=callbacks) model.train(config.epoch_size, dataset, callbacks=callbacks)