crnn_seq2seq_ocr merge

This commit is contained in:
zhengbin 2021-05-22 12:15:07 +08:00
parent 18e3180ca4
commit ebae2fb6a5
16 changed files with 742 additions and 187 deletions

View File

@ -104,14 +104,60 @@ The dataset is self-generated using a third-party library called [captcha](https
# training example on CPU
$ bash run_standalone_train.sh ../data/train CPU
or
python train.py --dataset_path=./data/train --platform=CPU
python train.py --train_data_dir=./data/train --device_target=CPU
# evaluation example on CPU
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU
or
python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU
python eval.py --test_data_dir=./data/test --checkpoint_path=warpctc-30-97.ckpt --device_target=CPU
```
- running on ModelArts
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
- 在ModelArt上使用8卡训练
```python
# (1) Upload the code folder to S3 bucket.
# (2) Click to "create training task" on the website UI interface.
# (3) Set the code directory to "/{path}/warpctc" on the website UI interface.
# (4) Set the startup file to /{path}/warpctc/train.py" on the website UI interface.
# (5) Perform a or b.
# a. setting parameters in /{path}/warpctc/default_config.yaml.
# 1. Set ”run_distributed=True“
# 2. Set ”enable_modelarts=True“
# 3. Set ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
# b. adding on the website UI interface.
# 1. Add ”run_distributed=True“
# 2. Add ”enable_modelarts=True“
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Create your job.
```
- 在ModelArts上使用单卡验证
```python
# (1) Upload the code folder to S3 bucket.
# (2) Click to "create training task" on the website UI interface.
# (3) Set the code directory to "/{path}/warpctc" on the website UI interface.
# (4) Set the startup file to /{path}/warpctc/eval.py" on the website UI interface.
# (5) Perform a or b.
# a. 在 /path/warpctc 下的default_config.yaml 文件中设置参数
# 1. Set ”enable_modelarts=True“
# 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
# b. 在 网页上设置
# 1. Set ”enable_modelarts=True“
# 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.)
# 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package.
# (6) Upload the dataset or the zip package of dataset to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Create your job.
```
## [Script Description](#contents)
### [Script and Sample Code](#contents)
@ -119,7 +165,8 @@ The dataset is self-generated using a third-party library called [captcha](https
```shell
.
└──warpctc
├── README.md
├── README.md # descriptions of warpctc
├── README_CN.md # chinese descriptions of warpctc
├── script
├── run_distribute_train.sh # launch distributed training in Ascend(8 pcs)
├── run_distribute_train_for_gpu.sh # launch distributed training in GPU
@ -127,13 +174,19 @@ The dataset is self-generated using a third-party library called [captcha](https
├── run_process_data.sh # launch dataset generation
└── run_standalone_train.sh # launch standalone training(1 pcs)
├── src
├── config.py # parameter configuration
├── model_utils
├── config.py # parsing parameter configuration file of "*.yaml"
├── devcie_adapter.py # local or ModelArts training
├── local_adapter.py # get related environment variables in local training
└── moxing_adapter.py # get related environment variables in ModelArts training
├── dataset.py # data preprocessing
├── loss.py # ctcloss definition
├── lr_generator.py # generate learning rate for each step
├── metric.py # accuracy metric for warpctc network
├── warpctc.py # warpctc network definition
└── warpctc_for_train.py # warpctc network with grad, loss and gradient clip
├── default_config.yaml # parameter configuration
├── export.py # inference
├── mindspore_hub_conf.py # mindspore hub interface
├── eval.py # eval net
├── process_data.py # dataset generation script
@ -146,13 +199,13 @@ The dataset is self-generated using a third-party library called [captcha](https
```bash
# distributed training in Ascend
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
# distributed training in GPU
Usage: bash run_distribute_train_for_gpu.sh [RANK_SIZE] [DATASET_PATH]
Usage: bash run_distribute_train_for_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
# standalone training
Usage: bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
Usage: bash run_standalone_train.sh [TRAIN_DATA_DIR] [DEVICE_TARGET]
```
#### Parameters Configuration
@ -160,18 +213,18 @@ Usage: bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
Parameters for both training and evaluation can be set in config.py.
```bash
"max_captcha_digits": 4, # max number of digits in each
"captcha_width": 160, # width of captcha images
"captcha_height": 64, # height of capthca images
"batch_size": 64, # batch size of input tensor
"epoch_size": 30, # only valid for taining, which is always 1 for inference
"hidden_size": 512, # hidden size in LSTM layers
"learning_rate": 0.01, # initial learning rate
"momentum": 0.9 # momentum of SGD optimizer
"save_checkpoint": True, # whether save checkpoint or not
"save_checkpoint_steps": 97, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
"keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./checkpoint", # path to save checkpoint
max_captcha_digits: 4 # max number of digits in each
captcha_width: 160 # width of captcha images
captcha_height: 64 # height of capthca images
batch_size: 64 # batch size of input tensor
epoch_size: 30 # only valid for taining, which is always 1 for inference
hidden_size: 512 # hidden size in LSTM layers
learning_rate: 0.01 # initial learning rate
momentum: 0.9 # momentum of SGD optimizer
save_checkpoint: True # whether save checkpoint or not
save_checkpoint_steps: 97 # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
keep_checkpoint_max: 30 # only keep the last keep_checkpoint_max checkpoint
save_checkpoint_path: "./checkpoint" # path to save checkpoint
```
## [Dataset Preparation](#contents)
@ -180,14 +233,14 @@ Parameters for both training and evaluation can be set in config.py.
### [Training Process](#contents)
- Set options in `config.py`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
- Set options in `default_config.yaml`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
#### [Training](#contents)
- Run `run_standalone_train.sh` for non-distributed training of WarpCTC model, either on Ascend or on GPU.
``` bash
bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
bash run_standalone_train.sh [TRAIN_DATA_DIR] [DEVICE_TARGET]
```
##### [Distributed Training](#contents)
@ -195,13 +248,13 @@ bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
- Run `run_distribute_train.sh` for distributed training of WarpCTC model on Ascend.
``` bash
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
```
- Run `run_distribute_train_gpu.sh` for distributed training of WarpCTC model on GPU.
``` bash
bash run_distribute_train_gpu.sh [RANK_SIZE] [DATASET_PATH]
bash run_distribute_train_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
```
### [Evaluation Process](#contents)
@ -211,7 +264,7 @@ bash run_distribute_train_gpu.sh [RANK_SIZE] [DATASET_PATH]
- Run `run_eval.sh` for evaluation.
``` bash
bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
bash run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DEVICE_TARGET]
```
## [Model Description](#contents)

View File

@ -108,14 +108,60 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
# CPU训练示例
$ bash run_standalone_train.sh ../data/train CPU
或者
python train.py --dataset_path=./data/train --platform=CPU
python train.py --train_data_dir=./data/train --device_target=CPU
# CPU评估示例
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU
或者
python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU
python eval.py --test_data_dir=./data/test --checkpoint_path=warpctc-30-97.ckpt --device_target=CPU
```
- 在ModelArts上运行
如果你想在modelarts上运行可以参考以下文档 [modelarts](https://support.huaweicloud.com/modelarts/))
- 在ModelArt上使用8卡训练
```python
# (1) 上传你的代码到 s3 桶上
# (2) 在ModelArts上创建训练任务
# (3) 选择代码目录 /{path}/warpctc
# (4) 选择启动文件 /{path}/warpctc/train.py
# (5) 执行a或b
# a. 在 /{path}/warpctc/default_config.yaml 文件中设置参数
# 1. 设置 ”run_distributed=True“
# 2. 设置 ”enable_modelarts=True“
# 3. 如果数据采用zip格式压缩包的形式上传设置 ”modelarts_dataset_unzip_name={filenmae}"
# b. 在 网页上设置
# 1. 添加 ”run_distributed=True“
# 2. 添加 ”enable_modelarts=True“
# 3. 如果数据采用zip格式压缩包的形式上传添加 ”modelarts_dataset_unzip_name={filenmae}"
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
# (9) 创建训练作业
```
- 在ModelArts上使用单卡验证
```python
# (1) 上传你的代码到 s3 桶上
# (2) 在ModelArts上创建训练任务
# (3) 选择代码目录 /{path}/warpctc
# (4) 选择启动文件 /{path}/warpctc/eval.py
# (5) 执行a或b
# a. 在 /path/warpctc 下的default_config.yaml 文件中设置参数
# 1. 设置 ”enable_modelarts=True“
# 2. 设置 “checkpoint_path={checkpoint_path}”({checkpoint_path}表示待评估的 权重文件 相对于 eval.py 的路径,权重文件须包含在代码目录下。)
# 3. 如果数据采用zip格式压缩包的形式上传设置 ”modelarts_dataset_unzip_name={filenmae}"
# b. 在 网页上设置
# 1. 设置 ”enable_modelarts=True“
# 2. 设置 “checkpoint_path={checkpoint_path}”({checkpoint_path}表示待评估的 权重文件 相对于 eval.py 的路径,权重文件须包含在代码目录下。)
# 3. 如果数据采用zip格式压缩包的形式上传设置 ”modelarts_dataset_unzip_name={filenmae}"
# (6) 上传你的 数据/数据zip压缩包 到 s3 桶上
# (7) 在网页上勾选数据存储位置,设置“训练数据集”路径(该路径下仅有 数据/数据zip压缩包
# (8) 在网页上设置“训练输出文件路径”、“作业日志路径”
# (9) 创建训练作业
```
## 脚本说明
### 脚本及样例代码
@ -123,7 +169,8 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
```text
.
└──warpctc
├── README.md
├── README.md # warpctc文档说明
├── README_CN.md # warpctc中文文档说明
├── script
├── run_distribute_train.sh # 启动Ascend分布式训练8卡
├── run_distribute_train_for_gpu.sh # 启动GPU分布式训练
@ -131,13 +178,19 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
├── run_process_data.sh # 启动数据集生成
└── run_standalone_train.sh # 启动单机训练1卡
├── src
├── config.py # 参数配置
├── model_utils
├── config.py # 解析 *.yaml参数配置文件
├── devcie_adapter.py # 区分本地/ModelArts训练
├── local_adapter.py # 本地训练获取相关环境变量
└── moxing_adapter.py # ModelArts训练获取相关环境变量、交换数据
├── dataset.py # 数据预处理
├── loss.py # CTC损失定义
├── lr_generator.py # 生成每个步骤的学习率
├── metric.py # warpctc网络准确指标
├── warpctc.py # warpctc网络定义
└── warpctc_for_train.py # 带梯度、损失和梯度剪裁的warpctc网络
├── default_config.yaml # 参数配置
├── export.py # 推理
├── mindspore_hub_conf.py # Mindspore Hub接口
├── eval.py # 评估网络
├── process_data.py # 数据集生成脚本
@ -150,32 +203,32 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
```bash
# Ascend分布式训练
用法: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
用法: bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
# GPU分布式训练
用法: bash run_distribute_train_for_gpu.sh [RANK_SIZE] [DATASET_PATH]
用法: bash run_distribute_train_for_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
# 单机训练
用法: bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
用法: bash run_standalone_train.sh [TRAIN_DATA_DIR] [DEVICE_TARGET]
```
### 参数配置
在config.py中可以同时配置训练参数和评估参数。
default_config.yaml中可以同时配置训练参数和评估参数。
```text
"max_captcha_digits": 4, # 每张图像的数字个数上限。
"captcha_width": 160, # captcha图片宽度。
"captcha_height": 64, # capthca图片高度。
"batch_size": 64, # 输入张量批次大小。
"epoch_size": 30, # 只对训练有效推理固定值为1。
"hidden_size": 512, # LSTM层隐藏大小。
"learning_rate": 0.01, # 初始学习率。
"momentum": 0.9 # SGD优化器动量。
"save_checkpoint": True, # 是否保存检查点。
"save_checkpoint_steps": 97, # 两个检查点之间的迭代间隙。默认情况下,最后一个检查点将在最后一步迭代结束后保存。
"keep_checkpoint_max": 30, # 只保留最后一个keep_checkpoint_max检查点。
"save_checkpoint_path": "./checkpoint", # 检查点保存路径
max_captcha_digits: 4 # 每张图像的数字个数上限。
captcha_width: 160 # captcha图片宽度。
captcha_height: 64 # capthca图片高度。
batch_size: 64 # 输入张量批次大小。
epoch_size: 30 # 只对训练有效推理固定值为1。
hidden_size: 512 # LSTM层隐藏大小。
learning_rate: 0.01 # 初始学习率。
momentum: 0.9 # SGD优化器动量。
save_checkpoint: True # 是否保存检查点。
save_checkpoint_steps: 97 # 两个检查点之间的迭代间隙。默认情况下,最后一个检查点将在最后一步迭代结束后保存。
keep_checkpoint_max: 30 # 只保留最后一个keep_checkpoint_max检查点。
save_checkpoint_path: "./checkpoints" # 检查点保存路径相对于train.py
```
## 数据集准备
@ -184,14 +237,14 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
## 训练过程
- 在`config.py`中设置选项,包括学习率和网络超参数。单击[MindSpore加载数据集教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html),了解更多信息。
- 在`default_config.yaml`中设置选项,包括学习率和网络超参数。单击[MindSpore加载数据集教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html),了解更多信息。
### 训练
- 在Ascend或GPU上运行`run_standalone_train.sh`进行WarpCTC模型的非分布式训练。
``` bash
bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
bash run_standalone_train.sh [TRAIN_DATA_DIR] [DEVICE_TARGET]
```
### 分布式训练
@ -199,13 +252,13 @@ bash run_standalone_train.sh [DATASET_PATH] [PLATFORM]
- 在Ascend上运行`run_distribute_train.sh`进行WarpCTC模型的分布式训练。
``` bash
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
```
- 在GPU上运行`run_distribute_train_gpu.sh`进行WarpCTC模型的分布式训练。
``` bash
bash run_distribute_train_gpu.sh [RANK_SIZE] [DATASET_PATH]
bash run_distribute_train_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
```
## 评估过程
@ -215,7 +268,7 @@ bash run_distribute_train_gpu.sh [RANK_SIZE] [DATASET_PATH]
- 运行`run_eval.sh`进行评估。
``` bash
bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
bash run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DEVICE_TARGET]
```
## 模型描述

View File

@ -0,0 +1,62 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
enable_profiling: False
modelarts_dataset_unzip_name: ''
# ==============================================================================
#train-related
run_distribute: False
train_data_dir: None
max_captcha_digits: 4
captcha_width: 160
captcha_height: 64
batch_size: 64
epoch_size: 30
hidden_size: 512
learning_rate: 0.01
momentum: 0.9
save_checkpoint: True
save_checkpoint_steps: 97
keep_checkpoint_max: 30
save_checkpoint_path: "./checkpoints"
#eval-related
test_data_dir: None
checkpoint_path: None
#export-related
file_name: "warpctc"
ckpt_file: ""
file_format: "MINDIR"
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts, default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
enable_profiling: 'Whether enable profiling while training, default: False'
run_distribute: "Run distribute, default is false."
train_data_dir: "tran Dataset path, default is None"
test_data_dir: "test Dataset path, default is None."
checkpoint_path: "checkpoint file path, default is None"
file_name: "warpctc output file name, default: warpctc"
ckpt_file: "required, warpctc ckpt file."
file_format: "file format, choose from AIR, MINDIR, and default is MINDIR"

View File

@ -14,57 +14,109 @@
# ============================================================================
"""Warpctc evaluation"""
import os
import time
import math as m
import argparse
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.loss import CTCLoss
from src.config import config as cf
from src.dataset import create_dataset
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
from src.metric import WarpCTCAccuracy
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id, get_device_num
set_seed(1)
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.')
args_opt = parser.parse_args()
if config.modelarts_dataset_unzip_name:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
if config.device_target == 'Ascend':
context.set_context(device_id=get_device_id())
max_captcha_digits = config.max_captcha_digits
input_size = m.ceil(config.captcha_height / 64) * 64 * 3
# create dataset
if config.enable_modelarts:
dataset_dir = config.data_path
else:
dataset_dir = config.test_data_dir
dataset = create_dataset(dataset_path=dataset_dir,
batch_size=config.batch_size,
device_target=config.device_target)
# step_size = dataset.get_dataset_size()
loss = CTCLoss(max_sequence_length=config.captcha_width,
max_label_length=max_captcha_digits,
batch_size=config.batch_size)
if config.device_target == 'Ascend':
net = StackedRNN(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
elif config.device_target == 'GPU':
net = StackedRNNForGPU(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
else:
net = StackedRNNForCPU(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
# load checkpoint
checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.checkpoint_path)
param_dict = load_checkpoint(checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define model
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy(config.device_target)})
# start evaluation
res = model.eval(dataset, dataset_sink_mode=config.device_target == 'Ascend')
print("result:", res, flush=True)
if __name__ == '__main__':
max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path,
batch_size=cf.batch_size,
device_target=args_opt.platform)
step_size = dataset.get_dataset_size()
loss = CTCLoss(max_sequence_length=cf.captcha_width,
max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
if args_opt.platform == 'Ascend':
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
elif args_opt.platform == 'GPU':
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
else:
net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define model
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy(args_opt.platform)})
# start evaluation
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
print("result:", res, flush=True)
run_eval()

View File

@ -13,29 +13,20 @@
# limitations under the License.
# ============================================================================
"""export checkpoint file into air models"""
import argparse
import math as m
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
from src.config import config
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
parser = argparse.ArgumentParser(description="warpctc_export")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--ckpt_file", type=str, required=True, help="warpctc ckpt file.")
parser.add_argument("--file_name", type=str, default="warpctc", help="warpctc output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if args.file_format == "AIR" and args.device_target != "Ascend":
if config.file_format == "AIR" and config.device_target != "Ascend":
raise ValueError("export AIR must on Ascend")
if __name__ == "__main__":
@ -45,14 +36,14 @@ if __name__ == "__main__":
batch_size = config.batch_size
hidden_size = config.hidden_size
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float32))
if args.device_target == 'Ascend':
if config.device_target == 'Ascend':
net = StackedRNN(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float16))
elif args.device_target == 'GPU':
elif config.device_target == 'GPU':
net = StackedRNNForGPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
else:
net = StackedRNNForCPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
param_dict = load_checkpoint(args.ckpt_file)
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(net, param_dict)
net.set_train(False)
export(net, image, file_name=args.file_name, file_format=args.file_format)
export(net, image, file_name=config.file_name, file_format=config.file_format)

