forked from mindspore-Ecosystem/mindspore
crnn_seq2seq_ocr merge
This commit is contained in:
parent
18e3180ca4
commit
ebae2fb6a5
|
@ -104,14 +104,60 @@ The dataset is self-generated using a third-party library called [captcha](https
|
|||
# training example on CPU
|
||||
$ bash run_standalone_train.sh ../data/train CPU
|
||||
or
|
||||
python train.py --dataset_path=./data/train --platform=CPU
|
||||
python train.py --train_data_dir=./data/train --device_target=CPU
|
||||
|
||||
# evaluation example on CPU
|
||||
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU
|
||||
or
|
||||
python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU
|
||||
python eval.py --test_data_dir=./data/test --checkpoint_path=warpctc-30-97.ckpt --device_target=CPU
|
||||
```
|
||||
|
||||
- running on ModelArts
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
|
||||
- 在ModelArt上使用8卡训练
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/warpctc" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/warpctc/train.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/warpctc/default_config.yaml.
|
||||
# 1. Set ”run_distributed=True“
|
||||
# 2. Set ”enable_modelarts=True“
|
||||
# 3. Set ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”run_distributed=True“
|
||||
# 2. Add ”enable_modelarts=True“
|
||||
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
|
||||
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Create your job.
|
||||
```
|
||||
|
||||
- 在ModelArts上使用单卡验证
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/warpctc" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/warpctc/eval.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. 在 /path/warpctc 下的default_config.yaml 文件中设置参数
|
||||
# 1. Set ”enable_modelarts=True“
|
||||
# 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
|
||||
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
|
||||
# b. 在 网页上设置
|
||||
# 1. Set ”enable_modelarts=True“
|
||||
# 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
|
||||
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
|
||||
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Create your job.
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
@ -119,7 +165,8 @@ The dataset is self-generated using a third-party library called [captcha](https
|
|||
```shell
|
||||
.
|
||||
└──warpctc
|
||||
├── README.md
|
||||
├── README.md # descriptions of warpctc
|
||||
├── README_CN.md # chinese descriptions of warpctc
|
||||
├── script
|
||||
├── run_distribute_train.sh # launch distributed training in Ascend(8 pcs)
|
||||
├── run_distribute_train_for_gpu.sh # launch distributed training in GPU
|
||||
|
@ -127,13 +174,19 @@ The dataset is self-generated using a third-party library called [captcha](https
|
|||
├── run_process_data.sh # launch dataset generation
|
||||
└── run_standalone_train.sh # launch standalone training(1 pcs)
|
||||
├── src
|
||||
├── config.py # parameter configuration
|
||||
├── model_utils
|
||||
├── config.py # parsing parameter configuration file of "*.yaml"
|
||||
├── devcie_adapter.py # local or ModelArts training
|
||||
├── local_adapter.py # get related environment variables in local training
|
||||
└── moxing_adapter.py # get related environment variables in ModelArts training
|
||||
├── dataset.py # data preprocessing
|
||||
├── loss.py # ctcloss definition
|
||||
├── lr_generator.py # generate learning rate for each step
|
||||
├── metric.py # accuracy metric for warpctc network
|
||||
├── warpctc.py # warpctc network definition
|
||||
└── warpctc_for_train.py # warpctc network with grad, loss and gradient clip
|
||||
├── default_config.yaml # parameter configuration
|
||||
├── export.py # inference
|
||||
├── mindspore_hub_conf.py # mindspore hub interface
|
||||
├── eval.py # eval net
|
||||
├── process_data.py # dataset generation script
|
||||
|
@ -146,13 +199,13 @@ The dataset is self-generated using a third-party library called [captcha](https
|
|||
|
||||
```bash
|
||||
# distributed training in Ascend
|
||||
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
|
||||
# distributed training in GPU
|
||||
Usage: bash run_distribute_train_for_gpu.sh [RANK_SIZE] [DATASET_PATH]
|
||||
Usage: bash run_distribute_train_for_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
|
||||
|
||||
# standalone training
|
||||
Usage: bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
|
||||
Usage: bash run_standalone_train.sh [TRAIN_DATA_DIR] [DEVICE_TARGET]
|
||||
```
|
||||
|
||||
#### Parameters Configuration
|
||||
|
@ -160,18 +213,18 @@ Usage: bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
|
|||
Parameters for both training and evaluation can be set in config.py.
|
||||
|
||||
```bash
|
||||
"max_captcha_digits": 4, # max number of digits in each
|
||||
"captcha_width": 160, # width of captcha images
|
||||
"captcha_height": 64, # height of capthca images
|
||||
"batch_size": 64, # batch size of input tensor
|
||||
"epoch_size": 30, # only valid for taining, which is always 1 for inference
|
||||
"hidden_size": 512, # hidden size in LSTM layers
|
||||
"learning_rate": 0.01, # initial learning rate
|
||||
"momentum": 0.9 # momentum of SGD optimizer
|
||||
"save_checkpoint": True, # whether save checkpoint or not
|
||||
"save_checkpoint_steps": 97, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
|
||||
"keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max checkpoint
|
||||
"save_checkpoint_path": "./checkpoint", # path to save checkpoint
|
||||
max_captcha_digits: 4 # max number of digits in each
|
||||
captcha_width: 160 # width of captcha images
|
||||
captcha_height: 64 # height of capthca images
|
||||
batch_size: 64 # batch size of input tensor
|
||||
epoch_size: 30 # only valid for taining, which is always 1 for inference
|
||||
hidden_size: 512 # hidden size in LSTM layers
|
||||
learning_rate: 0.01 # initial learning rate
|
||||
momentum: 0.9 # momentum of SGD optimizer
|
||||
save_checkpoint: True # whether save checkpoint or not
|
||||
save_checkpoint_steps: 97 # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
|
||||
keep_checkpoint_max: 30 # only keep the last keep_checkpoint_max checkpoint
|
||||
save_checkpoint_path: "./checkpoint" # path to save checkpoint
|
||||
```
|
||||
|
||||
## [Dataset Preparation](#contents)
|
||||
|
@ -180,14 +233,14 @@ Parameters for both training and evaluation can be set in config.py.
|
|||
|
||||
### [Training Process](#contents)
|
||||
|
||||
- Set options in `config.py`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
||||
- Set options in `default_config.yaml`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
||||
|
||||
#### [Training](#contents)
|
||||
|
||||
- Run `run_standalone_train.sh` for non-distributed training of WarpCTC model, either on Ascend or on GPU.
|
||||
|
||||
``` bash
|
||||
bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
|
||||
bash run_standalone_train.sh [TRAIN_DATA_DIR] [DEVICE_TARGET]
|
||||
```
|
||||
|
||||
##### [Distributed Training](#contents)
|
||||
|
@ -195,13 +248,13 @@ bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
|
|||
- Run `run_distribute_train.sh` for distributed training of WarpCTC model on Ascend.
|
||||
|
||||
``` bash
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
- Run `run_distribute_train_gpu.sh` for distributed training of WarpCTC model on GPU.
|
||||
|
||||
``` bash
|
||||
bash run_distribute_train_gpu.sh [RANK_SIZE] [DATASET_PATH]
|
||||
bash run_distribute_train_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
### [Evaluation Process](#contents)
|
||||
|
@ -211,7 +264,7 @@ bash run_distribute_train_gpu.sh [RANK_SIZE] [DATASET_PATH]
|
|||
- Run `run_eval.sh` for evaluation.
|
||||
|
||||
``` bash
|
||||
bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
|
||||
bash run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DEVICE_TARGET]
|
||||
```
|
||||
|
||||
## [Model Description](#contents)
|
||||
|
|
|
@ -108,14 +108,60 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
|||
# CPU训练示例
|
||||
$ bash run_standalone_train.sh ../data/train CPU
|
||||
或者
|
||||
python train.py --dataset_path=./data/train --platform=CPU
|
||||
python train.py --train_data_dir=./data/train --device_target=CPU
|
||||
|
||||
# CPU评估示例
|
||||
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU
|
||||
或者
|
||||
python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU
|
||||
python eval.py --test_data_dir=./data/test --checkpoint_path=warpctc-30-97.ckpt --device_target=CPU
|
||||
```
|
||||
|
||||
- 在ModelArts上运行
|
||||
如果你想在modelarts上运行,可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
|
||||
- 在ModelArt上使用8卡训练
|
||||
|
||||
```python
|
||||
# (1) 上传你的代码到 s3 桶上
|
||||
# (2) 在ModelArts上创建训练任务
|
||||
# (3) 选择代码目录 /{path}/warpctc
|
||||
# (4) 选择启动文件 /{path}/warpctc/train.py
|
||||
# (5) 执行a或b
|
||||
# a. 在 /{path}/warpctc/default_config.yaml 文件中设置参数
|
||||
# 1. 设置 ”run_distributed=True“
|
||||
# 2. 设置 ”enable_modelarts=True“
|
||||
# 3. 如果数据采用zip格式压缩包的形式上传,设置 ”modelarts_dataset_unzip_name={filenmae}"
|
||||
# b. 在 网页上设置
|
||||
# 1. 添加 ”run_distributed=True“
|
||||
# 2. 添加 ”enable_modelarts=True“
|
||||
# 3. 如果数据采用zip格式压缩包的形式上传,添加 ”modelarts_dataset_unzip_name={filenmae}"
|
||||
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
|
||||
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包)
|
||||
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
|
||||
# (9) 创建训练作业
|
||||
```
|
||||
|
||||
- 在ModelArts上使用单卡验证
|
||||
|
||||
```python
|
||||
# (1) 上传你的代码到 s3 桶上
|
||||
# (2) 在ModelArts上创建训练任务
|
||||
# (3) 选择代码目录 /{path}/warpctc
|
||||
# (4) 选择启动文件 /{path}/warpctc/eval.py
|
||||
# (5) 执行a或b
|
||||
# a. 在 /path/warpctc 下的default_config.yaml 文件中设置参数
|
||||
# 1. 设置 ”enable_modelarts=True“
|
||||
# 2. 设置 “checkpoint_path={checkpoint_path}”({checkpoint_path}表示待评估的 权重文件 相对于 eval.py 的路径,权重文件须包含在代码目录下。)
|
||||
# 3. 如果数据采用zip格式压缩包的形式上传,设置 ”modelarts_dataset_unzip_name={filenmae}"
|
||||
# b. 在 网页上设置
|
||||
# 1. 设置 ”enable_modelarts=True“
|
||||
# 2. 设置 “checkpoint_path={checkpoint_path}”({checkpoint_path}表示待评估的 权重文件 相对于 eval.py 的路径,权重文件须包含在代码目录下。)
|
||||
# 3. 如果数据采用zip格式压缩包的形式上传,设置 ”modelarts_dataset_unzip_name={filenmae}"
|
||||
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
|
||||
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包)
|
||||
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
|
||||
# (9) 创建训练作业
|
||||
```
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 脚本及样例代码
|
||||
|
@ -123,7 +169,8 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
|||
```text
|
||||
.
|
||||
└──warpctc
|
||||
├── README.md
|
||||
├── README.md # warpctc文档说明
|
||||
├── README_CN.md # warpctc中文文档说明
|
||||
├── script
|
||||
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
|
||||
├── run_distribute_train_for_gpu.sh # 启动GPU分布式训练
|
||||
|
@ -131,13 +178,19 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
|||
├── run_process_data.sh # 启动数据集生成
|
||||
└── run_standalone_train.sh # 启动单机训练(1卡)
|
||||
├── src
|
||||
├── config.py # 参数配置
|
||||
├── model_utils
|
||||
├── config.py # 解析 *.yaml参数配置文件
|
||||
├── devcie_adapter.py # 区分本地/ModelArts训练
|
||||
├── local_adapter.py # 本地训练获取相关环境变量
|
||||
└── moxing_adapter.py # ModelArts训练获取相关环境变量、交换数据
|
||||
├── dataset.py # 数据预处理
|
||||
├── loss.py # CTC损失定义
|
||||
├── lr_generator.py # 生成每个步骤的学习率
|
||||
├── metric.py # warpctc网络准确指标
|
||||
├── warpctc.py # warpctc网络定义
|
||||
└── warpctc_for_train.py # 带梯度、损失和梯度剪裁的warpctc网络
|
||||
├── default_config.yaml # 参数配置
|
||||
├── export.py # 推理
|
||||
├── mindspore_hub_conf.py # Mindspore Hub接口
|
||||
├── eval.py # 评估网络
|
||||
├── process_data.py # 数据集生成脚本
|
||||
|
@ -150,32 +203,32 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
|||
|
||||
```bash
|
||||
# Ascend分布式训练
|
||||
用法: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
用法: bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
|
||||
# GPU分布式训练
|
||||
用法: bash run_distribute_train_for_gpu.sh [RANK_SIZE] [DATASET_PATH]
|
||||
用法: bash run_distribute_train_for_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
|
||||
|
||||
# 单机训练
|
||||
用法: bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
|
||||
用法: bash run_standalone_train.sh [TRAIN_DATA_DIR] [DEVICE_TARGET]
|
||||
```
|
||||
|
||||
### 参数配置
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
在default_config.yaml中可以同时配置训练参数和评估参数。
|
||||
|
||||
```text
|
||||
"max_captcha_digits": 4, # 每张图像的数字个数上限。
|
||||
"captcha_width": 160, # captcha图片宽度。
|
||||
"captcha_height": 64, # capthca图片高度。
|
||||
"batch_size": 64, # 输入张量批次大小。
|
||||
"epoch_size": 30, # 只对训练有效,推理固定值为1。
|
||||
"hidden_size": 512, # LSTM层隐藏大小。
|
||||
"learning_rate": 0.01, # 初始学习率。
|
||||
"momentum": 0.9 # SGD优化器动量。
|
||||
"save_checkpoint": True, # 是否保存检查点。
|
||||
"save_checkpoint_steps": 97, # 两个检查点之间的迭代间隙。默认情况下,最后一个检查点将在最后一步迭代结束后保存。
|
||||
"keep_checkpoint_max": 30, # 只保留最后一个keep_checkpoint_max检查点。
|
||||
"save_checkpoint_path": "./checkpoint", # 检查点保存路径。
|
||||
max_captcha_digits: 4 # 每张图像的数字个数上限。
|
||||
captcha_width: 160 # captcha图片宽度。
|
||||
captcha_height: 64 # capthca图片高度。
|
||||
batch_size: 64 # 输入张量批次大小。
|
||||
epoch_size: 30 # 只对训练有效,推理固定值为1。
|
||||
hidden_size: 512 # LSTM层隐藏大小。
|
||||
learning_rate: 0.01 # 初始学习率。
|
||||
momentum: 0.9 # SGD优化器动量。
|
||||
save_checkpoint: True # 是否保存检查点。
|
||||
save_checkpoint_steps: 97 # 两个检查点之间的迭代间隙。默认情况下,最后一个检查点将在最后一步迭代结束后保存。
|
||||
keep_checkpoint_max: 30 # 只保留最后一个keep_checkpoint_max检查点。
|
||||
save_checkpoint_path: "./checkpoints" # 检查点保存路径,相对于train.py。
|
||||
```
|
||||
|
||||
## 数据集准备
|
||||
|
@ -184,14 +237,14 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
|||
|
||||
## 训练过程
|
||||
|
||||
- 在`config.py`中设置选项,包括学习率和网络超参数。单击[MindSpore加载数据集教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html),了解更多信息。
|
||||
- 在`default_config.yaml`中设置选项,包括学习率和网络超参数。单击[MindSpore加载数据集教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html),了解更多信息。
|
||||
|
||||
### 训练
|
||||
|
||||
- 在Ascend或GPU上运行`run_standalone_train.sh`进行WarpCTC模型的非分布式训练。
|
||||
|
||||
``` bash
|
||||
bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
|
||||
bash run_standalone_train.sh [TRAIN_DATA_DIR] [DEVICE_TARGET]
|
||||
```
|
||||
|
||||
### 分布式训练
|
||||
|
@ -199,13 +252,13 @@ bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
|
|||
- 在Ascend上运行`run_distribute_train.sh`进行WarpCTC模型的分布式训练。
|
||||
|
||||
``` bash
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
- 在GPU上运行`run_distribute_train_gpu.sh`进行WarpCTC模型的分布式训练。
|
||||
|
||||
``` bash
|
||||
bash run_distribute_train_gpu.sh [RANK_SIZE] [DATASET_PATH]
|
||||
bash run_distribute_train_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
@ -215,7 +268,7 @@ bash run_distribute_train_gpu.sh [RANK_SIZE] [DATASET_PATH]
|
|||
- 运行`run_eval.sh`进行评估。
|
||||
|
||||
``` bash
|
||||
bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
|
||||
bash run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DEVICE_TARGET]
|
||||
```
|
||||
|
||||
## 模型描述
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
# ==============================================================================
|
||||
#train-related
|
||||
run_distribute: False
|
||||
train_data_dir: None
|
||||
|
||||
max_captcha_digits: 4
|
||||
captcha_width: 160
|
||||
captcha_height: 64
|
||||
batch_size: 64
|
||||
epoch_size: 30
|
||||
hidden_size: 512
|
||||
learning_rate: 0.01
|
||||
momentum: 0.9
|
||||
save_checkpoint: True
|
||||
save_checkpoint_steps: 97
|
||||
keep_checkpoint_max: 30
|
||||
save_checkpoint_path: "./checkpoints"
|
||||
#eval-related
|
||||
test_data_dir: None
|
||||
checkpoint_path: None
|
||||
#export-related
|
||||
file_name: "warpctc"
|
||||
ckpt_file: ""
|
||||
file_format: "MINDIR"
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
|
||||
run_distribute: "Run distribute, default is false."
|
||||
train_data_dir: "tran Dataset path, default is None"
|
||||
|
||||
test_data_dir: "test Dataset path, default is None."
|
||||
checkpoint_path: "checkpoint file path, default is None"
|
||||
|
||||
file_name: "warpctc output file name, default: warpctc"
|
||||
ckpt_file: "required, warpctc ckpt file."
|
||||
file_format: "file format, choose from AIR, MINDIR, and default is MINDIR"
|
||||
|
||||
|
||||
|
|
@ -14,57 +14,109 @@
|
|||
# ============================================================================
|
||||
"""Warpctc evaluation"""
|
||||
import os
|
||||
import time
|
||||
import math as m
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.config import config as cf
|
||||
from src.dataset import create_dataset
|
||||
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
|
||||
from src.metric import WarpCTCAccuracy
|
||||
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
set_seed(1)
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.')
|
||||
args_opt = parser.parse_args()
|
||||
if config.modelarts_dataset_unzip_name:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||
if config.device_target == 'Ascend':
|
||||
context.set_context(device_id=get_device_id())
|
||||
max_captcha_digits = config.max_captcha_digits
|
||||
input_size = m.ceil(config.captcha_height / 64) * 64 * 3
|
||||
# create dataset
|
||||
if config.enable_modelarts:
|
||||
dataset_dir = config.data_path
|
||||
else:
|
||||
dataset_dir = config.test_data_dir
|
||||
dataset = create_dataset(dataset_path=dataset_dir,
|
||||
batch_size=config.batch_size,
|
||||
device_target=config.device_target)
|
||||
# step_size = dataset.get_dataset_size()
|
||||
loss = CTCLoss(max_sequence_length=config.captcha_width,
|
||||
max_label_length=max_captcha_digits,
|
||||
batch_size=config.batch_size)
|
||||
if config.device_target == 'Ascend':
|
||||
net = StackedRNN(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
|
||||
elif config.device_target == 'GPU':
|
||||
net = StackedRNNForGPU(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
|
||||
else:
|
||||
net = StackedRNNForCPU(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
|
||||
|
||||
# load checkpoint
|
||||
checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.checkpoint_path)
|
||||
param_dict = load_checkpoint(checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy(config.device_target)})
|
||||
# start evaluation
|
||||
res = model.eval(dataset, dataset_sink_mode=config.device_target == 'Ascend')
|
||||
print("result:", res, flush=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
max_captcha_digits = cf.max_captcha_digits
|
||||
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
batch_size=cf.batch_size,
|
||||
device_target=args_opt.platform)
|
||||
step_size = dataset.get_dataset_size()
|
||||
loss = CTCLoss(max_sequence_length=cf.captcha_width,
|
||||
max_label_length=max_captcha_digits,
|
||||
batch_size=cf.batch_size)
|
||||
if args_opt.platform == 'Ascend':
|
||||
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
elif args_opt.platform == 'GPU':
|
||||
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
else:
|
||||
net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy(args_opt.platform)})
|
||||
# start evaluation
|
||||
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
|
||||
print("result:", res, flush=True)
|
||||
run_eval()
|
||||
|
|
|
@ -13,29 +13,20 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air models"""
|
||||
import argparse
|
||||
import math as m
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
parser = argparse.ArgumentParser(description="warpctc_export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="warpctc ckpt file.")
|
||||
parser.add_argument("--file_name", type=str, default="warpctc", help="warpctc output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if args.file_format == "AIR" and args.device_target != "Ascend":
|
||||
if config.file_format == "AIR" and config.device_target != "Ascend":
|
||||
raise ValueError("export AIR must on Ascend")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -45,14 +36,14 @@ if __name__ == "__main__":
|
|||
batch_size = config.batch_size
|
||||
hidden_size = config.hidden_size
|
||||
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float32))
|
||||
if args.device_target == 'Ascend':
|
||||
if config.device_target == 'Ascend':
|
||||
net = StackedRNN(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
|
||||
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float16))
|
||||
elif args.device_target == 'GPU':
|
||||
elif config.device_target == 'GPU':
|
||||
net = StackedRNNForGPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
|
||||
else:
|
||||
net = StackedRNNForCPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
export(net, image, file_name=args.file_name, file_format=args.file_format)
|
||||
export(net, image, file_name=config.file_name, file_format=config.file_format)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -36,7 +36,7 @@ if [ ! -f $PATH1 ]; then
|
|||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]; then
|
||||
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||
echo "error: TRAIN_DATA_DIR=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -51,11 +51,12 @@ for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
|||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp ../*.yaml ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute > log.txt 2>&1 &
|
||||
python train.py --device_target=Ascend --train_data_dir=$PATH2 --run_distribute True > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_SIZE] [DATASET_PATH]"
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_SIZE] [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -28,10 +28,10 @@ get_real_path() {
|
|||
}
|
||||
|
||||
RANK_SIZE=$1
|
||||
DATASET_PATH=$(get_real_path $2)
|
||||
TRAIN_DATA_DIR=$(get_real_path $2)
|
||||
|
||||
if [ ! -d $DATASET_PATH ]; then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
if [ ! -d $TRAIN_DATA_DIR ]; then
|
||||
echo "error: TRAIN_DATA_DIR=$TRAIN_DATA_DIR is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -41,12 +41,13 @@ fi
|
|||
|
||||
mkdir ./distribute_train
|
||||
cp ../*.py ./distribute_train
|
||||
cp ../*.yaml ./distribute_train
|
||||
cp -r ../src ./distribute_train
|
||||
cd ./distribute_train || exit
|
||||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python train.py \
|
||||
--dataset_path=$DATASET_PATH \
|
||||
--platform=GPU \
|
||||
--run_distribute > log.txt 2>&1 &
|
||||
--train_data_dir=$TRAIN_DATA_DIR \
|
||||
--device_target=GPU \
|
||||
--run_distribute=True > log.txt 2>&1 &
|
||||
cd ..
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
|
||||
echo "Usage: sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DEVICE_TARGET]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -29,10 +29,10 @@ get_real_path() {
|
|||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
PLATFORM=$3
|
||||
DEVICE_TARGET=$3
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -53,11 +53,12 @@ run_ascend() {
|
|||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
|
||||
python eval.py --test_data_dir=$1 --checkpoint_path=$2 --device_target=Ascend > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
|
@ -67,16 +68,17 @@ run_gpu_cpu() {
|
|||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=$3 > log.txt 2>&1 &
|
||||
python eval.py --test_data_dir=$1 --checkpoint_path=$2 --device_target=$3 > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
if [ "Ascend" == $PLATFORM ]; then
|
||||
if [ "Ascend" == $DEVICE_TARGET ]; then
|
||||
run_ascend $PATH1 $PATH2
|
||||
else
|
||||
run_gpu_cpu $PATH1 $PATH2 $PLATFORM
|
||||
run_gpu_cpu $PATH1 $PATH2 $DEVICE_TARGET
|
||||
fi
|
||||
|
||||
|
|
|
@ -18,3 +18,4 @@ CUR_PATH=$(dirname $PWD/$0)
|
|||
cd $CUR_PATH/../ &&
|
||||
python process_data.py &&
|
||||
cd - &> /dev/null || exit
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]"
|
||||
echo "Usage: sh run_standalone_train.sh [TRAIN_DATA_DIR] [DEVICE_TARGET]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -28,10 +28,10 @@ get_real_path() {
|
|||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PLATFORM=$2
|
||||
DEVICE_TARGET=$2
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -44,13 +44,13 @@ run_ascend() {
|
|||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
|
||||
python train.py --train_data_dir=$1 --device_target=Ascend > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
run_gpu_cpu() {
|
||||
env >env.log
|
||||
python train.py --dataset_path=$1 --platform=$2 > log.txt 2>&1 &
|
||||
python train.py --train_data_dir=$1 --device_target=$2 > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
|
@ -59,11 +59,12 @@ if [ -d "train" ]; then
|
|||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
|
||||
if [ "Ascend" == $PLATFORM ]; then
|
||||
if [ "Ascend" == $DEVICE_TARGET ]; then
|
||||
run_ascend $PATH1
|
||||
else
|
||||
run_gpu_cpu $PATH1 $PLATFORM
|
||||
run_gpu_cpu $PATH1 $DEVICE_TARGET
|
||||
fi
|
|
@ -21,7 +21,7 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as c
|
||||
import mindspore.dataset.vision.c_transforms as vc
|
||||
from src.config import config as cf
|
||||
from src.model_utils.config import config
|
||||
|
||||
|
||||
class _CaptchaDataset:
|
||||
|
@ -79,18 +79,18 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_
|
|||
device_target(str): platform of training, support Ascend and GPU
|
||||
"""
|
||||
|
||||
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target)
|
||||
dataset = _CaptchaDataset(dataset_path, config.max_captcha_digits, device_target)
|
||||
data_set = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
|
||||
image_trans = [
|
||||
vc.Rescale(1.0 / 255.0, 0.0),
|
||||
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
|
||||
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
|
||||
vc.Resize((m.ceil(config.captcha_height / 16) * 16, config.captcha_width)),
|
||||
c.TypeCast(mstype.float16)
|
||||
]
|
||||
image_trans_gpu = [
|
||||
vc.Rescale(1.0 / 255.0, 0.0),
|
||||
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
|
||||
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
|
||||
vc.Resize((m.ceil(config.captcha_height / 16) * 16, config.captcha_width)),
|
||||
vc.HWC2CHW()
|
||||
]
|
||||
label_trans = [
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(config)
|
30
model_zoo/official/cv/warpctc/src/config.py → model_zoo/official/cv/warpctc/src/model_utils/device_adapter.py
Executable file → Normal file
30
model_zoo/official/cv/warpctc/src/config.py → model_zoo/official/cv/warpctc/src/model_utils/device_adapter.py
Executable file → Normal file
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,20 +12,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network parameters."""
|
||||
from easydict import EasyDict
|
||||
|
||||
config = EasyDict({
|
||||
"max_captcha_digits": 4,
|
||||
"captcha_width": 160,
|
||||
"captcha_height": 64,
|
||||
"batch_size": 64,
|
||||
"epoch_size": 30,
|
||||
"hidden_size": 512,
|
||||
"learning_rate": 0.01,
|
||||
"momentum": 0.9,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 97,
|
||||
"keep_checkpoint_max": 30,
|
||||
"save_checkpoint_path": "./",
|
||||
})
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
# print("os.mknod({}) success".format(sync_lock))
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -14,8 +14,8 @@
|
|||
# ============================================================================
|
||||
"""Warpctc training"""
|
||||
import os
|
||||
import time
|
||||
import math as m
|
||||
import argparse
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
|
@ -26,33 +26,77 @@ from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig,
|
|||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.config import config as cf
|
||||
from src.dataset import create_dataset
|
||||
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
|
||||
from src.warpctc_for_train import TrainOneStepCellWithGradClip
|
||||
from src.lr_schedule import get_lr
|
||||
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
|
||||
|
||||
set_seed(1)
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.')
|
||||
parser.set_defaults(run_distribute=False)
|
||||
args_opt = parser.parse_args()
|
||||
if config.modelarts_dataset_unzip_name:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
config.save_checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.save_checkpoint_path)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train():
|
||||
"""Train function."""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == 'Ascend':
|
||||
context.set_context(device_id=get_device_id())
|
||||
lr_scale = 1
|
||||
if args_opt.run_distribute:
|
||||
init()
|
||||
if args_opt.platform == 'Ascend':
|
||||
if config.run_distribute:
|
||||
if config.device_target == 'Ascend':
|
||||
device_num = int(os.environ.get("RANK_SIZE"))
|
||||
rank = int(os.environ.get("RANK_ID"))
|
||||
else:
|
||||
|
@ -62,29 +106,34 @@ if __name__ == '__main__':
|
|||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
device_num = 1
|
||||
rank = 0
|
||||
|
||||
max_captcha_digits = cf.max_captcha_digits
|
||||
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
|
||||
max_captcha_digits = config.max_captcha_digits
|
||||
input_size = m.ceil(config.captcha_height / 64) * 64 * 3
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, batch_size=cf.batch_size,
|
||||
num_shards=device_num, shard_id=rank, device_target=args_opt.platform)
|
||||
if config.enable_modelarts:
|
||||
dataset_dir = config.data_path
|
||||
else:
|
||||
dataset_dir = config.train_data_dir
|
||||
dataset = create_dataset(dataset_path=dataset_dir, batch_size=config.batch_size,
|
||||
num_shards=device_num, shard_id=rank, device_target=config.device_target)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# define lr
|
||||
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * device_num * lr_scale
|
||||
lr = get_lr(cf.epoch_size, step_size, lr_init)
|
||||
loss = CTCLoss(max_sequence_length=cf.captcha_width,
|
||||
lr_init = config.learning_rate if not config.run_distribute else config.learning_rate * device_num * lr_scale
|
||||
lr = get_lr(config.epoch_size, step_size, lr_init)
|
||||
loss = CTCLoss(max_sequence_length=config.captcha_width,
|
||||
max_label_length=max_captcha_digits,
|
||||
batch_size=cf.batch_size)
|
||||
if args_opt.platform == 'Ascend':
|
||||
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
elif args_opt.platform == 'GPU':
|
||||
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
batch_size=config.batch_size)
|
||||
if config.device_target == 'Ascend':
|
||||
net = StackedRNN(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
|
||||
elif config.device_target == 'GPU':
|
||||
net = StackedRNNForGPU(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
|
||||
else:
|
||||
net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
|
||||
net = StackedRNNForCPU(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum)
|
||||
|
||||
net = WithLossCell(net, loss)
|
||||
net = TrainOneStepCellWithGradClip(net, opt).set_train()
|
||||
|
@ -92,10 +141,14 @@ if __name__ == '__main__':
|
|||
model = Model(net)
|
||||
# define callbacks
|
||||
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
||||
if cf.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cf.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(cf.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = config.save_checkpoint_path
|
||||
ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=save_ckpt_path, config=config_ck)
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(cf.epoch_size, dataset, callbacks=callbacks)
|
||||
model.train(config.epoch_size, dataset, callbacks=callbacks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
|
|
Loading…
Reference in New Issue