forked from mindspore-Ecosystem/mindspore
!17764 tinydarknet can been used on ModelArts and warpctc init
From: @ZhengBina Reviewed-by: @oacjiewen,@wuxuejian Signed-off-by: @wuxuejian
This commit is contained in:
commit
ad30fff804
|
@ -79,7 +79,7 @@ After installing MindSpore via the official website, you can start training and
|
|||
bash ./scripts/run_standalone_train.sh 0
|
||||
|
||||
# run distributed training example
|
||||
bash ./scripts/run_distribute_train.sh rank_table.json
|
||||
bash ./scripts/run_distribute_train.sh /{path}/*.json
|
||||
|
||||
# run evaluation example
|
||||
python eval.py > eval.log 2>&1 &
|
||||
|
@ -87,12 +87,62 @@ After installing MindSpore via the official website, you can start training and
|
|||
bash ./script/run_eval.sh
|
||||
```
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
For distributed training, a hccl configuration file [RANK_TABLE_FILE] with JSON format needs to be created in advance.
|
||||
|
||||
Please follow the instructions in the link below:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||
|
||||
- Running on ModelArts
|
||||
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows.
|
||||
|
||||
- Training with 8 cards on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/tinydarknet" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/tinydarknet/train.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/tinydarknet/imagenet_config.yaml.
|
||||
# 1. Set ”batch_size: 64“ (not necessary)
|
||||
# 2. Set ”enable_modelarts: True“
|
||||
# 3. Set ”modelarts_dataset_unzip_name: {filenmae}",if the data is uploaded in the form of zip package.
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”batch_size=64“ (not necessary)
|
||||
# 2. Add ”enable_modelarts=True“
|
||||
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
|
||||
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of 8 cards.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
- evaluating with single card on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/not necessary" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/not necessary/eval.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/not necessary/imagenet_config.yaml.
|
||||
# 1. Set ”enable_modelarts: True“
|
||||
# 2. Set “checkpoint_path: {checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
|
||||
# 3. Add ”modelarts_dataset_unzip_name: {filenmae}",if the data is uploaded in the form of zip package.
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Set ”enable_modelarts=True“
|
||||
# 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
|
||||
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
|
||||
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of a single card.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
For more details, please refer the specify script.
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
@ -111,18 +161,23 @@ For more details, please refer the specify script.
|
|||
├── run_eval.sh // shell script for evaluation on Ascend
|
||||
├── run_infer_310.sh // shell script for inference on Ascend310
|
||||
├── src
|
||||
├─lr_scheduler //learning rate scheduler
|
||||
├─__init__.py // init
|
||||
├─linear_warmup.py // linear_warmup
|
||||
├─warmup_cosine_annealing_lr.py // warmup_cosine_annealing_lr
|
||||
├─warmup_step_lr.py // warmup_step_lr
|
||||
├── lr_scheduler //learning rate scheduler
|
||||
├── __init__.py // init
|
||||
├── linear_warmup.py // linear_warmup
|
||||
├── warmup_cosine_annealing_lr.py // warmup_cosine_annealing_lr
|
||||
├── warmup_step_lr.py // warmup_step_lr
|
||||
├── model_utils
|
||||
├── config.py // parsing parameter configuration file of "*.yaml"
|
||||
├── device_adapter.py // local or ModelArts training
|
||||
├── local_adapter.py // get related environment variables in local training
|
||||
└── moxing_adapter.py // get related environment variables in ModelArts training
|
||||
├── dataset.py // creating dataset
|
||||
├── CrossEntropySmooth.py // loss function
|
||||
├── tinydarknet.py // Tiny-Darknet architecture
|
||||
├──config.py // parameter configuration
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
├── export.py // export checkpoint file into air/onnx
|
||||
├── imagenet_config.yaml // parameter configuration
|
||||
├── mindspore_hub_conf.py // hub config
|
||||
├── postprocess.py // postprocess script
|
||||
|
||||
|
@ -130,40 +185,40 @@ For more details, please refer the specify script.
|
|||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
Parameters for both training and evaluation can be set in `imagenet_config.yaml`
|
||||
|
||||
- config for Tiny-Darknet
|
||||
- config for Tiny-Darknet(only some parameters are listed)
|
||||
|
||||
```python
|
||||
'pre_trained': 'False' # whether training based on the pre-trained model
|
||||
'num_classes': 1000 # the number of classes in the dataset
|
||||
'lr_init': 0.1 # initial learning rate
|
||||
'batch_size': 128 # training batch_size
|
||||
'epoch_size': 500 # total training epoch
|
||||
'momentum': 0.9 # momentum
|
||||
'weight_decay': 1e-4 # weight decay value
|
||||
'image_height': 224 # image height used as input to the model
|
||||
'image_width': 224 # image width used as input to the model
|
||||
'data_path': './ImageNet_Original/train/' # absolute full path to the train datasets
|
||||
'val_data_path': './ImageNet_Original/val/' # absolute full path to the evaluation datasets
|
||||
'device_target': 'Ascend' # device running the program
|
||||
'keep_checkpoint_max': 10 # only keep the last keep_checkpoint_max checkpoint
|
||||
'checkpoint_path': '/train_tinydarknet.ckpt' # the absolute full path to save the checkpoint file
|
||||
'onnx_filename': 'tinydarknet.onnx' # file name of the onnx model used in export.py
|
||||
'air_filename': 'tinydarknet.air' # file name of the air model used in export.py
|
||||
'lr_scheduler': 'exponential' # learning rate scheduler
|
||||
'lr_epochs': [70, 140, 210, 280] # epoch of lr changing
|
||||
'lr_gamma': 0.3 # decrease lr by a factor of exponential lr_scheduler
|
||||
'eta_min': 0.0 # eta_min in cosine_annealing scheduler
|
||||
'T_max': 150 # T-max in cosine_annealing scheduler
|
||||
'warmup_epochs': 0 # warmup epoch
|
||||
'is_dynamic_loss_scale': 0 # dynamic loss scale
|
||||
'loss_scale': 1024 # loss scale
|
||||
'label_smooth_factor': 0.1 # label_smooth_factor
|
||||
'use_label_smooth': True # label smooth
|
||||
pre_trained: False # whether training based on the pre-trained model
|
||||
num_classes: 1000 # the number of classes in the dataset
|
||||
lr_init: 0.1 # initial learning rate
|
||||
batch_size: 128 # training batch_size
|
||||
epoch_size: 500 # total training epoch
|
||||
momentum: 0.9 # momentum
|
||||
weight_decay: 1e-4 # weight decay value
|
||||
image_height: 224 # image height used as input to the model
|
||||
image_width: 224 # image width used as input to the model
|
||||
train_data_dir: './ImageNet_Original/train/' # absolute full path to the train datasets
|
||||
val_data_dir: './ImageNet_Original/val/' # absolute full path to the evaluation datasets
|
||||
device_target: 'Ascend' # device running the program
|
||||
keep_checkpoint_max: 10 # only keep the last keep_checkpoint_max checkpoint
|
||||
checkpoint_path: '/train_tinydarknet.ckpt' # the absolute full path to save the checkpoint file
|
||||
onnx_filename: 'tinydarknet.onnx' # file name of the onnx model used in export.py
|
||||
air_filename: 'tinydarknet.air' # file name of the air model used in export.py
|
||||
lr_scheduler: 'exponential' # learning rate scheduler
|
||||
lr_epochs: [70, 140, 210, 280] # epoch of lr changing
|
||||
lr_gamma: 0.3 # decrease lr by a factor of exponential lr_scheduler
|
||||
eta_min: 0.0 # eta_min in cosine_annealing scheduler
|
||||
T_max: 150 # T-max in cosine_annealing scheduler
|
||||
warmup_epochs: 0 # warmup epoch
|
||||
is_dynamic_loss_scale: 0 # dynamic loss scale
|
||||
loss_scale: 1024 # loss scale
|
||||
label_smooth_factor: 0.1 # label_smooth_factor
|
||||
use_label_smooth: True # label smooth
|
||||
```
|
||||
|
||||
For more configuration details, please refer the script config.py.
|
||||
For more configuration details, please refer the script `imagenet_config.yaml`.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
|
@ -172,7 +227,7 @@ For more configuration details, please refer the script config.py.
|
|||
- running on Ascend:
|
||||
|
||||
```python
|
||||
bash scripts/run_standalone_train.sh 0
|
||||
bash scripts/run_standalone_train.sh [DEVICE_ID]
|
||||
```
|
||||
|
||||
The command above will run in the background, you can view the results through the file train.log.
|
||||
|
@ -199,7 +254,7 @@ For more configuration details, please refer the script config.py.
|
|||
- running on Ascend:
|
||||
|
||||
```python
|
||||
bash ./scripts/run_distribute_train.sh rank_table.json
|
||||
bash ./scripts/run_distribute_train.sh [RANK_TABLE_FILE]
|
||||
```
|
||||
|
||||
The above shell script will run distribute training in the background. You can view the results through the file train_parallel[X]/log. The loss value will be achieved as follows:
|
||||
|
@ -252,7 +307,7 @@ For more configuration details, please refer the script config.py.
|
|||
python export.py --dataset [DATASET] --file_name [FILE_NAME] --file_format [EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
The parameter does not have the ckpt_file option. Please store the ckpt file according to the path of the parameter `checkpoint_path` in `config.py`.
|
||||
The parameter does not have the ckpt_file option. Please store the ckpt file according to the path of the parameter `checkpoint_path` in `imagenet_config.yaml`.
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
|
||||
### Infer on Ascend310
|
||||
|
|
|
@ -87,7 +87,7 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
bash ./scripts/run_standalone_train.sh 0
|
||||
|
||||
# 分布式训练
|
||||
bash ./scripts/run_distribute_train.sh rank_table.json
|
||||
bash ./scripts/run_distribute_train.sh /{path}/*.json
|
||||
|
||||
# 评估
|
||||
python eval.py > eval.log 2>&1 &
|
||||
|
@ -95,12 +95,61 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
bash ./script/run_eval.sh
|
||||
```
|
||||
|
||||
进行并行训练时, 需要提前创建JSON格式的hccl配置文件。
|
||||
进行并行训练时, 需要提前创建JSON格式的hccl配置文件 [RANK_TABLE_FILE]。
|
||||
|
||||
请按照以下链接的指导进行设置:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||
|
||||
- 在ModelArts上运行
|
||||
如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
- 在ModelArt上使用8卡训练
|
||||
|
||||
```python
|
||||
# (1) 上传你的代码到 s3 桶上
|
||||
# (2) 在ModelArts上创建训练任务
|
||||
# (3) 选择代码目录 /{path}/tinydarknet
|
||||
# (4) 选择启动文件 /{path}/tinydarknet/train.py
|
||||
# (5) 执行a或b
|
||||
# a. 在 /{path}/tinydarknet/imagenet_config.yaml 文件中设置参数
|
||||
# 1. 设置 ”batch_size: 64“ (非必须)
|
||||
# 2. 设置 ”enable_modelarts: True“
|
||||
# 3. 如果数据采用zip格式压缩包的形式上传,设置 ”modelarts_dataset_unzip_name: {filenmae}"
|
||||
# b. 在 网页上设置
|
||||
# 1. 添加 ”batch_size=True“
|
||||
# 2. 添加 ”enable_modelarts=True“
|
||||
# 3. 如果数据采用zip格式压缩包的形式上传,添加 ”modelarts_dataset_unzip_name={filenmae}"
|
||||
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
|
||||
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包)
|
||||
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
|
||||
# (9) 在网页上的’资源池选择‘项目下, 选择8卡规格的资源
|
||||
# (10) 创建训练作业
|
||||
```
|
||||
|
||||
- 在ModelArts上使用单卡验证
|
||||
|
||||
```python
|
||||
# (1) 上传你的代码到 s3 桶上
|
||||
# (2) 在ModelArts上创建训练任务
|
||||
# (3) 选择代码目录 /{path}/tinydarknet
|
||||
# (4) 选择启动文件 /{path}/tinydarknet/eval.py
|
||||
# (5) 执行a或b
|
||||
# a. 在 /path/tinydarknet 下的imagenet_config.yaml 文件中设置参数
|
||||
# 1. 设置 ”enable_modelarts: True“
|
||||
# 2. 设置 “checkpoint_path: {checkpoint_path}”({checkpoint_path}表示待评估的 权重文件 相对于 eval.py 的路径,权重文件须包含在代码目录下。)
|
||||
# 3. 如果数据采用zip格式压缩包的形式上传,设置 ”modelarts_dataset_unzip_name: {filenmae}"
|
||||
# b. 在 网页上设置
|
||||
# 1. 设置 ”enable_modelarts=True“
|
||||
# 2. 设置 “checkpoint_path={checkpoint_path}”({checkpoint_path}表示待评估的 权重文件 相对于 eval.py 的路径,权重文件须包含在代码目录下。)
|
||||
# 3. 如果数据采用zip格式压缩包的形式上传,设置 ”modelarts_dataset_unzip_name={filenmae}"
|
||||
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
|
||||
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包)
|
||||
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
|
||||
# (9) 在网页上的’资源池选择‘项目下, 选择单卡规格的资源
|
||||
# (10) 创建训练作业
|
||||
```
|
||||
|
||||
更多的细节请参考具体的script文件
|
||||
|
||||
# [脚本描述](#目录)
|
||||
|
@ -117,61 +166,66 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
├── run_standalone_train.sh // Ascend单卡训练shell脚本
|
||||
├── run_distribute_train.sh // Ascend分布式训练shell脚本
|
||||
├── run_eval.sh // Ascend评估shell脚本
|
||||
├──run_infer_310.sh // Ascend310推理shell脚本
|
||||
└── run_infer_310.sh // Ascend310推理shell脚本
|
||||
├── src
|
||||
├─lr_scheduler //学习率策略
|
||||
├─__init__.py // 初始化文件
|
||||
├─linear_warmup.py // linear_warmup策略
|
||||
├─warmup_cosine_annealing_lr.py // warmup_cosine_annealing_lr策略
|
||||
├─warmup_step_lr.py // warmup_step_lr策略
|
||||
├── lr_scheduler // 学习率策略
|
||||
├── __init__.py // 初始化文件
|
||||
├── linear_warmup.py // linear_warmup策略
|
||||
├── warmup_cosine_annealing_lr.py // warmup_cosine_annealing_lr策略
|
||||
└── warmup_step_lr.py // warmup_step_lr策略
|
||||
├── model_utils
|
||||
├── config.py // 解析 *.yaml 参数配置文件
|
||||
├── device_adapter.py // 区分本地/ModelArts训练
|
||||
├── local_adapter.py // 本地训练获取相关环境变量
|
||||
└── moxing_adapter.py // ModelArts训练获取相关环境变量、交换数据
|
||||
├── dataset.py // 创建数据集
|
||||
├── CrossEntropySmooth.py // 损失函数
|
||||
├──tinydarknet.py // Tiny-Darknet网络结构
|
||||
├──config.py // 参数配置
|
||||
└── tinydarknet.py // Tiny-Darknet网络结构
|
||||
├── train.py // 训练脚本
|
||||
├── eval.py // 评估脚本
|
||||
├── export.py // 导出checkpoint文件
|
||||
├── imagenet_config.yaml // 参数配置
|
||||
├── mindspore_hub_conf.py // hub配置文件
|
||||
├── postprocess.py // 310推理后处理脚本
|
||||
└── postprocess.py // 310推理后处理脚本
|
||||
|
||||
```
|
||||
|
||||
## [脚本参数](#目录)
|
||||
|
||||
训练和测试的参数可在 config.py 中进行设置
|
||||
训练和测试的参数可在 `imagenet_config.yaml` 中进行设置
|
||||
|
||||
- Tiny-Darknet的配置文件
|
||||
- Tiny-Darknet的配置文件(仅列出部分参数)
|
||||
|
||||
```python
|
||||
'pre_trained': 'False' # 是否载入预训练模型
|
||||
'num_classes': 1000 # 数据集中类的数量
|
||||
'lr_init': 0.1 # 初始学习率
|
||||
'batch_size': 128 # 训练的batch_size
|
||||
'epoch_size': 500 # 总共的训练epoch
|
||||
'momentum': 0.9 # 动量
|
||||
'weight_decay': 1e-4 # 权重衰减率
|
||||
'image_height': 224 # 输入图像的高度
|
||||
'image_width': 224 # 输入图像的宽度
|
||||
'data_path': './ImageNet_Original/train/' # 训练数据集的绝对路径
|
||||
'val_data_path': './ImageNet_Original/val/' # 评估数据集的绝对路径
|
||||
'device_target': 'Ascend' # 程序运行的设备
|
||||
'keep_checkpoint_max': 10 # 仅仅保持最新的keep_checkpoint_max个checkpoint文件
|
||||
'checkpoint_path': '/train_tinydarknet.ckpt' # 保存checkpoint文件的绝对路径
|
||||
'onnx_filename': 'tinydarknet.onnx' # 用于export.py 文件中的onnx模型的文件名
|
||||
'air_filename': 'tinydarknet.air' # 用于export.py 文件中的air模型的文件名
|
||||
'lr_scheduler': 'exponential' # 学习率策略
|
||||
'lr_epochs': [70, 140, 210, 280] # 学习率进行变化的epoch数
|
||||
'lr_gamma': 0.3 # lr_scheduler为exponential时的学习率衰减因子
|
||||
'eta_min': 0.0 # cosine_annealing策略中的eta_min
|
||||
'T_max': 150 # cosine_annealing策略中的T-max
|
||||
'warmup_epochs': 0 # 热启动的epoch数
|
||||
'is_dynamic_loss_scale': 0 # 动态损失尺度
|
||||
'loss_scale': 1024 # 损失尺度
|
||||
'label_smooth_factor': 0.1 # 训练标签平滑因子
|
||||
'use_label_smooth': True # 是否采用训练标签平滑
|
||||
pre_trained: False # 是否载入预训练模型
|
||||
num_classes: 1000 # 数据集中类的数量
|
||||
lr_init: 0.1 # 初始学习率
|
||||
batch_size: 128 # 训练的batch_size
|
||||
epoch_size: 500 # 总共的训练epoch
|
||||
momentum: 0.9 # 动量
|
||||
weight_decay: 1e-4 # 权重衰减率
|
||||
image_height: 224 # 输入图像的高度
|
||||
image_width: 224 # 输入图像的宽度
|
||||
train_data_dir: './ImageNet_Original/train/' # 训练数据集的绝对路径
|
||||
val_data_dir: './ImageNet_Original/val/' # 评估数据集的绝对路径
|
||||
device_target: 'Ascend' # 程序运行的设备
|
||||
keep_checkpoint_max: 10 # 仅仅保持最新的keep_checkpoint_max个checkpoint文件
|
||||
checkpoint_path: '/train_tinydarknet.ckpt' # 保存checkpoint文件的绝对路径
|
||||
onnx_filename: 'tinydarknet.onnx' # 用于export.py 文件中的onnx模型的文件名
|
||||
air_filename: 'tinydarknet.air' # 用于export.py 文件中的air模型的文件名
|
||||
lr_scheduler: 'exponential' # 学习率策略
|
||||
lr_epochs: [70, 140, 210, 280] # 学习率进行变化的epoch数
|
||||
lr_gamma: 0.3 # lr_scheduler为exponential时的学习率衰减因子
|
||||
eta_min: 0.0 # cosine_annealing策略中的eta_min
|
||||
T_max: 150 # cosine_annealing策略中的T-max
|
||||
warmup_epochs: 0 # 热启动的epoch数
|
||||
is_dynamic_loss_scale: 0 # 动态损失尺度
|
||||
loss_scale: 1024 # 损失尺度
|
||||
label_smooth_factor: 0.1 # 训练标签平滑因子
|
||||
use_label_smooth: True # 是否采用训练标签平滑
|
||||
```
|
||||
|
||||
更多的细节, 请参考`config.py`.
|
||||
更多的细节, 请参考`imagenet_config.yaml`.
|
||||
|
||||
## [训练过程](#目录)
|
||||
|
||||
|
@ -180,7 +234,7 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
- 在Ascend资源上运行:
|
||||
|
||||
```python
|
||||
bash ./scripts/run_standalone_train.sh 0
|
||||
bash ./scripts/run_standalone_train.sh [DEVICE_ID]
|
||||
```
|
||||
|
||||
上述的命令将运行在后台中,可以通过 `train.log` 文件查看运行结果.
|
||||
|
@ -207,7 +261,7 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
- 在Ascend资源上运行:
|
||||
|
||||
```python
|
||||
bash scripts/run_distribute_train.sh rank_table.json
|
||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE]
|
||||
```
|
||||
|
||||
上述的脚本命令将在后台中进行分布式训练,可以通过`train_parallel[X]/log`文件查看运行结果. 训练的损失值将以如下的形式展示:
|
||||
|
@ -229,7 +283,7 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
|
||||
- 在Ascend资源上进行评估:
|
||||
|
||||
在运行如下命令前,请确认用于评估的checkpoint文件的路径.请将checkpoint路径设置为绝对路径,例如:"/username/imagenet/train_tinydarknet.ckpt"
|
||||
在运行如下命令前,请确认用于评估的checkpoint文件的路径.checkpoint文件须包含在tinydarknet文件夹内.请将checkpoint路径设置为相对于 eval.py文件 的路径,例如:"./ckpts/train_tinydarknet.ckpt"(ckpts 与 eval.py 同级).
|
||||
|
||||
```python
|
||||
python eval.py > eval.log 2>&1 &
|
||||
|
@ -259,7 +313,7 @@ Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的
|
|||
python export.py --dataset [DATASET] --file_name [FILE_NAME] --file_format [EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
参数没有ckpt_file选项,ckpt文件请按照`config.py`中参数`checkpoint_path`的路径存放。
|
||||
参数没有ckpt_file选项,ckpt文件请按照`imagenet_config.yaml`中参数`checkpoint_path`的路径存放。
|
||||
`EXPORT_FORMAT` 可选 ["AIR", "MINDIR"].
|
||||
|
||||
### 在Ascend310执行推理
|
||||
|
|
|
@ -16,34 +16,82 @@
|
|||
##############test tinydarknet example on cifar10#################
|
||||
python eval.py
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from mindspore import context
|
||||
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import imagenet_cfg
|
||||
from src.dataset import create_dataset_imagenet
|
||||
|
||||
from src.tinydarknet import TinyDarkNet
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_num
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='tinydarknet')
|
||||
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
args_opt = parser.parse_args()
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("Unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
if config.modelarts_dataset_unzip_name:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
if args_opt.dataset_name == "imagenet":
|
||||
cfg = imagenet_cfg
|
||||
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if config.device_id % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(config.device_id, zip_file_1, save_dir_1))
|
||||
config.checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.checkpoint_path)
|
||||
if not os.path.exists(config.checkpoint_path):
|
||||
raise ValueError("Check parameter 'checkpoint_path'. for more details, you can see README.md")
|
||||
config.val_data_dir = config.data_path
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
if config.dataset_name == "imagenet":
|
||||
cfg = config
|
||||
dataset = create_dataset_imagenet(cfg.val_data_dir, 1, False)
|
||||
if not cfg.use_label_smooth:
|
||||
cfg.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
|
@ -52,20 +100,19 @@ if __name__ == '__main__':
|
|||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
raise ValueError("Dataset is not support.")
|
||||
|
||||
device_target = cfg.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
|
||||
if args_opt.checkpoint_path is not None:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
print("load checkpoint from [{}].".format(args_opt.checkpoint_path))
|
||||
else:
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
print("load checkpoint from [{}].".format(cfg.checkpoint_path))
|
||||
print("Load checkpoint from [{}].".format(cfg.checkpoint_path))
|
||||
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
acc = model.eval(dataset)
|
||||
print("accuracy: ", acc)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_eval()
|
||||
|
|
|
@ -23,7 +23,7 @@ import mindspore as ms
|
|||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import imagenet_cfg
|
||||
from src.model_utils.config import config as imagenet_cfg
|
||||
from src.tinydarknet import TinyDarkNet
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -38,7 +38,7 @@ if __name__ == '__main__':
|
|||
if args_opt.dataset_name == 'imagenet':
|
||||
cfg = imagenet_cfg
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
raise ValueError("Dataset is not support.")
|
||||
|
||||
net = TinyDarkNet(num_classes=cfg.num_classes)
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
# ==============================================================================
|
||||
#train-eval-export related
|
||||
dataset_name: imagenet
|
||||
ckpt_save_dir: checkpoints
|
||||
pre_trained: False
|
||||
device_id: 0
|
||||
num_classes: 1000
|
||||
lr_init: 0.1
|
||||
batch_size: 128
|
||||
epoch_size: 500
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001
|
||||
image_height: 224
|
||||
image_width: 224
|
||||
train_data_dir: './dataset/imagenet_original/train/'
|
||||
val_data_dir: './dataset/imagenet_original/val/'
|
||||
keep_checkpoint_max: 1
|
||||
checkpoint_path: './scripts/train_parallel4/ckpt_4/train_tinydarknet_imagenet-300_1251.ckpt'
|
||||
onnx_filename: 'tinydarknet.onnx'
|
||||
air_filename: 'tinydarknet.air'
|
||||
# optimizer and lr related
|
||||
lr_scheduler: 'exponential'
|
||||
lr_epochs: [70, 140, 210, 280]
|
||||
lr_gamma: 0.3
|
||||
eta_min: 0.0
|
||||
T_max: 150
|
||||
warmup_epochs: 0
|
||||
# loss related
|
||||
is_dynamic_loss_scale: False
|
||||
loss_scale: 1024
|
||||
label_smooth_factor: 0.1
|
||||
use_label_smooth: True
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
|
@ -28,7 +28,6 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
dataset_type='imagenet'
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
|
@ -56,11 +55,12 @@ do
|
|||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp ../train.py ./train_parallel$i
|
||||
cp ../*.yaml ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
|
||||
cd ./train_parallel$i || exit
|
||||
env > env.log
|
||||
python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 &
|
||||
python train.py --dataset_name=$dataset_type > log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -14,6 +14,16 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
abs_path=$(readlink -f "$0")
|
||||
cur_path=$(dirname $abs_path)
|
||||
cd $cur_path
|
||||
|
||||
rm -rf ./eval
|
||||
mkdir ./eval
|
||||
python ./eval.py > ./eval/eval.log 2>&1 &
|
||||
cp -r ../src ./eval
|
||||
cp ../eval.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
python ./eval.py > ./eval.log 2>&1 &
|
||||
cd ..
|
||||
|
|
|
@ -14,10 +14,48 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "$1 $2 $3"
|
||||
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_distribute_train.sh [DEVICE_ID] [TRAIN_DATA_DIR] [cifar10|imagenet]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
expr $1 + 6 &>/dev/null
|
||||
if [ $? != 0 ]
|
||||
then
|
||||
echo "error:DEVICE_ID=$1 is not a integer"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $2 ]
|
||||
then
|
||||
echo "error:TRAIN_DATA_DIR=$2 is not a folder"
|
||||
exit 1
|
||||
fi
|
||||
train_data_dir=$2
|
||||
|
||||
dataset_type='imagenet'
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
if [ $3 != "cifar10" ] && [ $3 != "imagenet" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet"
|
||||
exit 1
|
||||
fi
|
||||
dataset_type=$3
|
||||
fi
|
||||
|
||||
export DEVICE_ID=$1
|
||||
export RANK_ID=0
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=1
|
||||
rm -rf ./train_single
|
||||
mkdir ./train_single
|
||||
cp -r ./src ./train_single
|
||||
cp ./train.py ./train_single
|
||||
cp -r ../src ./train_single
|
||||
cp ../train.py ./train_single
|
||||
cp ../*.yaml ./train_single
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
|
||||
cd ./train_single || exit
|
||||
python ./train.py --device_id=$DEVICE_ID > ./train.log 2>&1 &
|
||||
python ./train.py --dataset_name=$dataset_type --train_data_dir=$train_data_dir> ./train.log 2>&1 &
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
imagenet_cfg = edict({
|
||||
'name': 'imagenet',
|
||||
'pre_trained': False,
|
||||
'num_classes': 1000,
|
||||
'lr_init': 0.1,
|
||||
'batch_size': 128,
|
||||
'epoch_size': 500,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 1e-4,
|
||||
'image_height': 224,
|
||||
'image_width': 224,
|
||||
'data_path': './dataset/imagenet_original/train/',
|
||||
'val_data_path': './dataset/imagenet_original/val/',
|
||||
'device_target': 'Ascend',
|
||||
'keep_checkpoint_max': 1,
|
||||
'checkpoint_path': './scripts/train_parallel4/ckpt_4/train_tinydarknet_imagenet-300_1251.ckpt',
|
||||
'onnx_filename': 'tinydarknet.onnx',
|
||||
'air_filename': 'tinydarknet.air',
|
||||
|
||||
# optimizer and lr related
|
||||
'lr_scheduler': 'exponential',
|
||||
'lr_epochs': [70, 140, 210, 280],
|
||||
'lr_gamma': 0.3,
|
||||
'eta_min': 0.0,
|
||||
'T_max': 150,
|
||||
'warmup_epochs': 0,
|
||||
|
||||
# loss related
|
||||
'is_dynamic_loss_scale': False,
|
||||
'loss_scale': 1024,
|
||||
'label_smooth_factor': 0.1,
|
||||
'use_label_smooth': True,
|
||||
})
|
|
@ -21,7 +21,7 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from src.config import imagenet_cfg
|
||||
from src.model_utils.config import config as imagenet_cfg
|
||||
|
||||
def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
|
||||
num_parallel_workers=None, shuffle=None):
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
_config = "imagenet_config.yaml"
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path=_config):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../{}".format(_config)),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(config)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
# print("os.mknod({}) success".format(sync_lock))
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -17,11 +17,11 @@
|
|||
python train.py
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||
|
@ -30,14 +30,15 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import imagenet_cfg
|
||||
from src.dataset import create_dataset_imagenet
|
||||
from src.tinydarknet import TinyDarkNet
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def lr_steps_imagenet(_cfg, steps_per_epoch):
|
||||
"""lr step for imagenet"""
|
||||
from src.lr_scheduler.warmup_step_lr import warmup_step_lr
|
||||
|
@ -62,54 +63,104 @@ def lr_steps_imagenet(_cfg, steps_per_epoch):
|
|||
|
||||
return _lr
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("Unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Classification')
|
||||
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
if config.modelarts_dataset_unzip_name:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
if args_opt.dataset_name == "imagenet":
|
||||
cfg = imagenet_cfg
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
config.ckpt_save_dir = os.path.join(config.output_path, config.ckpt_save_dir)
|
||||
config.train_data_dir = config.data_path
|
||||
config.checkpoint_path = config.load_path
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
if config.dataset_name == "imagenet":
|
||||
pass
|
||||
elif config.dataset_name == "cifar10":
|
||||
raise ValueError("Unsupported dataset: 'cifar10'.")
|
||||
else:
|
||||
raise ValueError("Unsupported dataset.")
|
||||
|
||||
# set context
|
||||
device_target = cfg.device_target
|
||||
device_target = config.device_target
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
device_num = get_device_num()
|
||||
|
||||
rank = 0
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
if device_num > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
rank = get_rank()
|
||||
rank = get_rank_id()
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
if args_opt.dataset_name == "imagenet":
|
||||
dataset = create_dataset_imagenet(cfg.data_path, 1)
|
||||
if config.dataset_name == "imagenet":
|
||||
dataset = create_dataset_imagenet(config.train_data_dir, 1)
|
||||
else:
|
||||
raise ValueError("Unsupported dataset.")
|
||||
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
net = TinyDarkNet(num_classes=cfg.num_classes)
|
||||
net = TinyDarkNet(num_classes=config.num_classes)
|
||||
# Continue training if set pre_trained to be True
|
||||
if cfg.pre_trained:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
if config.pre_trained:
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
loss_scale_manager = None
|
||||
if args_opt.dataset_name == 'imagenet':
|
||||
lr = lr_steps_imagenet(cfg, batch_num)
|
||||
if config.dataset_name == 'imagenet':
|
||||
lr = lr_steps_imagenet(config, batch_num)
|
||||
|
||||
def get_param_groups(network):
|
||||
""" get param groups """
|
||||
|
@ -132,32 +183,35 @@ if __name__ == '__main__':
|
|||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||
|
||||
|
||||
if cfg.is_dynamic_loss_scale:
|
||||
cfg.loss_scale = 1
|
||||
if config.is_dynamic_loss_scale:
|
||||
config.loss_scale = 1
|
||||
|
||||
opt = Momentum(params=get_param_groups(net),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=cfg.momentum,
|
||||
weight_decay=cfg.weight_decay,
|
||||
loss_scale=cfg.loss_scale)
|
||||
if not cfg.use_label_smooth:
|
||||
cfg.label_smooth_factor = 0.0
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay,
|
||||
loss_scale=config.loss_scale)
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
|
||||
|
||||
if cfg.is_dynamic_loss_scale:
|
||||
if config.is_dynamic_loss_scale:
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
||||
else:
|
||||
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
|
||||
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O3", loss_scale_manager=loss_scale_manager)
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 50, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 50, keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
ckpt_save_dir = "./ckpt_" + str(rank) + "/"
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_tinydarknet_" + args_opt.dataset_name, directory=ckpt_save_dir,
|
||||
ckpt_save_dir = os.path.join(config.ckpt_save_dir, str(rank))
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_tinydarknet_" + config.dataset_name, directory=ckpt_save_dir,
|
||||
config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||
model.train(config.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||
print("train success")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
||||
|
|
|
@ -119,7 +119,7 @@ The dataset is self-generated using a third-party library called [captcha](https
|
|||
|
||||
- running on ModelArts
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
|
||||
- 在ModelArt上使用8卡训练
|
||||
- Training with 8 cards on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
|
@ -138,10 +138,11 @@ The dataset is self-generated using a third-party library called [captcha](https
|
|||
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Create your job.
|
||||
# (9) Under the item "resource pool selection", select the specification of 8 cards.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
- 在ModelArts上使用单卡验证
|
||||
- evaluating with single card on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
|
@ -160,7 +161,8 @@ The dataset is self-generated using a third-party library called [captcha](https
|
|||
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Create your job.
|
||||
# (9) Under the item "resource pool selection", select the specification of a single card.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
|
|
@ -122,7 +122,7 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
|||
```
|
||||
|
||||
- 在ModelArts上运行
|
||||
如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
|
||||
如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/)
|
||||
- 在ModelArt上使用8卡训练
|
||||
|
||||
```python
|
||||
|
@ -142,7 +142,8 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
|||
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
|
||||
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包)
|
||||
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
|
||||
# (9) 创建训练作业
|
||||
# (9) 在网页上的’资源池选择‘项目下, 选择8卡规格的资源
|
||||
# (10) 创建训练作业
|
||||
```
|
||||
|
||||
- 在ModelArts上使用单卡验证
|
||||
|
@ -164,7 +165,8 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
|||
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
|
||||
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包)
|
||||
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
|
||||
# (9) 创建训练作业
|
||||
# (9) 在网页上的’资源池选择‘项目下, 选择单卡规格的资源
|
||||
# (10) 创建训练作业
|
||||
```
|
||||
|
||||
## 脚本说明
|
||||
|
|
|
@ -42,15 +42,15 @@ def modelarts_pre_process():
|
|||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
print("Unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
|
|
|
@ -27,7 +27,7 @@ if config.device_target == "Ascend":
|
|||
context.set_context(device_id=get_device_id())
|
||||
|
||||
if config.file_format == "AIR" and config.device_target != "Ascend":
|
||||
raise ValueError("export AIR must on Ascend")
|
||||
raise ValueError("Export AIR must on Ascend")
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_size = m.ceil(config.captcha_height / 64) * 64 * 3
|
||||
|
|
|
@ -49,15 +49,15 @@ def modelarts_pre_process():
|
|||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
print("Unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
|
@ -88,7 +88,7 @@ def modelarts_pre_process():
|
|||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
config.save_checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.save_checkpoint_path)
|
||||
config.save_checkpoint_path = os.path.join(config.output_path, config.save_checkpoint_path)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
|
@ -100,16 +100,22 @@ def train():
|
|||
lr_scale = 1
|
||||
if config.run_distribute:
|
||||
if config.device_target == 'Ascend':
|
||||
device_num = int(os.environ.get("RANK_SIZE"))
|
||||
rank = int(os.environ.get("RANK_ID"))
|
||||
device_num = get_device_num()
|
||||
rank = get_rank_id()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
init()
|
||||
device_num = get_group_size()
|
||||
rank = get_rank()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
|
||||
else:
|
||||
device_num = 1
|
||||
rank = 0
|
||||
|
@ -149,7 +155,7 @@ def train():
|
|||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = config.save_checkpoint_path
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, str(rank))
|
||||
ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=save_ckpt_path, config=config_ck)
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(config.epoch_size, dataset, callbacks=callbacks)
|
||||
|
|
Loading…
Reference in New Issue