View File

@ -15,7 +15,7 @@
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]"
exit 1
fi
@ -36,7 +36,7 @@ if [ ! -f $PATH1 ]; then
fi
if [ ! -d $PATH2 ]; then
echo "error: DATASET_PATH=$PATH2 is not a directory"
echo "error: TRAIN_DATA_DIR=$PATH2 is not a directory"
exit 1
fi
@ -51,11 +51,12 @@ for ((i = 0; i < ${DEVICE_NUM}; i++)); do
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp ../*.yaml ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute > log.txt 2>&1 &
python train.py --device_target=Ascend --train_data_dir=$PATH2 --run_distribute True > log.txt 2>&1 &
cd ..
done

View File

@ -15,7 +15,7 @@
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [RANK_SIZE] [DATASET_PATH]"
echo "Usage: sh run_distribute_train.sh [RANK_SIZE] [TRAIN_DATA_DIR]"
exit 1
fi
@ -28,10 +28,10 @@ get_real_path() {
}
RANK_SIZE=$1
DATASET_PATH=$(get_real_path $2)
TRAIN_DATA_DIR=$(get_real_path $2)
if [ ! -d $DATASET_PATH ]; then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
if [ ! -d $TRAIN_DATA_DIR ]; then
echo "error: TRAIN_DATA_DIR=$TRAIN_DATA_DIR is not a directory"
exit 1
fi
@ -41,12 +41,13 @@ fi
mkdir ./distribute_train
cp ../*.py ./distribute_train
cp ../*.yaml ./distribute_train
cp -r ../src ./distribute_train
cd ./distribute_train || exit
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py \
--dataset_path=$DATASET_PATH \
--platform=GPU \
--run_distribute > log.txt 2>&1 &
--train_data_dir=$TRAIN_DATA_DIR \
--device_target=GPU \
--run_distribute=True > log.txt 2>&1 &
cd ..

View File

@ -15,7 +15,7 @@
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
echo "Usage: sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DEVICE_TARGET]"
exit 1
fi
@ -29,10 +29,10 @@ get_real_path() {
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
PLATFORM=$3
DEVICE_TARGET=$3
if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory"
echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
@ -53,11 +53,12 @@ run_ascend() {
fi
mkdir ./eval
cp ../*.py ./eval
cp ../*.yaml ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
python eval.py --test_data_dir=$1 --checkpoint_path=$2 --device_target=Ascend > log.txt 2>&1 &
cd ..
}
@ -67,16 +68,17 @@ run_gpu_cpu() {
fi
mkdir ./eval
cp ../*.py ./eval
cp ../*.yaml ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=$3 > log.txt 2>&1 &
python eval.py --test_data_dir=$1 --checkpoint_path=$2 --device_target=$3 > log.txt 2>&1 &
cd ..
}
if [ "Ascend" == $PLATFORM ]; then
if [ "Ascend" == $DEVICE_TARGET ]; then
run_ascend $PATH1 $PATH2
else
run_gpu_cpu $PATH1 $PATH2 $PLATFORM
run_gpu_cpu $PATH1 $PATH2 $DEVICE_TARGET
fi

View File

@ -18,3 +18,4 @@ CUR_PATH=$(dirname $PWD/$0)
cd $CUR_PATH/../ &&
python process_data.py &&
cd - &> /dev/null || exit

View File

@ -15,7 +15,7 @@
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]"
echo "Usage: sh run_standalone_train.sh [TRAIN_DATA_DIR] [DEVICE_TARGET]"
exit 1
fi
@ -28,10 +28,10 @@ get_real_path() {
}
PATH1=$(get_real_path $1)
PLATFORM=$2
DEVICE_TARGET=$2
if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory"
echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
@ -44,13 +44,13 @@ run_ascend() {
echo "start training for device $DEVICE_ID"
env >env.log
python train.py --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
python train.py --train_data_dir=$1 --device_target=Ascend > log.txt 2>&1 &
cd ..
}
run_gpu_cpu() {
env >env.log
python train.py --dataset_path=$1 --platform=$2 > log.txt 2>&1 &
python train.py --train_data_dir=$1 --device_target=$2 > log.txt 2>&1 &
cd ..
}
@ -59,11 +59,12 @@ if [ -d "train" ]; then
fi
mkdir ./train
cp ../*.py ./train
cp ../*.yaml ./train
cp -r ../src ./train
cd ./train || exit
if [ "Ascend" == $PLATFORM ]; then
if [ "Ascend" == $DEVICE_TARGET ]; then
run_ascend $PATH1
else
run_gpu_cpu $PATH1 $PLATFORM
run_gpu_cpu $PATH1 $DEVICE_TARGET
fi

View File

@ -21,7 +21,7 @@ import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as c
import mindspore.dataset.vision.c_transforms as vc
from src.config import config as cf
from src.model_utils.config import config
class _CaptchaDataset:
@ -79,18 +79,18 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_
device_target(str): platform of training, support Ascend and GPU
"""
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target)
dataset = _CaptchaDataset(dataset_path, config.max_captcha_digits, device_target)
data_set = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
image_trans = [
vc.Rescale(1.0 / 255.0, 0.0),
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
vc.Resize((m.ceil(config.captcha_height / 16) * 16, config.captcha_width)),
c.TypeCast(mstype.float16)
]
image_trans_gpu = [
vc.Rescale(1.0 / 255.0, 0.0),
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
vc.Resize((m.ceil(config.captcha_height / 16) * 16, config.captcha_width)),
vc.HWC2CHW()
]
label_trans = [

View File

@ -0,0 +1,130 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()
if __name__ == '__main__':
print(config)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,20 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network parameters."""
from easydict import EasyDict
config = EasyDict({
"max_captcha_digits": 4,
"captcha_width": 160,
"captcha_height": 64,
"batch_size": 64,
"epoch_size": 30,
"hidden_size": 512,
"learning_rate": 0.01,
"momentum": 0.9,
"save_checkpoint": True,
"save_checkpoint_steps": 97,
"keep_checkpoint_max": 30,
"save_checkpoint_path": "./",
})
"""Device adapter for ModelArts"""
from src.model_utils.config import config
if config.enable_modelarts:
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,123 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from src.model_utils.config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
# print("os.mknod({}) success".format(sync_lock))
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -14,8 +14,8 @@
# ============================================================================
"""Warpctc training"""
import os
import time
import math as m
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.common import set_seed
@ -26,33 +26,77 @@ from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig,
from mindspore.communication.management import init, get_group_size, get_rank
from src.loss import CTCLoss
from src.config import config as cf
from src.dataset import create_dataset
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
from src.warpctc_for_train import TrainOneStepCellWithGradClip
from src.lr_schedule import get_lr
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
set_seed(1)
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.')
parser.set_defaults(run_distribute=False)
args_opt = parser.parse_args()
if config.modelarts_dataset_unzip_name:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
if __name__ == '__main__':
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
config.save_checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.save_checkpoint_path)
@moxing_wrapper(pre_process=modelarts_pre_process)
def train():
"""Train function."""
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == 'Ascend':
context.set_context(device_id=get_device_id())
lr_scale = 1
if args_opt.run_distribute:
init()
if args_opt.platform == 'Ascend':
if config.run_distribute:
if config.device_target == 'Ascend':
device_num = int(os.environ.get("RANK_SIZE"))
rank = int(os.environ.get("RANK_ID"))
else:
@ -62,29 +106,34 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
device_num = 1
rank = 0
max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
max_captcha_digits = config.max_captcha_digits
input_size = m.ceil(config.captcha_height / 64) * 64 * 3
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, batch_size=cf.batch_size,
num_shards=device_num, shard_id=rank, device_target=args_opt.platform)
if config.enable_modelarts:
dataset_dir = config.data_path
else:
dataset_dir = config.train_data_dir
dataset = create_dataset(dataset_path=dataset_dir, batch_size=config.batch_size,
num_shards=device_num, shard_id=rank, device_target=config.device_target)
step_size = dataset.get_dataset_size()
# define lr
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * device_num * lr_scale
lr = get_lr(cf.epoch_size, step_size, lr_init)
loss = CTCLoss(max_sequence_length=cf.captcha_width,
lr_init = config.learning_rate if not config.run_distribute else config.learning_rate * device_num * lr_scale
lr = get_lr(config.epoch_size, step_size, lr_init)
loss = CTCLoss(max_sequence_length=config.captcha_width,
max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
if args_opt.platform == 'Ascend':
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
elif args_opt.platform == 'GPU':
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
batch_size=config.batch_size)
if config.device_target == 'Ascend':
net = StackedRNN(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
elif config.device_target == 'GPU':
net = StackedRNNForGPU(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
else:
net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
net = StackedRNNForCPU(input_size=input_size, batch_size=config.batch_size, hidden_size=config.hidden_size)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum)
net = WithLossCell(net, loss)
net = TrainOneStepCellWithGradClip(net, opt).set_train()
@ -92,10 +141,14 @@ if __name__ == '__main__':
model = Model(net)
# define callbacks
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
if cf.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps,
keep_checkpoint_max=cf.keep_checkpoint_max)
save_ckpt_path = os.path.join(cf.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
save_ckpt_path = config.save_checkpoint_path
ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=save_ckpt_path, config=config_ck)
callbacks.append(ckpt_cb)
model.train(cf.epoch_size, dataset, callbacks=callbacks)
model.train(config.epoch_size, dataset, callbacks=callbacks)
if __name__ == '__main__':
train()