!21373 Deepspeech2更新

Merge pull request !21373 from 懒人谈/deepspeech2
This commit is contained in:
i-robot 2021-08-23 07:54:12 +00:00 committed by Gitee
commit 1db7bfde62
9 changed files with 503 additions and 25 deletions

View File

@ -0,0 +1,293 @@
# 目录
[View English](./README.md)
<!-- TOC -->
- [目录](#目录)
- [DeepSpeech2介绍](#DeepSpeech2介绍)
- [网络模型结构](#网络模型结构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [文件说明和运行说明](#文件说明和运行说明)
- [代码目录结构说明](#代码目录结构说明)
- [模型参数](#模型参数)
- [训练和推理过程](#训练和推理过程)
- [Export](#Export)
- [性能](#性能)
- [训练性能](#训练性能)
- [推理性能](#推理性能)
- [ModelZoo主页](#modelzoo主页)
# [DeepSpeech2介绍](#contents)
DeepSpeech2是一个使用 CTC 损失训练的语音识别模型。它用神经网络取代了整个手工设计的管道,可以处理各种各样的语音,包括嘈杂的环境、口音和不同的语言。
[论文](https://arxiv.org/pdf/1512.02595v1.pdf): Amodei, Dario, et al. Deep speech 2: End-to-end speech recognition in english and mandarin.
# [网络模型结构](#contents)
模型包括:
- 两个卷积层:
- 通道数为 32内核大小为 41, 11 ,步长为 2, 2
- 通道数为 32内核大小为 41, 11 ,步长为 2, 1
- 五个双向 LSTM 层(大小为 1024
- 一个投影层【大小为字符数加 1为CTC空白符号)29】
# [数据集](#contents)
可以基于论文中提到的数据集或在相关领域/网络架构中广泛使用的数据集运行脚本。在下面的部分中,我们将介绍如何使用下面的相关数据集运行脚本。
使用的数据集为: [LibriSpeech](<http://www.openslr.org/12>)
- 训练集:
- train-clean-100: [6.3G] (100小时的无噪音演讲训练集)
- train-clean-360.tar.gz [23G] (360小时的无噪音演讲训练集)
- train-other-500.tar.gz [30G] (500小时的有噪音演讲训练集)
- 验证集:
- dev-clean.tar.gz [337M] (无噪音)
- dev-other.tar.gz [314M] (有噪音)
- 测试集:
- test-clean.tar.gz [346M] (测试集, 无噪音)
- test-other.tar.gz [328M] (测试集, 有噪音)
- 数据格式wav 和 txt 文件
- 注意数据需要通过librispeech.py进行处理
# [环境要求](#contents)
- 硬件GPU
- GPU处理器
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 通过下面网址可以获得更多信息:
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [文件说明和运行说明](#contents)
## [代码目录结构说明](#contents)
```path
.
├── audio
├── deepspeech2
├── scripts
│ ├──run_distribute_train_gpu.sh // gpu8卡训练脚本
│ ├──run_eval_cpu.sh // cpu推理脚本
│ ├──run_eval_gpu.sh // gpu推理脚本
│ ├──run_standalone_train_cpu.sh // cpu单卡训练脚本
│ └──run_standalone_train_gpu.sh // gpu单卡训练脚本
├── train.py // 训练文件
├── eval.py // 推理文件
├── export.py // 将mindspore模型转换为mindir模型
├── labels.json // 可能映射到的字符
├── README.md // DeepSpeech2相关描述
├── deepspeech_pytorch //
├──decoder.py // 来自第三方代码的解码器MIT 许可证)
├── src
├──__init__.py
├──DeepSpeech.py // DeepSpeech2网络架构
├──dataset.py // 数据处理
├──config.py // DeepSpeech配置文件
├──lr_generator.py // 产生学习率
├──greedydecoder.py // 修改Mindspore代码的greedydecoder
└──callback.py // 回调以监控训练
```
## [模型参数](#contents)
训练和推理的相关参数在`config.py`文件
```text
训练相关参数
epochs 训练的epoch数量默认为70
```
```text
数据处理相关参数
train_manifest 用于训练的数据文件路径,默认为 'data/libri_train_manifest.csv'
val_manifest 用于测试的数据文件路径,默认为 'data/libri_val_manifest.csv'
batch_size 批处理大小默认为8
labels_path 模型输出的token json 路径, 默认为 "./labels.json"
sample_rate 数据特征的采样率默认为16000
window_size 频谱图生成的窗口大小默认为0.02
window_stride 频谱图生成的窗口步长默认为0.01
window 频谱图生成的窗口类型,默认为 'hamming'
speed_volume_perturb 使用随机速度和增益扰动默认为False当前模型中未使用
spec_augment 在MEL谱图上使用简单的光谱增强默认为False当前模型中未使用
noise_dir 注入噪音到音频。默认为noise Inject未添加默认为'',当前模型中未使用
noise_prob 每个样本加噪声的概率默认为0.4,当前模型中未使用
noise_min 样本的最小噪音水平,(1.0意味着所有的噪声,不是原始信号)默认是0.0,当前模型中未使用
noise_max 样本的最大噪音水平。最大值为1.0默认值为0.5,当前模型中未使用
```
```text
模型相关参数
rnn_type 模型中使用的RNN类型默认为'LSTM'当前只支持LSTM
hidden_size RNN层的隐藏大小默认为1024
hidden_layers RNN层的数量默认为5
lookahead_context 查看上下文默认值是20当前模型中未使用
```
```text
优化器相关参数
learning_rate 初始化学习率默认为3e-4
learning_anneal 对每个epoch之后的学习率进行退火默认为1.1
weight_decay 权重衰减默认为1e-5
momentum 动量默认为0.9
eps Adam eps默认为1e-8
betas Adam betas默认为(0.9, 0.999)
loss_scale 损失规模默认是1024
```
```text
checkpoint相关参数
ckpt_file_name_prefix ckpt文件的名称前缀默认为'DeepSpeech'
ckpt_path ckpt文件的保存路径默认为'checkpoints'
keep_checkpoint_max ckpt文件的最大数量限制删除旧的检查点默认是10
```
# [训练和推理过程](#contents)
## 训练
```text
运行: train.py [--use_pretrained USE_PRETRAINED]
[--pre_trained_model_path PRE_TRAINED_MODEL_PATH]
[--is_distributed IS_DISTRIBUTED]
[--bidirectional BIDIRECTIONAL]
[--device_target DEVICE_TARGET]
参数:
--pre_trained_model_path 预先训练的模型文件路径,默认为''
--is_distributed 多卡训练默认为False
--bidirectional 是否使用双向RNN默认为True目前只实现了双向模型
--device_target 运行代码的设备:"GPU" | “CPU”默认为"GPU"
```
## 推理
```text
运行: eval.py [--bidirectional BIDIRECTIONAL]
[--pretrain_ckpt PRETRAIN_CKPT]
[--device_target DEVICE_TARGET]
参数:
--bidirectional 是否使用双向RNN默认为True。 目前只实现了双向模型
--pretrain_ckpt checkpoint的文件路径, 默认为''
--device_target 运行代码的设备:"GPU" | “CPU”默认为"GPU"
```
在训练之前,应该处理数据集,使用[SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch)中的脚本来处理数据。
[SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch)中的脚本文件将自动下载数据集并进行处理。
流程结束后,数据目录结构如下:
```path
.
├─ LibriSpeech_dataset
│ ├── train
│ │ ├─ wav
│ │ └─ txt
│ ├── val
│ │ ├─ wav
│ │ └─ txt
│ ├── test_clean
│ │ ├─ wav
│ │ └─ txt
│ └── test_other
│ ├─ wav
│ └─ txt
└─ libri_test_clean_manifest.csv, libri_test_other_manifest.csv, libri_train_manifest.csv, libri_val_manifest.csv
```
三个*.csv文件存放的是对应数据的绝对路径得到3个csv文件后修改src/config.py中的配置。
对于训练配置, train_manifest应该配置为`libri_train_manifest.csv`的路径,对于 eval 配置,应该配置为 `libri_test_other_manifest.csv``libri_train_manifest.csv`,具体取决于评估的数据集。
```shell
...
训练配置
"DataConfig":{
train_manifest:'path_to_csv/libri_train_manifest.csv'
}
评估配置
"DataConfig":{
train_manifest:'path_to_csv/libri_test_clean_manifest.csv'
}
```
训练之前,需要安装`librosa` and `Levenshtein`
通过官网安装MindSpore并完成数据集处理后可以开始训练如下
```shell
# gpu单卡训练
bash ./scripts/run_standalone_train_gpu.sh [DEVICE_ID]
# cpu单卡训练
bash ./scripts/run_standalone_train_cpu.sh
# gpu多卡训练
bash ./scripts/run_distribute_train_gpu.sh
```
进行模型评估需要注意的是目前在运行脚本之前只支持greedy decoder可以从[SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch)下载解码器并将
deepspeech_pytorch文件放入deepspeech2目录 之后文件目录将显示为[Script and Sample Code]
```shell
# cpu评估
bash ./scripts/run_eval_cpu.sh [PATH_CHECKPOINT]
# gpu评估
bash ./scripts/run_eval_gpu.sh [DEVICE_ID] [PATH_CHECKPOINT]
```
## [Export](#contents)
```bash
python export.py --pre_trained_model_path='ckpt_path'
```
# [性能](#contents)
## [训练和测试性能分析](#contents)
### 训练性能
| 参数 | DeepSpeech |
| -------------------------- | ---------------------------------------------------------------|
| 资源 | NV SMX2 V100-32G |
| 更新日期 | 12/29/2020 (month/day/year) |
| MindSpore版本 | 1.0.0 |
| 数据集 | LibriSpeech |
| 训练参数 | 2p, epoch=70, steps=5144 * epoch, batch_size = 20, lr=3e-4 |
| 优化器 | Adam |
| 损失函数 | CTCLoss |
| 输出 | 概率值 |
| 损失值 | 0.2-0.7 |
| 运行速度 | 2p 2.139s/step |
| 训练总时间 | 2p: around 1 week; |
| Checkpoint文件大小 | 991M (.ckpt file) |
| 代码 | [DeepSpeech script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2) |
### Inference Performance
| 参数 | DeepSpeech |
| -------------------------- | ----------------------------------------------------------------|
| 资源 | NV SMX2 V100-32G |
| 更新日期 | 12/29/2020 (month/day/year) |
| MindSpore版本 | 1.0.0 |
| 数据集 | LibriSpeech |
| 批处理大小 | 20 |
| 输出 | 概率值 |
| 精确度(无噪声) | 2p: WER: 9.902 CER: 3.317 8p: WER: 11.593 CER: 3.907|
| 精确度(有噪声) | 2p: WER: 28.693 CER: 12.473 8p: WER: 31.397 CER: 13.696|
| 模型大小 | 330M (.mindir file) |
# [ModelZoo主页](#contents)
[ModelZoo主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -59,8 +59,8 @@ Dataset used: [LibriSpeech](<http://www.openslr.org/12>)
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script Description](#contents)

View File

@ -29,7 +29,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
parser = argparse.ArgumentParser(description='DeepSpeech evaluation')
parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN')
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
parser.add_argument('--pretrain_ckpt', type=str,
default='./checkpoint/ckpt_0/DeepSpeech0-70_1287.ckpt', help='Pretrained checkpoint path')
parser.add_argument('--device_target', type=str, default="GPU", choices=("GPU", "CPU"),
help='Device target, support GPU and CPU, Default: GPU')
args = parser.parse_args()
@ -57,7 +58,6 @@ if __name__ == '__main__':
load_param_into_net(model, param_dict)
print('Successfully loading the pre-trained model')
if config.LMConfig.decoder_type == 'greedy':
decoder = MSGreedyDecoder(labels=labels, blank_index=labels.index('_'))
else:

View File

@ -14,4 +14,4 @@
# limitations under the License.
# ============================================================================
mpirun --allow-run-as-root -n 8 --output-filename log_output --merge-stderr-to-stdout \
python ./train.py --is_distributed --device_target 'GPU' > train.log 2>&1 &
python ./train.py --is_distributed --device_target 'GPU' > train_8p.log 2>&1 &

View File

@ -15,5 +15,6 @@
# ============================================================================
DEVICE_ID=$1
PATH_CHECKPOINT=$2
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --pretrain_ckpt $PATH_CHECKPOINT \
export CUDA_VISIBLE_DEVICES=$DEVICE_ID
python ./eval.py --pretrain_ckpt $PATH_CHECKPOINT \
--device_target 'GPU' > eval.log 2>&1 &

View File

@ -24,7 +24,7 @@ train_config = ed({
},
"DataConfig": {
"train_manifest": 'data/libri_train_manifest.csv',
"train_manifest": '../../deepspeech.pytorch/data/libri_train_manifest.csv',
# "val_manifest": 'data/libri_val_manifest.csv',
"batch_size": 20,
"labels_path": "labels.json",
@ -77,7 +77,7 @@ eval_config = ed({
"verbose": True,
"DataConfig": {
"test_manifest": 'data/libri_test_clean_manifest.csv',
"test_manifest": '../../deepspeech.pytorch/data/libri_test_clean_manifest.csv',
# "test_manifest": 'data/libri_test_other_manifest.csv',
# "test_manifest": 'data/libri_val_manifest.csv',
"batch_size": 20,

View File

@ -16,19 +16,22 @@
Create train or eval dataset.
"""
import math
import numpy as np
import mindspore.dataset.engine as de
import librosa
import soundfile as sf
TRAIN_INPUT_PAD_LENGTH = 1501
TRAIN_INPUT_PAD_LENGTH = 1250
TRAIN_LABEL_PAD_LENGTH = 350
TEST_INPUT_PAD_LENGTH = 3500
class LoadAudioAndTranscript():
"""
parse audio and transcript
"""
def __init__(self,
audio_conf=None,
normalize=False,
@ -89,12 +92,19 @@ class ASRDataset(LoadAudioAndTranscript):
normalize: Apply standard mean and deviation Normalization to audio tensor
batch_size (int): Dataset batch size (default=32)
"""
def __init__(self, audio_conf=None,
manifest_filepath='',
labels=None,
normalize=False,
batch_size=32,
is_training=True):
# with open(manifest_filepath) as f:
# json_file = json.load(f)
#
# self.root_path = json_file.get('root_path')
# wav_txts = json_file.get('samples')
# ids = [list(x.values()) for x in wav_txts]
with open(manifest_filepath) as f:
ids = f.readlines()
@ -117,6 +127,7 @@ class ASRDataset(LoadAudioAndTranscript):
batch_spect, batch_script, target_indices = [], [], []
input_length = np.zeros(batch_size, np.float32)
for data in batch_idx:
# audio_path, transcript_path = os.path.join(self.root_path, data[0]), os.path.join(self.root_path, data[1])
audio_path, transcript_path = data[0], data[1]
spect = self.parse_audio(audio_path)
transcript = self.parse_transcript(transcript_path)
@ -133,12 +144,20 @@ class ASRDataset(LoadAudioAndTranscript):
targets = np.ones((self.batch_size, TRAIN_LABEL_PAD_LENGTH), dtype=np.int32) * self.blank_id
for k, spect_, scripts_ in zip(range(batch_size), batch_spect, batch_script):
seq_length = np.shape(spect_)[1]
input_length[k] = seq_length
# input_length[k] = seq_length
script_length = len(scripts_)
targets[k, :script_length] = scripts_
for m in range(350):
target_indices.append([k, m])
inputs[k, 0, :, 0:seq_length] = spect_
if seq_length <= TRAIN_INPUT_PAD_LENGTH:
input_length[k] = seq_length
inputs[k, 0, :, 0:seq_length] = spect_[:, :seq_length]
else:
maxstart = seq_length - TRAIN_INPUT_PAD_LENGTH
start = np.random.randint(maxstart)
input_length[k] = TRAIN_INPUT_PAD_LENGTH
inputs[k, 0, :, 0:TRAIN_INPUT_PAD_LENGTH] = spect_[:, start:start + TRAIN_INPUT_PAD_LENGTH]
targets = np.reshape(targets, (-1,))
else:
inputs = np.zeros((batch_size, 1, freq_size, TEST_INPUT_PAD_LENGTH), dtype=np.float32)
@ -156,10 +175,12 @@ class ASRDataset(LoadAudioAndTranscript):
def __len__(self):
return self.size
class DistributedSampler():
"""
function to distribute and shuffle sample
"""
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
self.dataset = dataset
self.rank = rank

View File

@ -0,0 +1,155 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval CallBack of Deepspeech2"""
import json
import os
import logging
import numpy as np
from mindspore import save_checkpoint, load_checkpoint
from mindspore.train.callback import Callback
from src.config import eval_config
from src.dataset import create_dataset
from src.deepspeech2 import PredictWithSoftmax, DeepSpeechModel
from src.greedydecoder import MSGreedyDecoder
class SaveCallback(Callback):
"""
EvalCallback body
"""
def __init__(self, path):
super(SaveCallback, self).__init__()
self.logger = logging.getLogger(__name__)
self.init_logger()
self.interval = 5
self.eval_start_epoch = 30
self.config = eval_config
with open(self.config.DataConfig.labels_path) as label_file:
self.labels = json.load(label_file)
self.model = PredictWithSoftmax(DeepSpeechModel(batch_size=self.config.DataConfig.batch_size,
rnn_hidden_size=self.config.ModelConfig.hidden_size,
nb_layers=self.config.ModelConfig.hidden_layers,
labels=self.labels,
rnn_type=self.config.ModelConfig.rnn_type,
audio_conf=self.config.DataConfig.SpectConfig,
bidirectional=True))
self.ds_eval = create_dataset(audio_conf=self.config.DataConfig.SpectConfig,
manifest_filepath=self.config.DataConfig.test_manifest,
labels=self.labels, normalize=True, train_mode=False,
batch_size=self.config.DataConfig.batch_size, rank=0, group_size=1)
self.wer = float('inf')
self.cer = float('inf')
if self.config.LMConfig.decoder_type == 'greedy':
self.decoder = MSGreedyDecoder(
labels=self.labels, blank_index=self.labels.index('_'))
else:
raise NotImplementedError("Only greedy decoder is supported now")
self.target_decoder = MSGreedyDecoder(
self.labels, blank_index=self.labels.index('_'))
self.path = path
def epoch_end(self, run_context):
"""
select ckpt after some epoch
"""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
message = '------------Epoch {} :start eval------------'.format(
cur_epoch)
self.logger.info(message)
if not os.path.exists(self.path):
os.makedirs(self.path)
filename = os.path.join(
self.path, 'Deepspeech2' + '_' + str(cur_epoch) + '.ckpt')
save_checkpoint(save_obj=cb_params.train_network,
ckpt_file_name=filename)
message = '------------Epoch {} :training ckpt saved------------'.format(
cur_epoch)
self.logger.info(message)
load_checkpoint(ckpt_file_name=filename, net=self.model)
message = '------------Epoch {} :training ckpt loaded------------'.format(
cur_epoch)
self.logger.info(message)
total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
output_data = []
for data in self.ds_eval.create_dict_iterator():
inputs, input_length, target_indices, targets = data['inputs'], data['input_length'], data[
'target_indices'], data['label_values']
split_targets = []
start, count, last_id = 0, 0, 0
target_indices, targets = target_indices.asnumpy(), targets.asnumpy()
for i in range(np.shape(targets)[0]):
if target_indices[i, 0] == last_id:
count += 1
else:
split_targets.append(list(targets[start:count]))
last_id += 1
start = count
count += 1
split_targets.append(list(targets[start:]))
out, output_sizes = self.model(inputs, input_length)
decoded_output, _ = self.decoder.decode(out, output_sizes)
target_strings = self.target_decoder.convert_to_strings(
split_targets)
if self.config.save_output is not None:
output_data.append(
(out.asnumpy(), output_sizes.asnumpy(), target_strings))
for doutput, toutput in zip(decoded_output, target_strings):
transcript, reference = doutput[0], toutput[0]
wer_inst = self.decoder.wer(transcript, reference)
cer_inst = self.decoder.cer(transcript, reference)
total_wer += wer_inst
total_cer += cer_inst
num_tokens += len(reference.split())
num_chars += len(reference.replace(' ', ''))
if self.config.verbose:
print("Ref:", reference.lower())
print("Hyp:", transcript.lower())
print("WER:", float(wer_inst) / len(reference.split()),
"CER:", float(cer_inst) / len(reference.replace(' ', '')), "\n")
wer = float(total_wer) / num_tokens
cer = float(total_cer) / num_chars
message = "----------Epoch {} :wer is {}------------".format(
cur_epoch, wer)
self.logger.info(message)
message = "----------Epoch {} :cer is {}------------".format(
cur_epoch, cer)
self.logger.info(message)
if wer < self.wer and cer < self.cer:
self.wer = wer
self.cer = cer
file_name = os.path.join(self.path,
'Deepspeech2' + '_' + str(cur_epoch) + '_' + str(self.wer) + '_' + str(
self.cer) + ".ckpt")
save_checkpoint(save_obj=cb_params.train_network,
ckpt_file_name=file_name)
message = "Save the minimum wer and cer checkpoint,now Epoch {} : and ,the wer is {}, the cer is \
{}".format(cur_epoch, self.wer, self.cer)
self.logger.info(message)
def init_logger(self):
self.logger.setLevel(level=logging.INFO)
handler = logging.FileHandler('eval_callback.log')
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)

View File

@ -14,23 +14,24 @@
# ============================================================================
"""train_criteo."""
import os
import json
import argparse
import json
import os
from mindspore import context, Tensor, ParameterTuple
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.context import ParallelMode
from mindspore.nn import TrainOneStepCell
from mindspore.nn.optim import Adam
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Adam
from mindspore.nn import TrainOneStepCell
from mindspore.train import Model
from src.deepspeech2 import DeepSpeechModel, NetWithLossClass
from src.lr_generator import get_lr
from src.config import train_config
from src.dataset import create_dataset
from src.deepspeech2 import DeepSpeechModel, NetWithLossClass
from src.eval_callback import SaveCallback
from src.lr_generator import get_lr
parser = argparse.ArgumentParser(description='DeepSpeech2 training')
parser.add_argument('--pre_trained_model_path', type=str, default='', help='Pretrained checkpoint path')
@ -41,6 +42,7 @@ parser.add_argument('--device_target', type=str, default="GPU", choices=("GPU",
args = parser.parse_args()
if __name__ == '__main__':
rank_id = 0
group_size = 1
config = train_config
@ -94,12 +96,18 @@ if __name__ == '__main__':
callback_list = [TimeMonitor(steps_size), LossMonitor()]
if args.is_distributed:
config.CheckpointConfig.ckpt_file_name_prefix = config.CheckpointConfig.ckpt_file_name_prefix + str(get_rank())
config.CheckpointConfig.ckpt_path = os.path.join(config.CheckpointConfig.ckpt_path,
'ckpt_' + str(get_rank()) + '/')
config_ck = CheckpointConfig(save_checkpoint_steps=1,
keep_checkpoint_max=config.CheckpointConfig.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=config.CheckpointConfig.ckpt_file_name_prefix,
directory=config.CheckpointConfig.ckpt_path, config=config_ck)
callback_list.append(ckpt_cb)
if rank_id == 0:
callback_update = SaveCallback(config.CheckpointConfig.ckpt_path)
callback_list += [callback_update]
else:
config_ck = CheckpointConfig(save_checkpoint_steps=5,
keep_checkpoint_max=config.CheckpointConfig.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=config.CheckpointConfig.ckpt_file_name_prefix,
directory=config.CheckpointConfig.ckpt_path, config=config_ck)
callback_list.append(ckpt_cb)
print(callback_list)
model.train(config.TrainingConfig.epochs, ds_train, callbacks=callback_list, dataset_sink_mode=data_sink)