commit
1db7bfde62
|
@ -0,0 +1,293 @@
|
|||
# 目录
|
||||
|
||||
[View English](./README.md)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [DeepSpeech2介绍](#DeepSpeech2介绍)
|
||||
- [网络模型结构](#网络模型结构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [文件说明和运行说明](#文件说明和运行说明)
|
||||
- [代码目录结构说明](#代码目录结构说明)
|
||||
- [模型参数](#模型参数)
|
||||
- [训练和推理过程](#训练和推理过程)
|
||||
- [Export](#Export)
|
||||
- [性能](#性能)
|
||||
- [训练性能](#训练性能)
|
||||
- [推理性能](#推理性能)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
# [DeepSpeech2介绍](#contents)
|
||||
|
||||
DeepSpeech2是一个使用 CTC 损失训练的语音识别模型。它用神经网络取代了整个手工设计的管道,可以处理各种各样的语音,包括嘈杂的环境、口音和不同的语言。
|
||||
|
||||
[论文](https://arxiv.org/pdf/1512.02595v1.pdf): Amodei, Dario, et al. Deep speech 2: End-to-end speech recognition in english and mandarin.
|
||||
|
||||
# [网络模型结构](#contents)
|
||||
|
||||
模型包括:
|
||||
|
||||
- 两个卷积层:
|
||||
- 通道数为 32,内核大小为 41, 11 ,步长为 2, 2
|
||||
- 通道数为 32,内核大小为 41, 11 ,步长为 2, 1
|
||||
- 五个双向 LSTM 层(大小为 1024)
|
||||
- 一个投影层【大小为字符数加 1(为CTC空白符号),29】
|
||||
|
||||
# [数据集](#contents)
|
||||
|
||||
可以基于论文中提到的数据集或在相关领域/网络架构中广泛使用的数据集运行脚本。在下面的部分中,我们将介绍如何使用下面的相关数据集运行脚本。
|
||||
|
||||
使用的数据集为: [LibriSpeech](<http://www.openslr.org/12>)
|
||||
|
||||
- 训练集:
|
||||
- train-clean-100: [6.3G] (100小时的无噪音演讲训练集)
|
||||
- train-clean-360.tar.gz [23G] (360小时的无噪音演讲训练集)
|
||||
- train-other-500.tar.gz [30G] (500小时的有噪音演讲训练集)
|
||||
- 验证集:
|
||||
- dev-clean.tar.gz [337M] (无噪音)
|
||||
- dev-other.tar.gz [314M] (有噪音)
|
||||
- 测试集:
|
||||
- test-clean.tar.gz [346M] (测试集, 无噪音)
|
||||
- test-other.tar.gz [328M] (测试集, 有噪音)
|
||||
- 数据格式:wav 和 txt 文件
|
||||
- 注意:数据需要通过librispeech.py进行处理
|
||||
|
||||
# [环境要求](#contents)
|
||||
|
||||
- 硬件(GPU)
|
||||
- GPU处理器
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 通过下面网址可以获得更多信息:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [文件说明和运行说明](#contents)
|
||||
|
||||
## [代码目录结构说明](#contents)
|
||||
|
||||
```path
|
||||
.
|
||||
├── audio
|
||||
├── deepspeech2
|
||||
├── scripts
|
||||
│ ├──run_distribute_train_gpu.sh // gpu8卡训练脚本
|
||||
│ ├──run_eval_cpu.sh // cpu推理脚本
|
||||
│ ├──run_eval_gpu.sh // gpu推理脚本
|
||||
│ ├──run_standalone_train_cpu.sh // cpu单卡训练脚本
|
||||
│ └──run_standalone_train_gpu.sh // gpu单卡训练脚本
|
||||
├── train.py // 训练文件
|
||||
├── eval.py // 推理文件
|
||||
├── export.py // 将mindspore模型转换为mindir模型
|
||||
├── labels.json // 可能映射到的字符
|
||||
├── README.md // DeepSpeech2相关描述
|
||||
├── deepspeech_pytorch //
|
||||
├──decoder.py // 来自第三方代码的解码器(MIT 许可证)
|
||||
├── src
|
||||
├──__init__.py
|
||||
├──DeepSpeech.py // DeepSpeech2网络架构
|
||||
├──dataset.py // 数据处理
|
||||
├──config.py // DeepSpeech配置文件
|
||||
├──lr_generator.py // 产生学习率
|
||||
├──greedydecoder.py // 修改Mindspore代码的greedydecoder
|
||||
└──callback.py // 回调以监控训练
|
||||
```
|
||||
|
||||
## [模型参数](#contents)
|
||||
|
||||
训练和推理的相关参数在`config.py`文件
|
||||
|
||||
```text
|
||||
训练相关参数
|
||||
epochs 训练的epoch数量,默认为70
|
||||
```
|
||||
|
||||
```text
|
||||
数据处理相关参数
|
||||
train_manifest 用于训练的数据文件路径,默认为 'data/libri_train_manifest.csv'
|
||||
val_manifest 用于测试的数据文件路径,默认为 'data/libri_val_manifest.csv'
|
||||
batch_size 批处理大小,默认为8
|
||||
labels_path 模型输出的token json 路径, 默认为 "./labels.json"
|
||||
sample_rate 数据特征的采样率,默认为16000
|
||||
window_size 频谱图生成的窗口大小(秒),默认为0.02
|
||||
window_stride 频谱图生成的窗口步长(秒),默认为0.01
|
||||
window 频谱图生成的窗口类型,默认为 'hamming'
|
||||
speed_volume_perturb 使用随机速度和增益扰动,默认为False,当前模型中未使用
|
||||
spec_augment 在MEL谱图上使用简单的光谱增强,默认为False,当前模型中未使用
|
||||
noise_dir 注入噪音到音频。默认为noise Inject未添加,默认为'',当前模型中未使用
|
||||
noise_prob 每个样本加噪声的概率,默认为0.4,当前模型中未使用
|
||||
noise_min 样本的最小噪音水平,(1.0意味着所有的噪声,不是原始信号),默认是0.0,当前模型中未使用
|
||||
noise_max 样本的最大噪音水平。最大值为1.0,默认值为0.5,当前模型中未使用
|
||||
```
|
||||
|
||||
```text
|
||||
模型相关参数
|
||||
rnn_type 模型中使用的RNN类型,默认为'LSTM',当前只支持LSTM
|
||||
hidden_size RNN层的隐藏大小,默认为1024
|
||||
hidden_layers RNN层的数量,默认为5
|
||||
lookahead_context 查看上下文,默认值是20,当前模型中未使用
|
||||
```
|
||||
|
||||
```text
|
||||
优化器相关参数
|
||||
learning_rate 初始化学习率,默认为3e-4
|
||||
learning_anneal 对每个epoch之后的学习率进行退火,默认为1.1
|
||||
weight_decay 权重衰减,默认为1e-5
|
||||
momentum 动量,默认为0.9
|
||||
eps Adam eps,默认为1e-8
|
||||
betas Adam betas,默认为(0.9, 0.999)
|
||||
loss_scale 损失规模,默认是1024
|
||||
```
|
||||
|
||||
```text
|
||||
checkpoint相关参数
|
||||
ckpt_file_name_prefix ckpt文件的名称前缀,默认为'DeepSpeech'
|
||||
ckpt_path ckpt文件的保存路径,默认为'checkpoints'
|
||||
keep_checkpoint_max ckpt文件的最大数量限制,删除旧的检查点,默认是10
|
||||
```
|
||||
|
||||
# [训练和推理过程](#contents)
|
||||
|
||||
## 训练
|
||||
|
||||
```text
|
||||
运行: train.py [--use_pretrained USE_PRETRAINED]
|
||||
[--pre_trained_model_path PRE_TRAINED_MODEL_PATH]
|
||||
[--is_distributed IS_DISTRIBUTED]
|
||||
[--bidirectional BIDIRECTIONAL]
|
||||
[--device_target DEVICE_TARGET]
|
||||
参数:
|
||||
--pre_trained_model_path 预先训练的模型文件路径,默认为''
|
||||
--is_distributed 多卡训练,默认为False
|
||||
--bidirectional 是否使用双向RNN,默认为True,目前只实现了双向模型
|
||||
--device_target 运行代码的设备:"GPU" | “CPU”,默认为"GPU"
|
||||
```
|
||||
|
||||
## 推理
|
||||
|
||||
```text
|
||||
运行: eval.py [--bidirectional BIDIRECTIONAL]
|
||||
[--pretrain_ckpt PRETRAIN_CKPT]
|
||||
[--device_target DEVICE_TARGET]
|
||||
|
||||
参数:
|
||||
--bidirectional 是否使用双向RNN,默认为True。 目前只实现了双向模型
|
||||
--pretrain_ckpt checkpoint的文件路径, 默认为''
|
||||
--device_target 运行代码的设备:"GPU" | “CPU”,默认为"GPU"
|
||||
```
|
||||
|
||||
在训练之前,应该处理数据集,使用[SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch)中的脚本来处理数据。
|
||||
[SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch)中的脚本文件将自动下载数据集并进行处理。
|
||||
流程结束后,数据目录结构如下:
|
||||
|
||||
```path
|
||||
.
|
||||
├─ LibriSpeech_dataset
|
||||
│ ├── train
|
||||
│ │ ├─ wav
|
||||
│ │ └─ txt
|
||||
│ ├── val
|
||||
│ │ ├─ wav
|
||||
│ │ └─ txt
|
||||
│ ├── test_clean
|
||||
│ │ ├─ wav
|
||||
│ │ └─ txt
|
||||
│ └── test_other
|
||||
│ ├─ wav
|
||||
│ └─ txt
|
||||
└─ libri_test_clean_manifest.csv, libri_test_other_manifest.csv, libri_train_manifest.csv, libri_val_manifest.csv
|
||||
```
|
||||
|
||||
三个*.csv文件存放的是对应数据的绝对路径,得到3个csv文件后,修改src/config.py中的配置。
|
||||
对于训练配置, train_manifest应该配置为`libri_train_manifest.csv`的路径,对于 eval 配置,应该配置为 `libri_test_other_manifest.csv` 或 `libri_train_manifest.csv`,具体取决于评估的数据集。
|
||||
|
||||
```shell
|
||||
...
|
||||
训练配置
|
||||
"DataConfig":{
|
||||
train_manifest:'path_to_csv/libri_train_manifest.csv'
|
||||
}
|
||||
|
||||
评估配置
|
||||
"DataConfig":{
|
||||
train_manifest:'path_to_csv/libri_test_clean_manifest.csv'
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
训练之前,需要安装`librosa` and `Levenshtein`
|
||||
通过官网安装MindSpore并完成数据集处理后,可以开始训练如下:
|
||||
|
||||
```shell
|
||||
|
||||
# gpu单卡训练
|
||||
bash ./scripts/run_standalone_train_gpu.sh [DEVICE_ID]
|
||||
|
||||
# cpu单卡训练
|
||||
bash ./scripts/run_standalone_train_cpu.sh
|
||||
|
||||
# gpu多卡训练
|
||||
bash ./scripts/run_distribute_train_gpu.sh
|
||||
|
||||
```
|
||||
|
||||
进行模型评估需要注意的是:目前在运行脚本之前只支持greedy decoder,可以从[SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch)下载解码器并将
|
||||
deepspeech_pytorch文件放入deepspeech2目录, 之后文件目录将显示为[Script and Sample Code]
|
||||
|
||||
```shell
|
||||
|
||||
# cpu评估
|
||||
bash ./scripts/run_eval_cpu.sh [PATH_CHECKPOINT]
|
||||
|
||||
# gpu评估
|
||||
bash ./scripts/run_eval_gpu.sh [DEVICE_ID] [PATH_CHECKPOINT]
|
||||
|
||||
```
|
||||
|
||||
## [Export](#contents)
|
||||
|
||||
```bash
|
||||
python export.py --pre_trained_model_path='ckpt_path'
|
||||
```
|
||||
|
||||
# [性能](#contents)
|
||||
|
||||
## [训练和测试性能分析](#contents)
|
||||
|
||||
### 训练性能
|
||||
|
||||
| 参数 | DeepSpeech |
|
||||
| -------------------------- | ---------------------------------------------------------------|
|
||||
| 资源 | NV SMX2 V100-32G |
|
||||
| 更新日期 | 12/29/2020 (month/day/year) |
|
||||
| MindSpore版本 | 1.0.0 |
|
||||
| 数据集 | LibriSpeech |
|
||||
| 训练参数 | 2p, epoch=70, steps=5144 * epoch, batch_size = 20, lr=3e-4 |
|
||||
| 优化器 | Adam |
|
||||
| 损失函数 | CTCLoss |
|
||||
| 输出 | 概率值 |
|
||||
| 损失值 | 0.2-0.7 |
|
||||
| 运行速度 | 2p 2.139s/step |
|
||||
| 训练总时间 | 2p: around 1 week; |
|
||||
| Checkpoint文件大小 | 991M (.ckpt file) |
|
||||
| 代码 | [DeepSpeech script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| 参数 | DeepSpeech |
|
||||
| -------------------------- | ----------------------------------------------------------------|
|
||||
| 资源 | NV SMX2 V100-32G |
|
||||
| 更新日期 | 12/29/2020 (month/day/year) |
|
||||
| MindSpore版本 | 1.0.0 |
|
||||
| 数据集 | LibriSpeech |
|
||||
| 批处理大小 | 20 |
|
||||
| 输出 | 概率值 |
|
||||
| 精确度(无噪声) | 2p: WER: 9.902 CER: 3.317 8p: WER: 11.593 CER: 3.907|
|
||||
| 精确度(有噪声) | 2p: WER: 28.693 CER: 12.473 8p: WER: 31.397 CER: 13.696|
|
||||
| 模型大小 | 330M (.mindir file) |
|
||||
|
||||
# [ModelZoo主页](#contents)
|
||||
|
||||
[ModelZoo主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -59,8 +59,8 @@ Dataset used: [LibriSpeech](<http://www.openslr.org/12>)
|
|||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
|
|
|
@ -29,7 +29,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
|
||||
parser = argparse.ArgumentParser(description='DeepSpeech evaluation')
|
||||
parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN')
|
||||
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
|
||||
parser.add_argument('--pretrain_ckpt', type=str,
|
||||
default='./checkpoint/ckpt_0/DeepSpeech0-70_1287.ckpt', help='Pretrained checkpoint path')
|
||||
parser.add_argument('--device_target', type=str, default="GPU", choices=("GPU", "CPU"),
|
||||
help='Device target, support GPU and CPU, Default: GPU')
|
||||
args = parser.parse_args()
|
||||
|
@ -57,7 +58,6 @@ if __name__ == '__main__':
|
|||
load_param_into_net(model, param_dict)
|
||||
print('Successfully loading the pre-trained model')
|
||||
|
||||
|
||||
if config.LMConfig.decoder_type == 'greedy':
|
||||
decoder = MSGreedyDecoder(labels=labels, blank_index=labels.index('_'))
|
||||
else:
|
||||
|
|
|
@ -14,4 +14,4 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
mpirun --allow-run-as-root -n 8 --output-filename log_output --merge-stderr-to-stdout \
|
||||
python ./train.py --is_distributed --device_target 'GPU' > train.log 2>&1 &
|
||||
python ./train.py --is_distributed --device_target 'GPU' > train_8p.log 2>&1 &
|
||||
|
|
|
@ -15,5 +15,6 @@
|
|||
# ============================================================================
|
||||
DEVICE_ID=$1
|
||||
PATH_CHECKPOINT=$2
|
||||
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --pretrain_ckpt $PATH_CHECKPOINT \
|
||||
export CUDA_VISIBLE_DEVICES=$DEVICE_ID
|
||||
python ./eval.py --pretrain_ckpt $PATH_CHECKPOINT \
|
||||
--device_target 'GPU' > eval.log 2>&1 &
|
||||
|
|
|
@ -24,7 +24,7 @@ train_config = ed({
|
|||
},
|
||||
|
||||
"DataConfig": {
|
||||
"train_manifest": 'data/libri_train_manifest.csv',
|
||||
"train_manifest": '../../deepspeech.pytorch/data/libri_train_manifest.csv',
|
||||
# "val_manifest": 'data/libri_val_manifest.csv',
|
||||
"batch_size": 20,
|
||||
"labels_path": "labels.json",
|
||||
|
@ -77,7 +77,7 @@ eval_config = ed({
|
|||
"verbose": True,
|
||||
|
||||
"DataConfig": {
|
||||
"test_manifest": 'data/libri_test_clean_manifest.csv',
|
||||
"test_manifest": '../../deepspeech.pytorch/data/libri_test_clean_manifest.csv',
|
||||
# "test_manifest": 'data/libri_test_other_manifest.csv',
|
||||
# "test_manifest": 'data/libri_val_manifest.csv',
|
||||
"batch_size": 20,
|
||||
|
|
|
@ -16,19 +16,22 @@
|
|||
Create train or eval dataset.
|
||||
"""
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import mindspore.dataset.engine as de
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
|
||||
TRAIN_INPUT_PAD_LENGTH = 1501
|
||||
TRAIN_INPUT_PAD_LENGTH = 1250
|
||||
TRAIN_LABEL_PAD_LENGTH = 350
|
||||
TEST_INPUT_PAD_LENGTH = 3500
|
||||
|
||||
|
||||
class LoadAudioAndTranscript():
|
||||
"""
|
||||
parse audio and transcript
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
audio_conf=None,
|
||||
normalize=False,
|
||||
|
@ -89,12 +92,19 @@ class ASRDataset(LoadAudioAndTranscript):
|
|||
normalize: Apply standard mean and deviation Normalization to audio tensor
|
||||
batch_size (int): Dataset batch size (default=32)
|
||||
"""
|
||||
|
||||
def __init__(self, audio_conf=None,
|
||||
manifest_filepath='',
|
||||
labels=None,
|
||||
normalize=False,
|
||||
batch_size=32,
|
||||
is_training=True):
|
||||
# with open(manifest_filepath) as f:
|
||||
# json_file = json.load(f)
|
||||
#
|
||||
# self.root_path = json_file.get('root_path')
|
||||
# wav_txts = json_file.get('samples')
|
||||
# ids = [list(x.values()) for x in wav_txts]
|
||||
with open(manifest_filepath) as f:
|
||||
ids = f.readlines()
|
||||
|
||||
|
@ -117,6 +127,7 @@ class ASRDataset(LoadAudioAndTranscript):
|
|||
batch_spect, batch_script, target_indices = [], [], []
|
||||
input_length = np.zeros(batch_size, np.float32)
|
||||
for data in batch_idx:
|
||||
# audio_path, transcript_path = os.path.join(self.root_path, data[0]), os.path.join(self.root_path, data[1])
|
||||
audio_path, transcript_path = data[0], data[1]
|
||||
spect = self.parse_audio(audio_path)
|
||||
transcript = self.parse_transcript(transcript_path)
|
||||
|
@ -133,12 +144,20 @@ class ASRDataset(LoadAudioAndTranscript):
|
|||
targets = np.ones((self.batch_size, TRAIN_LABEL_PAD_LENGTH), dtype=np.int32) * self.blank_id
|
||||
for k, spect_, scripts_ in zip(range(batch_size), batch_spect, batch_script):
|
||||
seq_length = np.shape(spect_)[1]
|
||||
input_length[k] = seq_length
|
||||
|
||||
# input_length[k] = seq_length
|
||||
script_length = len(scripts_)
|
||||
targets[k, :script_length] = scripts_
|
||||
for m in range(350):
|
||||
target_indices.append([k, m])
|
||||
inputs[k, 0, :, 0:seq_length] = spect_
|
||||
if seq_length <= TRAIN_INPUT_PAD_LENGTH:
|
||||
input_length[k] = seq_length
|
||||
inputs[k, 0, :, 0:seq_length] = spect_[:, :seq_length]
|
||||
else:
|
||||
maxstart = seq_length - TRAIN_INPUT_PAD_LENGTH
|
||||
start = np.random.randint(maxstart)
|
||||
input_length[k] = TRAIN_INPUT_PAD_LENGTH
|
||||
inputs[k, 0, :, 0:TRAIN_INPUT_PAD_LENGTH] = spect_[:, start:start + TRAIN_INPUT_PAD_LENGTH]
|
||||
targets = np.reshape(targets, (-1,))
|
||||
else:
|
||||
inputs = np.zeros((batch_size, 1, freq_size, TEST_INPUT_PAD_LENGTH), dtype=np.float32)
|
||||
|
@ -156,10 +175,12 @@ class ASRDataset(LoadAudioAndTranscript):
|
|||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
|
||||
class DistributedSampler():
|
||||
"""
|
||||
function to distribute and shuffle sample
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Eval CallBack of Deepspeech2"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
from mindspore import save_checkpoint, load_checkpoint
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
from src.config import eval_config
|
||||
from src.dataset import create_dataset
|
||||
from src.deepspeech2 import PredictWithSoftmax, DeepSpeechModel
|
||||
from src.greedydecoder import MSGreedyDecoder
|
||||
|
||||
|
||||
class SaveCallback(Callback):
|
||||
"""
|
||||
EvalCallback body
|
||||
"""
|
||||
|
||||
def __init__(self, path):
|
||||
|
||||
super(SaveCallback, self).__init__()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.init_logger()
|
||||
self.interval = 5
|
||||
self.eval_start_epoch = 30
|
||||
self.config = eval_config
|
||||
with open(self.config.DataConfig.labels_path) as label_file:
|
||||
self.labels = json.load(label_file)
|
||||
self.model = PredictWithSoftmax(DeepSpeechModel(batch_size=self.config.DataConfig.batch_size,
|
||||
rnn_hidden_size=self.config.ModelConfig.hidden_size,
|
||||
nb_layers=self.config.ModelConfig.hidden_layers,
|
||||
labels=self.labels,
|
||||
rnn_type=self.config.ModelConfig.rnn_type,
|
||||
audio_conf=self.config.DataConfig.SpectConfig,
|
||||
bidirectional=True))
|
||||
self.ds_eval = create_dataset(audio_conf=self.config.DataConfig.SpectConfig,
|
||||
manifest_filepath=self.config.DataConfig.test_manifest,
|
||||
labels=self.labels, normalize=True, train_mode=False,
|
||||
batch_size=self.config.DataConfig.batch_size, rank=0, group_size=1)
|
||||
self.wer = float('inf')
|
||||
self.cer = float('inf')
|
||||
if self.config.LMConfig.decoder_type == 'greedy':
|
||||
self.decoder = MSGreedyDecoder(
|
||||
labels=self.labels, blank_index=self.labels.index('_'))
|
||||
else:
|
||||
raise NotImplementedError("Only greedy decoder is supported now")
|
||||
self.target_decoder = MSGreedyDecoder(
|
||||
self.labels, blank_index=self.labels.index('_'))
|
||||
self.path = path
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
select ckpt after some epoch
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
|
||||
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
|
||||
message = '------------Epoch {} :start eval------------'.format(
|
||||
cur_epoch)
|
||||
self.logger.info(message)
|
||||
if not os.path.exists(self.path):
|
||||
os.makedirs(self.path)
|
||||
filename = os.path.join(
|
||||
self.path, 'Deepspeech2' + '_' + str(cur_epoch) + '.ckpt')
|
||||
save_checkpoint(save_obj=cb_params.train_network,
|
||||
ckpt_file_name=filename)
|
||||
message = '------------Epoch {} :training ckpt saved------------'.format(
|
||||
cur_epoch)
|
||||
self.logger.info(message)
|
||||
load_checkpoint(ckpt_file_name=filename, net=self.model)
|
||||
message = '------------Epoch {} :training ckpt loaded------------'.format(
|
||||
cur_epoch)
|
||||
self.logger.info(message)
|
||||
total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
|
||||
output_data = []
|
||||
for data in self.ds_eval.create_dict_iterator():
|
||||
inputs, input_length, target_indices, targets = data['inputs'], data['input_length'], data[
|
||||
'target_indices'], data['label_values']
|
||||
split_targets = []
|
||||
start, count, last_id = 0, 0, 0
|
||||
target_indices, targets = target_indices.asnumpy(), targets.asnumpy()
|
||||
for i in range(np.shape(targets)[0]):
|
||||
if target_indices[i, 0] == last_id:
|
||||
count += 1
|
||||
else:
|
||||
split_targets.append(list(targets[start:count]))
|
||||
last_id += 1
|
||||
start = count
|
||||
count += 1
|
||||
split_targets.append(list(targets[start:]))
|
||||
out, output_sizes = self.model(inputs, input_length)
|
||||
decoded_output, _ = self.decoder.decode(out, output_sizes)
|
||||
target_strings = self.target_decoder.convert_to_strings(
|
||||
split_targets)
|
||||
|
||||
if self.config.save_output is not None:
|
||||
output_data.append(
|
||||
(out.asnumpy(), output_sizes.asnumpy(), target_strings))
|
||||
for doutput, toutput in zip(decoded_output, target_strings):
|
||||
transcript, reference = doutput[0], toutput[0]
|
||||
wer_inst = self.decoder.wer(transcript, reference)
|
||||
cer_inst = self.decoder.cer(transcript, reference)
|
||||
total_wer += wer_inst
|
||||
total_cer += cer_inst
|
||||
num_tokens += len(reference.split())
|
||||
num_chars += len(reference.replace(' ', ''))
|
||||
if self.config.verbose:
|
||||
print("Ref:", reference.lower())
|
||||
print("Hyp:", transcript.lower())
|
||||
print("WER:", float(wer_inst) / len(reference.split()),
|
||||
"CER:", float(cer_inst) / len(reference.replace(' ', '')), "\n")
|
||||
wer = float(total_wer) / num_tokens
|
||||
cer = float(total_cer) / num_chars
|
||||
message = "----------Epoch {} :wer is {}------------".format(
|
||||
cur_epoch, wer)
|
||||
self.logger.info(message)
|
||||
message = "----------Epoch {} :cer is {}------------".format(
|
||||
cur_epoch, cer)
|
||||
self.logger.info(message)
|
||||
if wer < self.wer and cer < self.cer:
|
||||
self.wer = wer
|
||||
self.cer = cer
|
||||
file_name = os.path.join(self.path,
|
||||
'Deepspeech2' + '_' + str(cur_epoch) + '_' + str(self.wer) + '_' + str(
|
||||
self.cer) + ".ckpt")
|
||||
save_checkpoint(save_obj=cb_params.train_network,
|
||||
ckpt_file_name=file_name)
|
||||
message = "Save the minimum wer and cer checkpoint,now Epoch {} : and ,the wer is {}, the cer is \
|
||||
{}".format(cur_epoch, self.wer, self.cer)
|
||||
self.logger.info(message)
|
||||
|
||||
def init_logger(self):
|
||||
self.logger.setLevel(level=logging.INFO)
|
||||
handler = logging.FileHandler('eval_callback.log')
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
|
@ -14,23 +14,24 @@
|
|||
# ============================================================================
|
||||
|
||||
"""train_criteo."""
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
from mindspore import context, Tensor, ParameterTuple
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.train import Model
|
||||
|
||||
from src.deepspeech2 import DeepSpeechModel, NetWithLossClass
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import train_config
|
||||
from src.dataset import create_dataset
|
||||
from src.deepspeech2 import DeepSpeechModel, NetWithLossClass
|
||||
from src.eval_callback import SaveCallback
|
||||
from src.lr_generator import get_lr
|
||||
|
||||
parser = argparse.ArgumentParser(description='DeepSpeech2 training')
|
||||
parser.add_argument('--pre_trained_model_path', type=str, default='', help='Pretrained checkpoint path')
|
||||
|
@ -41,6 +42,7 @@ parser.add_argument('--device_target', type=str, default="GPU", choices=("GPU",
|
|||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
rank_id = 0
|
||||
group_size = 1
|
||||
config = train_config
|
||||
|
@ -94,12 +96,18 @@ if __name__ == '__main__':
|
|||
callback_list = [TimeMonitor(steps_size), LossMonitor()]
|
||||
|
||||
if args.is_distributed:
|
||||
config.CheckpointConfig.ckpt_file_name_prefix = config.CheckpointConfig.ckpt_file_name_prefix + str(get_rank())
|
||||
config.CheckpointConfig.ckpt_path = os.path.join(config.CheckpointConfig.ckpt_path,
|
||||
'ckpt_' + str(get_rank()) + '/')
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=1,
|
||||
keep_checkpoint_max=config.CheckpointConfig.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix=config.CheckpointConfig.ckpt_file_name_prefix,
|
||||
directory=config.CheckpointConfig.ckpt_path, config=config_ck)
|
||||
callback_list.append(ckpt_cb)
|
||||
if rank_id == 0:
|
||||
|
||||
callback_update = SaveCallback(config.CheckpointConfig.ckpt_path)
|
||||
callback_list += [callback_update]
|
||||
else:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=5,
|
||||
keep_checkpoint_max=config.CheckpointConfig.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix=config.CheckpointConfig.ckpt_file_name_prefix,
|
||||
directory=config.CheckpointConfig.ckpt_path, config=config_ck)
|
||||
|
||||
callback_list.append(ckpt_cb)
|
||||
print(callback_list)
|
||||
model.train(config.TrainingConfig.epochs, ds_train, callbacks=callback_list, dataset_sink_mode=data_sink)
|
||||
|
|
Loading…
Reference in New Issue