deepspeech
This commit is contained in:
parent
142f9c2d3e
commit
b6b2239ffe
|
@ -76,6 +76,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
|
|||
- [AutoDis](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/autodis/README.md)
|
||||
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
|
||||
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
|
||||
- [DeepSpeech2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2/README.md)
|
||||
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)
|
||||
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md)
|
||||
- [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md)
|
||||
|
|
|
@ -0,0 +1,262 @@
|
|||
# Contents
|
||||
|
||||
- [DeepSpeech2 Description](#CenterNet-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training and eval Process](#training-process)
|
||||
- [Export MindIR](#convert-process)
|
||||
- [Convert](#convert)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [DeepSpeech2 Description](#contents)
|
||||
|
||||
DeepSpeech2 is a speech recognition models which is trained with CTC loss. It replaces entire pipelines of hand-engineered components with neural networks and can handle a diverse variety of speech including noisy
|
||||
environments, accents and different languages. We support training and evaluation on GPU.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1512.02595v1.pdf): Amodei, Dario, et al. Deep speech 2: End-to-end speech recognition in english and mandarin.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
The current reproduced model consists of:
|
||||
|
||||
- two convolutional layers:
|
||||
- number of channels is 32, kernel size is [41, 11], stride is [2, 2]
|
||||
- number of channels is 32, kernel size is [41, 11], stride is [2, 1]
|
||||
- five bidirectional LSTM layers (size is 1024)
|
||||
- one projection layer (size is number of characters plus 1 for CTC blank symbol, 29)
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [LibriSpeech](<http://www.openslr.org/12>)
|
||||
|
||||
- Train Data:
|
||||
- train-clean-100: [6.3G] (training set of 100 hours "clean" speech)
|
||||
- train-clean-360.tar.gz [23G] (training set of 360 hours "clean" speech)
|
||||
- train-other-500.tar.gz [30G] (training set of 500 hours "other" speech)
|
||||
- Val Data:
|
||||
- dev-clean.tar.gz [337M] (development set, "clean" speech)
|
||||
- dev-other.tar.gz [314M] (development set, "other", more challenging, speech)
|
||||
- Test Data:
|
||||
- test-clean.tar.gz [346M] (test set, "clean" speech )
|
||||
- test-other.tar.gz [328M] (test set, "other" speech )
|
||||
- Data format:wav and txt files
|
||||
- Note:Data will be processed in librispeech.py
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://cmc-szv.clouddragon.huawei.com/cmcversion/index/search?searchKey=Do-MindSpore%20V100R001C00B622)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```path
|
||||
.
|
||||
├── audio
|
||||
├── deepspeech2
|
||||
├── train.py // training scripts
|
||||
├── eval.py // testing and evaluation outputs
|
||||
├── export.py // convert mindspore model to mindir model
|
||||
├── labels.json // possible characters to map to
|
||||
├── README.md // descriptions about DeepSpeech
|
||||
├── deepspeech_pytorch //
|
||||
├──decoder.py // decoder from third party codes(MIT License)
|
||||
├── src
|
||||
├──__init__.py
|
||||
├──DeepSpeech.py // DeepSpeech networks
|
||||
├──dataset.py // generate dataloader and data processing entry
|
||||
├──config.py // DeepSpeech configs
|
||||
├──lr_generator.py // learning rate generator
|
||||
├──greedydecoder.py // modified greedydecoder for mindspore code
|
||||
└──callback.py // callbacks to monitor the training
|
||||
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
```text
|
||||
usage: train.py [--use_pretrained USE_PRETRAINED]
|
||||
[--pre_trained_model_path PRE_TRAINED_MODEL_PATH]
|
||||
[--is_distributed IS_DISTRIBUTED]
|
||||
[--bidirectional BIDIRECTIONAL]
|
||||
options:
|
||||
--pre_trained_model_path pretrained checkpoint path, default is ''
|
||||
--is_distributed distributed training, default is False
|
||||
--bidirectional whether or not to use bidirectional RNN, default is True. Currently, only bidirectional model is implemented
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```text
|
||||
usage: eval.py [--bidirectional BIDIRECTIONAL]
|
||||
[--pretrain_ckpt PRETRAIN_CKPT]
|
||||
|
||||
options:
|
||||
--bidirectional whether to use bidirectional RNN, default is True. Currently, only bidirectional model is implemented
|
||||
--pretrain_ckpt saved checkpoint path, default is ''
|
||||
```
|
||||
|
||||
### Options and Parameters
|
||||
|
||||
Parameters for training and evaluation can be set in file `config.py`
|
||||
|
||||
```text
|
||||
config for training.
|
||||
epochs number of training epoch, default is 70
|
||||
```
|
||||
|
||||
```text
|
||||
config for dataloader.
|
||||
train_manifest train manifest path, default is 'data/libri_train_manifest.csv'
|
||||
val_manifest dev manifest path, default is 'data/libri_val_manifest.csv'
|
||||
batch_size batch size for training, default is 8
|
||||
labels_path tokens json path for model output, default is "./labels.json"
|
||||
sample_rate sample rate for the data/model features, default is 16000
|
||||
window_size window size for spectrogram generation (seconds), default is 0.02
|
||||
window_stride window stride for spectrogram generation (seconds), default is 0.01
|
||||
window window type for spectrogram generation, default is 'hamming'
|
||||
speed_volume_perturb use random tempo and gain perturbations, default is False, not used in current model
|
||||
spec_augment use simple spectral augmentation on mel spectograms, default is False, not used in current model
|
||||
noise_dir directory to inject noise into audio. If default, noise Inject not added, default is '', not used in current model
|
||||
noise_prob probability of noise being added per sample, default is 0.4, not used in current model
|
||||
noise_min minimum noise level to sample from. (1.0 means all noise, not original signal), default is 0.0, not used in current model
|
||||
noise_max maximum noise levels to sample from. Maximum 1.0, default is 0.5, not used in current model
|
||||
```
|
||||
|
||||
```text
|
||||
config for model.
|
||||
rnn_type type of RNN to use in model, default is 'LSTM'. Currently, only LSTM is supported
|
||||
hidden_size hidden size of RNN Layer, default is 1024
|
||||
hidden_layers number of RNN layers, default is 5
|
||||
lookahead_context look ahead context, default is 20, not used in current model
|
||||
```
|
||||
|
||||
```text
|
||||
config for optimizer.
|
||||
learning_rate initial learning rate, default is 3e-4
|
||||
learning_anneal annealing applied to learning rate after each epoch, default is 1.1
|
||||
weight_decay weight decay, default is 1e-5
|
||||
momentum momentum, default is 0.9
|
||||
eps Adam eps, default is 1e-8
|
||||
betas Adam betas, default is (0.9, 0.999)
|
||||
loss_scale loss scale, default is 1024
|
||||
```
|
||||
|
||||
```text
|
||||
config for checkpoint.
|
||||
ckpt_file_name_prefix ckpt_file_name_prefix, default is 'DeepSpeech'
|
||||
ckpt_path path to save ckpt, default is 'checkpoints'
|
||||
keep_checkpoint_max max number of checkpoints to save, delete older checkpoints, default is 10
|
||||
```
|
||||
|
||||
# [Training and Eval process](#contents)
|
||||
|
||||
Before training, the dataset should be processed. We use the scripts provided by [SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch) to process the dataset.
|
||||
This script in [SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch) will automatically download the dataset and process it. After the process, the
|
||||
dataset directory structure is as follows:
|
||||
|
||||
```path
|
||||
.
|
||||
├─ LibriSpeech_dataset
|
||||
│ ├── train
|
||||
│ │ ├─ wav
|
||||
│ │ └─ txt
|
||||
│ ├── val
|
||||
│ │ ├─ wav
|
||||
│ │ └─ txt
|
||||
│ ├── test_clean
|
||||
│ │ ├─ wav
|
||||
│ │ └─ txt
|
||||
│ └── test_other
|
||||
│ ├─ wav
|
||||
│ └─ txt
|
||||
└─ libri_test_clean_manifest.csv, libri_test_other_manifest.csv, libri_train_manifest.csv, libri_val_manifest.csv
|
||||
```
|
||||
|
||||
The three *.csv file stores the absolute path of the corresponding
|
||||
data. The three *.csv files will be used in training and evaluation process.
|
||||
After installing MindSpore via the official website and finishing dataset processing, you can start training as follows:
|
||||
|
||||
```shell
|
||||
|
||||
# standalone training
|
||||
CUDA_VISIBLE_DEVICES='0' python train.py
|
||||
|
||||
# distributed training
|
||||
CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' mpirun --allow-run-as-root -n 8 python train.py --is_distributed=True > log 2>&1 &
|
||||
|
||||
```
|
||||
|
||||
The following script is used to evaluate the model. Note we only support greedy decoder now and before run the script,
|
||||
you should download the decoder code from [SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch) and place
|
||||
deepspeech_pytorch into deepspeech2 directory. After that, the file directory will be displayed as that in [Script and Sample Code]
|
||||
|
||||
```shell
|
||||
|
||||
# eval
|
||||
CUDA_VISIBLE_DEVICES='0' python eval.py --pretrain_ckpt='saved_model_path'
|
||||
```
|
||||
|
||||
## [Export MindIR](#contents)
|
||||
|
||||
```bash
|
||||
python export.py --pre_trained_model_path='ckpt_path'
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | DeepSpeech |
|
||||
| -------------------------- | ---------------------------------------------------------------|
|
||||
| Resource | NV SMX2 V100-32G |
|
||||
| uploaded Date | 12/29/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | LibriSpeech |
|
||||
| Training Parameters | 2p, epoch=70, steps=5144 * epoch, batch_size = 20, lr=3e-4 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | CTCLoss |
|
||||
| outputs | probability |
|
||||
| Loss | 0.2-0.7 |
|
||||
| Speed | 2p 2.139s/step |
|
||||
| Total time: training | 2p: around 1 week; |
|
||||
| Checkpoint | 991M (.ckpt file) |
|
||||
| Scripts | [DeepSpeech script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | DeepSpeech |
|
||||
| -------------------------- | ----------------------------------------------------------------|
|
||||
| Resource | NV SMX2 V100-32G |
|
||||
| uploaded Date | 12/29/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | LibriSpeech |
|
||||
| batch_size | 20 |
|
||||
| outputs | probability |
|
||||
| Accuracy(test-clean) | WER: 9.732 CER: 3.270|
|
||||
| Accuracy(test-others) | WER: 28.198 CER: 12.253|
|
||||
| Model for inference | 330M (.mindir file) |
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
|
||||
"""
|
||||
Eval DeepSpeech2
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import pickle
|
||||
import numpy as np
|
||||
from src.config import eval_config
|
||||
from src.deepspeech2 import DeepSpeechModel, PredictWithSoftmax
|
||||
from src.dataset import create_dataset
|
||||
from src.greedydecoder import MSGreedyDecoder
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
|
||||
|
||||
parser = argparse.ArgumentParser(description='DeepSpeech evaluation')
|
||||
parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN')
|
||||
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = eval_config
|
||||
with open(config.DataConfig.labels_path) as label_file:
|
||||
labels = json.load(label_file)
|
||||
|
||||
model = PredictWithSoftmax(DeepSpeechModel(batch_size=config.DataConfig.batch_size,
|
||||
rnn_hidden_size=config.ModelConfig.hidden_size,
|
||||
nb_layers=config.ModelConfig.hidden_layers,
|
||||
labels=labels,
|
||||
rnn_type=config.ModelConfig.rnn_type,
|
||||
audio_conf=config.DataConfig.SpectConfig,
|
||||
bidirectional=args.bidirectional))
|
||||
|
||||
ds_eval = create_dataset(audio_conf=config.DataConfig.SpectConfig,
|
||||
manifest_filepath=config.DataConfig.test_manifest,
|
||||
labels=labels, normalize=True, train_mode=False,
|
||||
batch_size=config.DataConfig.batch_size, rank=0, group_size=1)
|
||||
|
||||
param_dict = load_checkpoint(args.pretrain_ckpt)
|
||||
load_param_into_net(model, param_dict)
|
||||
print('Successfully loading the pre-trained model')
|
||||
|
||||
|
||||
if config.LMConfig.decoder_type == 'greedy':
|
||||
decoder = MSGreedyDecoder(labels=labels, blank_index=labels.index('_'))
|
||||
else:
|
||||
raise NotImplementedError("Only greedy decoder is supported now")
|
||||
target_decoder = MSGreedyDecoder(labels, blank_index=labels.index('_'))
|
||||
|
||||
model.set_train(False)
|
||||
total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
|
||||
output_data = []
|
||||
for data in ds_eval.create_dict_iterator():
|
||||
inputs, input_length, target_indices, targets = data['inputs'], data['input_length'], data['target_indices'], \
|
||||
data['label_values']
|
||||
|
||||
split_targets = []
|
||||
start, count, last_id = 0, 0, 0
|
||||
target_indices, targets = target_indices.asnumpy(), targets.asnumpy()
|
||||
for i in range(np.shape(targets)[0]):
|
||||
if target_indices[i, 0] == last_id:
|
||||
count += 1
|
||||
else:
|
||||
split_targets.append(list(targets[start:count]))
|
||||
last_id += 1
|
||||
start = count
|
||||
count += 1
|
||||
out, output_sizes = model(inputs, input_length)
|
||||
decoded_output, _ = decoder.decode(out, output_sizes)
|
||||
target_strings = target_decoder.convert_to_strings(split_targets)
|
||||
|
||||
if config.save_output is not None:
|
||||
output_data.append((out.asnumpy(), output_sizes.asnumpy(), target_strings))
|
||||
for doutput, toutput in zip(decoded_output, target_strings):
|
||||
transcript, reference = doutput[0], toutput[0]
|
||||
wer_inst = decoder.wer(transcript, reference)
|
||||
cer_inst = decoder.cer(transcript, reference)
|
||||
total_wer += wer_inst
|
||||
total_cer += cer_inst
|
||||
num_tokens += len(reference.split())
|
||||
num_chars += len(reference.replace(' ', ''))
|
||||
if config.verbose:
|
||||
print("Ref:", reference.lower())
|
||||
print("Hyp:", transcript.lower())
|
||||
print("WER:", float(wer_inst) / len(reference.split()),
|
||||
"CER:", float(cer_inst) / len(reference.replace(' ', '')), "\n")
|
||||
wer = float(total_wer) / num_tokens
|
||||
cer = float(total_cer) / num_chars
|
||||
|
||||
print('Test Summary \t'
|
||||
'Average WER {wer:.3f}\t'
|
||||
'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))
|
||||
|
||||
if config.save_output is not None:
|
||||
with open(config.save_output + '.bin', 'wb') as output:
|
||||
pickle.dump(output_data, output)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
export checkpoint file to mindir model
|
||||
"""
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from src.deepspeech2 import DeepSpeechModel
|
||||
from src.config import train_config
|
||||
|
||||
parser = argparse.ArgumentParser(description='Export DeepSpeech model to Mindir')
|
||||
parser.add_argument('--pre_trained_model_path', type=str, default='', help=' existed checkpoint path')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = train_config
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
|
||||
with open(config.DataConfig.labels_path) as label_file:
|
||||
labels = json.load(label_file)
|
||||
|
||||
deepspeech_net = DeepSpeechModel(batch_size=1,
|
||||
rnn_hidden_size=config.ModelConfig.hidden_size,
|
||||
nb_layers=config.ModelConfig.hidden_layers,
|
||||
labels=labels,
|
||||
rnn_type=config.ModelConfig.rnn_type,
|
||||
audio_conf=config.DataConfig.SpectConfig,
|
||||
bidirectional=True)
|
||||
|
||||
param_dict = load_checkpoint(args.pre_trained_model_path)
|
||||
load_param_into_net(deepspeech_net, param_dict)
|
||||
print('Successfully loading the pre-trained model')
|
||||
# 3500 is the max length in evaluation dataset(LibriSpeech). This is consistent with that in dataset.py
|
||||
# The length is fixed to this value because Mindspore does not support dynamic shape currently
|
||||
input_np = np.random.uniform(0.0, 1.0, size=[1, 1, 161, 3500]).astype(np.float32)
|
||||
length = np.array([15], dtype=np.int32)
|
||||
export(deepspeech_net, Tensor(input_np), Tensor(length), file_name="deepspeech2.mindir", file_format='MINDIR')
|
|
@ -0,0 +1,31 @@
|
|||
[
|
||||
"'",
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"D",
|
||||
"E",
|
||||
"F",
|
||||
"G",
|
||||
"H",
|
||||
"I",
|
||||
"J",
|
||||
"K",
|
||||
"L",
|
||||
"M",
|
||||
"N",
|
||||
"O",
|
||||
"P",
|
||||
"Q",
|
||||
"R",
|
||||
"S",
|
||||
"T",
|
||||
"U",
|
||||
"V",
|
||||
"W",
|
||||
"X",
|
||||
"Y",
|
||||
"Z",
|
||||
" ",
|
||||
"_"
|
||||
]
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Defined callback for DeepSpeech.
|
||||
"""
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore import Tensor
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TimeMonitor(Callback):
|
||||
"""
|
||||
Time monitor for calculating cost of each epoch.
|
||||
Args
|
||||
data_size (int) step size of an epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_size):
|
||||
super(TimeMonitor, self).__init__()
|
||||
self.data_size = data_size
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / self.data_size
|
||||
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
print(f"step time {step_mseconds}", flush=True)
|
||||
|
||||
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
||||
Args:
|
||||
lr_init (numpy array): train lr
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def __init__(self, lr_init=None):
|
||||
super(Monitor, self).__init__()
|
||||
self.lr_init = lr_init
|
||||
self.lr_init_len = len(lr_init)
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.losses = []
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
epoch_mseconds = (time.time() - self.epoch_time)
|
||||
per_step_mseconds = epoch_mseconds / cb_params.batch_num
|
||||
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
|
||||
per_step_mseconds,
|
||||
np.mean(self.losses)))
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
|
||||
Args:
|
||||
run_context:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
step_mseconds = (time.time() - self.step_time)
|
||||
step_loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
|
||||
step_loss = step_loss[0]
|
||||
if isinstance(step_loss, Tensor):
|
||||
step_loss = np.mean(step_loss.asnumpy())
|
||||
|
||||
self.losses.append(step_loss)
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
|
||||
|
||||
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:.9f}]".format(
|
||||
cb_params.cur_epoch_num -
|
||||
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
|
||||
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1].asnumpy()))
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
train_config = ed({
|
||||
|
||||
"TrainingConfig": {
|
||||
"epochs": 70,
|
||||
},
|
||||
|
||||
"DataConfig": {
|
||||
"train_manifest": 'data/libri_train_manifest.csv',
|
||||
# "val_manifest": 'data/libri_val_manifest.csv',
|
||||
"batch_size": 20,
|
||||
"labels_path": "labels.json",
|
||||
|
||||
"SpectConfig": {
|
||||
"sample_rate": 16000,
|
||||
"window_size": 0.02,
|
||||
"window_stride": 0.01,
|
||||
"window": "hamming"
|
||||
},
|
||||
|
||||
"AugmentationConfig": {
|
||||
"speed_volume_perturb": False,
|
||||
"spec_augment": False,
|
||||
"noise_dir": '',
|
||||
"noise_prob": 0.4,
|
||||
"noise_min": 0.0,
|
||||
"noise_max": 0.5,
|
||||
}
|
||||
},
|
||||
|
||||
"ModelConfig": {
|
||||
"rnn_type": "LSTM",
|
||||
"hidden_size": 1024,
|
||||
"hidden_layers": 5,
|
||||
"lookahead_context": 20,
|
||||
},
|
||||
|
||||
"OptimConfig": {
|
||||
"learning_rate": 3e-4,
|
||||
"learning_anneal": 1.1,
|
||||
"weight_decay": 1e-5,
|
||||
"momentum": 0.9,
|
||||
"eps": 1e-8,
|
||||
"betas": (0.9, 0.999),
|
||||
"loss_scale": 1024,
|
||||
"epsilon": 0.00001
|
||||
},
|
||||
|
||||
"CheckpointConfig": {
|
||||
"ckpt_file_name_prefix": 'DeepSpeech',
|
||||
"ckpt_path": './checkpoint',
|
||||
"keep_checkpoint_max": 10
|
||||
}
|
||||
})
|
||||
|
||||
eval_config = ed({
|
||||
|
||||
"save_output": 'librispeech_val_output',
|
||||
"verbose": True,
|
||||
|
||||
"DataConfig": {
|
||||
"test_manifest": 'data/libri_test_clean_manifest.csv',
|
||||
# "test_manifest": 'data/libri_test_other_manifest.csv',
|
||||
# "test_manifest": 'data/libri_val_manifest.csv',
|
||||
"batch_size": 20,
|
||||
"labels_path": "labels.json",
|
||||
|
||||
"SpectConfig": {
|
||||
"sample_rate": 16000,
|
||||
"window_size": 0.02,
|
||||
"window_stride": 0.01,
|
||||
"window": "hanning"
|
||||
},
|
||||
},
|
||||
|
||||
"ModelConfig": {
|
||||
"rnn_type": "LSTM",
|
||||
"hidden_size": 1024,
|
||||
"hidden_layers": 5,
|
||||
"lookahead_context": 20,
|
||||
},
|
||||
|
||||
"LMConfig": {
|
||||
"decoder_type": "greedy",
|
||||
"lm_path": './3-gram.pruned.3e-7.arpa',
|
||||
"top_paths": 1,
|
||||
"alpha": 1.818182,
|
||||
"beta": 0,
|
||||
"cutoff_top_n": 40,
|
||||
"cutoff_prob": 1.0,
|
||||
"beam_width": 1024,
|
||||
"lm_workers": 4
|
||||
},
|
||||
|
||||
})
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Create train or eval dataset.
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.dataset.engine as de
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
|
||||
TRAIN_INPUT_PAD_LENGTH = 1501
|
||||
TRAIN_LABEL_PAD_LENGTH = 350
|
||||
TEST_INPUT_PAD_LENGTH = 3500
|
||||
|
||||
class LoadAudioAndTranscript():
|
||||
"""
|
||||
parse audio and transcript
|
||||
"""
|
||||
def __init__(self,
|
||||
audio_conf=None,
|
||||
normalize=False,
|
||||
labels=None):
|
||||
super(LoadAudioAndTranscript, self).__init__()
|
||||
self.window_stride = audio_conf.window_stride
|
||||
self.window_size = audio_conf.window_size
|
||||
self.sample_rate = audio_conf.sample_rate
|
||||
self.window = audio_conf.window
|
||||
self.is_normalization = normalize
|
||||
self.labels = labels
|
||||
|
||||
def load_audio(self, path):
|
||||
"""
|
||||
load audio
|
||||
"""
|
||||
sound, _ = sf.read(path, dtype='int16')
|
||||
sound = sound.astype('float32') / 32767
|
||||
if len(sound.shape) > 1:
|
||||
if sound.shape[1] == 1:
|
||||
sound = sound.squeeze()
|
||||
else:
|
||||
sound = sound.mean(axis=1)
|
||||
return sound
|
||||
|
||||
def parse_audio(self, audio_path):
|
||||
"""
|
||||
parse audio
|
||||
"""
|
||||
audio = self.load_audio(audio_path)
|
||||
n_fft = int(self.sample_rate * self.window_size)
|
||||
win_length = n_fft
|
||||
hop_length = int(self.sample_rate * self.window_stride)
|
||||
D = librosa.stft(y=audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=self.window)
|
||||
mag, _ = librosa.magphase(D)
|
||||
mag = np.log1p(mag)
|
||||
if self.is_normalization:
|
||||
mean = mag.mean()
|
||||
std = mag.std()
|
||||
mag = (mag - mean) / std
|
||||
return mag
|
||||
|
||||
def parse_transcript(self, transcript_path):
|
||||
with open(transcript_path, 'r', encoding='utf8') as transcript_file:
|
||||
transcript = transcript_file.read().replace('\n', '')
|
||||
transcript = list(filter(None, [self.labels.get(x) for x in list(transcript)]))
|
||||
return transcript
|
||||
|
||||
|
||||
class ASRDataset(LoadAudioAndTranscript):
|
||||
"""
|
||||
create ASRDataset
|
||||
|
||||
Args:
|
||||
audio_conf: Config containing the sample rate, window and the window length/stride in seconds
|
||||
manifest_filepath (str): manifest_file path.
|
||||
labels (list): List containing all the possible characters to map to
|
||||
normalize: Apply standard mean and deviation normalization to audio tensor
|
||||
batch_size (int): Dataset batch size (default=32)
|
||||
"""
|
||||
def __init__(self, audio_conf=None,
|
||||
manifest_filepath='',
|
||||
labels=None,
|
||||
normalize=False,
|
||||
batch_size=32,
|
||||
is_training=True):
|
||||
with open(manifest_filepath) as f:
|
||||
ids = f.readlines()
|
||||
|
||||
ids = [x.strip().split(',') for x in ids]
|
||||
self.is_training = is_training
|
||||
self.ids = ids
|
||||
self.blank_id = int(labels.index('_'))
|
||||
self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)]
|
||||
if len(self.ids) % batch_size != 0:
|
||||
self.bins = self.bins[:-1]
|
||||
self.bins.append(ids[-batch_size:])
|
||||
self.size = len(self.bins)
|
||||
self.batch_size = batch_size
|
||||
self.labels_map = {labels[i]: i for i in range(len(labels))}
|
||||
super(ASRDataset, self).__init__(audio_conf, normalize, self.labels_map)
|
||||
|
||||
def __getitem__(self, index):
|
||||
batch_idx = self.bins[index]
|
||||
batch_size = len(batch_idx)
|
||||
batch_spect, batch_script, target_indices = [], [], []
|
||||
input_length = np.zeros(batch_size, np.int32)
|
||||
for data in batch_idx:
|
||||
audio_path, transcript_path = data[0], data[1]
|
||||
spect = self.parse_audio(audio_path)
|
||||
transcript = self.parse_transcript(transcript_path)
|
||||
batch_spect.append(spect)
|
||||
batch_script.append(transcript)
|
||||
freq_size = np.shape(batch_spect[-1])[0]
|
||||
|
||||
if self.is_training:
|
||||
# 1501 is the max length in train dataset(LibriSpeech).
|
||||
# The length is fixed to this value because Mindspore does not support dynamic shape currently
|
||||
inputs = np.zeros((batch_size, 1, freq_size, TRAIN_INPUT_PAD_LENGTH), dtype=np.float32)
|
||||
# The target length is fixed to this value because Mindspore does not support dynamic shape currently
|
||||
# 350 may be greater than the max length of labels in train dataset(LibriSpeech).
|
||||
targets = np.ones((self.batch_size, TRAIN_LABEL_PAD_LENGTH), dtype=np.int32) * self.blank_id
|
||||
for k, spect_, scripts_ in zip(range(batch_size), batch_spect, batch_script):
|
||||
seq_length = np.shape(spect_)[1]
|
||||
input_length[k] = seq_length
|
||||
script_length = len(scripts_)
|
||||
targets[k, :script_length] = scripts_
|
||||
for m in range(350):
|
||||
target_indices.append([k, m])
|
||||
inputs[k, 0, :, 0:seq_length] = spect_
|
||||
targets = np.reshape(targets, (-1,))
|
||||
else:
|
||||
inputs = np.zeros((batch_size, 1, freq_size, TEST_INPUT_PAD_LENGTH), dtype=np.float32)
|
||||
targets = []
|
||||
for k, spect_, scripts_ in zip(range(batch_size), batch_spect, batch_script):
|
||||
seq_length = np.shape(spect_)[1]
|
||||
input_length[k] = seq_length
|
||||
targets.extend(scripts_)
|
||||
for m in range(len(scripts_)):
|
||||
target_indices.append([k, m])
|
||||
inputs[k, 0, :, 0:seq_length] = spect_
|
||||
|
||||
return inputs, input_length, np.array(target_indices, dtype=np.int64), np.array(targets, dtype=np.int32)
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
class DistributedSampler():
|
||||
"""
|
||||
function to distribute and shuffle sample
|
||||
"""
|
||||
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.group_size = group_size
|
||||
self.dataset_len = len(self.dataset)
|
||||
self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
|
||||
self.total_size = self.num_samplers * self.group_size
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.seed = (self.seed + 1) & 0xffffffff
|
||||
np.random.seed(self.seed)
|
||||
indices = np.random.permutation(self.dataset_len).tolist()
|
||||
else:
|
||||
indices = list(range(self.dataset_len))
|
||||
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
indices = indices[self.rank::self.group_size]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samplers
|
||||
|
||||
|
||||
def create_dataset(audio_conf, manifest_filepath, labels, normalize, batch_size, train_mode=True,
|
||||
rank=None, group_size=None):
|
||||
"""
|
||||
create train dataset
|
||||
|
||||
Args:
|
||||
audio_conf: Config containing the sample rate, window and the window length/stride in seconds
|
||||
manifest_filepath (str): manifest_file path.
|
||||
labels (list): list containing all the possible characters to map to
|
||||
normalize: Apply standard mean and deviation normalization to audio tensor
|
||||
train_mode (bool): Whether dataset is use for train or eval (default=True).
|
||||
batch_size (int): Dataset batch size
|
||||
rank (int): The shard ID within num_shards (default=None).
|
||||
group_size (int): Number of shards that the dataset should be divided into (default=None).
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
|
||||
dataset = ASRDataset(audio_conf=audio_conf, manifest_filepath=manifest_filepath, labels=labels, normalize=normalize,
|
||||
batch_size=batch_size, is_training=train_mode)
|
||||
|
||||
sampler = DistributedSampler(dataset, rank, group_size, shuffle=True)
|
||||
|
||||
ds = de.GeneratorDataset(dataset, ["inputs", "input_length", "target_indices", "label_values"], sampler=sampler)
|
||||
ds = ds.repeat(1)
|
||||
return ds
|
|
@ -0,0 +1,300 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
DeepSpeech2 model
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import nn, Tensor, ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
|
||||
class SequenceWise(nn.Cell):
|
||||
"""
|
||||
SequenceWise FC Layers.
|
||||
"""
|
||||
def __init__(self, module):
|
||||
super(SequenceWise, self).__init__()
|
||||
self.module = module
|
||||
self.reshape_op = P.Reshape()
|
||||
self.shape_op = P.Shape()
|
||||
self._initialize_weights()
|
||||
|
||||
def construct(self, x):
|
||||
sizes = self.shape_op(x)
|
||||
t, n = sizes[0], sizes[1]
|
||||
x = self.reshape_op(x, (t * n, -1))
|
||||
x = self.module(x)
|
||||
x = self.reshape_op(x, (t, n, -1))
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
self.init_parameters_data()
|
||||
for _, m in self.cells_and_names():
|
||||
if isinstance(m, nn.Dense):
|
||||
m.weight.set_data(Tensor(
|
||||
np.random.uniform(-1. / m.in_channels, 1. / m.in_channels, m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(Tensor(
|
||||
np.random.uniform(-1. / m.in_channels, 1. / m.in_channels, m.bias.data.shape).astype(
|
||||
"float32")))
|
||||
|
||||
|
||||
class MaskConv(nn.Cell):
|
||||
"""
|
||||
MaskConv architecture. MaskConv is actually not implemented in this part because some operation in MindSpore
|
||||
is not supported. lengths is kept for future use.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MaskConv, self).__init__()
|
||||
self.zeros = P.ZerosLike()
|
||||
self.conv1 = nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), pad_mode='pad', padding=(20, 20, 5, 5))
|
||||
self.bn1 = nn.BatchNorm2d(num_features=32)
|
||||
self.conv2 = nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), pad_mode='pad', padding=(10, 10, 5, 5))
|
||||
self.bn2 = nn.BatchNorm2d(num_features=32)
|
||||
self.tanh = nn.Tanh()
|
||||
self._initialize_weights()
|
||||
self.module_list = nn.CellList([self.conv1, self.bn1, self.tanh, self.conv2, self.bn2, self.tanh])
|
||||
|
||||
def construct(self, x, lengths):
|
||||
|
||||
for module in self.module_list:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""
|
||||
parameter initialization
|
||||
"""
|
||||
self.init_parameters_data()
|
||||
for _, m in self.cells_and_names():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
|
||||
m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.gamma.set_data(
|
||||
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
|
||||
m.beta.set_data(
|
||||
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
|
||||
|
||||
|
||||
class BatchRNN(nn.Cell):
|
||||
"""
|
||||
BatchRNN architecture.
|
||||
Args:
|
||||
batch_size(int): smaple_number of per step in training
|
||||
input_size (int): dimension of input tensor
|
||||
hidden_size(int): rnn hidden size
|
||||
num_layers(int): rnn layers
|
||||
bidirectional(bool): use bidirectional rnn (default=True). Currently, only bidirectional rnn is implemented.
|
||||
batch_norm(bool): whether to use batchnorm in RNN. Currently, GPU does not support batch_norm1D (default=False).
|
||||
rnn_type (str): rnn type to use (default='LSTM'). Currently, only LSTM is supported.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, input_size, hidden_size, num_layers, bidirectional=False, batch_norm=False,
|
||||
rnn_type='LSTM'):
|
||||
super(BatchRNN, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.rnn_type = rnn_type
|
||||
self.bidirectional = bidirectional
|
||||
self.has_bias = True
|
||||
self.is_batch_norm = batch_norm
|
||||
self.num_directions = 2 if bidirectional else 1
|
||||
self.reshape_op = P.Reshape()
|
||||
self.shape_op = P.Shape()
|
||||
self.sum_op = P.ReduceSum()
|
||||
|
||||
input_size_list = [input_size]
|
||||
for i in range(num_layers - 1):
|
||||
input_size_list.append(hidden_size)
|
||||
layers = []
|
||||
|
||||
for i in range(num_layers):
|
||||
layers.append(
|
||||
nn.LSTMCell(input_size=input_size_list[i], hidden_size=hidden_size, bidirectional=bidirectional,
|
||||
has_bias=self.has_bias))
|
||||
|
||||
weights = []
|
||||
for i in range(num_layers):
|
||||
weight_size = (input_size_list[i] + hidden_size) * hidden_size * self.num_directions * 4
|
||||
if self.has_bias:
|
||||
bias_size = self.num_directions * hidden_size * 4 * 2
|
||||
weight_size = weight_size + bias_size
|
||||
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
|
||||
|
||||
weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i)))
|
||||
|
||||
self.h, self.c = self.stack_lstm_default_state(batch_size, hidden_size, num_layers=num_layers,
|
||||
bidirectional=bidirectional)
|
||||
self.lstms = layers
|
||||
self.weight = ParameterTuple(tuple(weights))
|
||||
|
||||
if batch_norm:
|
||||
batch_norm_layer = []
|
||||
for i in range(num_layers - 1):
|
||||
batch_norm_layer.append(nn.BatchNorm1d(hidden_size))
|
||||
self.batch_norm_list = batch_norm_layer
|
||||
|
||||
def stack_lstm_default_state(self, batch_size, hidden_size, num_layers, bidirectional):
|
||||
"""init default input."""
|
||||
num_directions = 2 if bidirectional else 1
|
||||
h_list = c_list = []
|
||||
for _ in range(num_layers):
|
||||
h_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)))
|
||||
c_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)))
|
||||
h, c = tuple(h_list), tuple(c_list)
|
||||
return h, c
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(self.num_layers):
|
||||
if self.is_batch_norm and i > 0:
|
||||
x = self.batch_norm_list[i - 1](x)
|
||||
x, _, _, _, _ = self.lstms[i](x, self.h[i], self.c[i], self.weight[i])
|
||||
if self.bidirectional:
|
||||
size = self.shape_op(x)
|
||||
x = self.reshape_op(x, (size[0], size[1], 2, -1))
|
||||
x = self.sum_op(x, 2)
|
||||
return x
|
||||
|
||||
|
||||
class DeepSpeechModel(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
Args:
|
||||
batch_size(int): smaple_number of per step in training (default=128)
|
||||
rnn_type (str): rnn type to use (default="LSTM")
|
||||
labels (list): list containing all the possible characters to map to
|
||||
rnn_hidden_size(int): rnn hidden size
|
||||
nb_layers(int): number of rnn layers
|
||||
audio_conf: Config containing the sample rate, window and the window length/stride in seconds
|
||||
bidirectional(bool): use bidirectional rnn (default=True)
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, labels, rnn_hidden_size, nb_layers, audio_conf, rnn_type='LSTM', bidirectional=True):
|
||||
super(DeepSpeechModel, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.hidden_size = rnn_hidden_size
|
||||
self.hidden_layers = nb_layers
|
||||
self.rnn_type = rnn_type
|
||||
self.audio_conf = audio_conf
|
||||
self.labels = labels
|
||||
self.bidirectional = bidirectional
|
||||
self.reshape_op = P.Reshape()
|
||||
self.shape_op = P.Shape()
|
||||
self.transpose_op = P.Transpose()
|
||||
self.add = P.TensorAdd()
|
||||
self.div = P.Div()
|
||||
|
||||
sample_rate = self.audio_conf.sample_rate
|
||||
window_size = self.audio_conf.window_size
|
||||
num_classes = len(self.labels)
|
||||
|
||||
self.conv = MaskConv()
|
||||
# This is to calculate
|
||||
self.pre, self.stride = self.get_conv_num()
|
||||
|
||||
# Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
|
||||
rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
|
||||
rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)
|
||||
rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)
|
||||
rnn_input_size *= 32
|
||||
|
||||
self.RNN = BatchRNN(batch_size=self.batch_size, input_size=rnn_input_size, num_layers=nb_layers,
|
||||
hidden_size=rnn_hidden_size, bidirectional=bidirectional, batch_norm=False,
|
||||
rnn_type=self.rnn_type)
|
||||
fully_connected = nn.Dense(rnn_hidden_size, num_classes, has_bias=False)
|
||||
self.fc = SequenceWise(fully_connected)
|
||||
|
||||
def construct(self, x, lengths):
|
||||
"""
|
||||
lengths is actually not used in this part since Mindspore does not support dynamic shape.
|
||||
"""
|
||||
output_lengths = self.get_seq_lens(lengths)
|
||||
x = self.conv(x, lengths)
|
||||
sizes = self.shape_op(x)
|
||||
x = self.reshape_op(x, (sizes[0], sizes[1] * sizes[2], sizes[3]))
|
||||
x = self.transpose_op(x, (2, 0, 1))
|
||||
x = self.RNN(x)
|
||||
x = self.fc(x)
|
||||
return x, output_lengths
|
||||
|
||||
def get_seq_lens(self, seq_len):
|
||||
"""
|
||||
Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable
|
||||
containing the size sequences that will be output by the network.
|
||||
"""
|
||||
for i in range(len(self.stride)):
|
||||
seq_len = self.add(self.div(self.add(seq_len, self.pre[i]), self.stride[i]), 1)
|
||||
return seq_len
|
||||
|
||||
def get_conv_num(self):
|
||||
p, s = [], []
|
||||
for _, cell in self.conv.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
kernel_size = cell.kernel_size
|
||||
padding_1 = int((kernel_size[1] - 1) / 2)
|
||||
temp = 2 * padding_1 - cell.dilation[1] * (cell.kernel_size[1] - 1) - 1
|
||||
p.append(temp)
|
||||
s.append(cell.stride[1])
|
||||
return p, s
|
||||
|
||||
|
||||
class NetWithLossClass(nn.Cell):
|
||||
"""
|
||||
NetWithLossClass definition
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||
self.loss = P.CTCLoss(ctc_merge_repeated=True)
|
||||
self.network = network
|
||||
self.ReduceMean_false = P.ReduceMean(keep_dims=False)
|
||||
self.squeeze_op = P.Squeeze(0)
|
||||
|
||||
def construct(self, inputs, input_length, target_indices, label_values):
|
||||
predict, output_length = self.network(inputs, input_length)
|
||||
loss = self.loss(predict, target_indices, label_values, output_length)
|
||||
return self.ReduceMean_false(loss[0])
|
||||
|
||||
|
||||
class PredictWithSoftmax(nn.Cell):
|
||||
"""
|
||||
PredictWithSoftmax
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(PredictWithSoftmax, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.inference_softmax = P.Softmax(axis=-1)
|
||||
self.transpose_op = P.Transpose()
|
||||
|
||||
def construct(self, inputs, input_length):
|
||||
x, output_sizes = self.network(inputs, input_length)
|
||||
x = self.inference_softmax(x)
|
||||
x = self.transpose_op(x, (1, 0, 2))
|
||||
return x, output_sizes
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
modify GreedyDecoder to adapt to MindSpore
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from deepspeech_pytorch.decoder import GreedyDecoder
|
||||
|
||||
class MSGreedyDecoder(GreedyDecoder):
|
||||
"""
|
||||
GreedyDecoder used for MindSpore
|
||||
"""
|
||||
|
||||
def process_string(self, sequence, size, remove_repetitions=False):
|
||||
"""
|
||||
process string
|
||||
"""
|
||||
string = ''
|
||||
offsets = []
|
||||
for i in range(size):
|
||||
char = self.int_to_char[sequence[i].item()]
|
||||
if char != self.int_to_char[self.blank_index]:
|
||||
if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]:
|
||||
pass
|
||||
elif char == self.labels[self.space_index]:
|
||||
string += ' '
|
||||
offsets.append(i)
|
||||
else:
|
||||
string = string + char
|
||||
offsets.append(i)
|
||||
return string, offsets
|
||||
|
||||
def decode(self, probs, sizes=None):
|
||||
probs = probs.asnumpy()
|
||||
sizes = sizes.asnumpy()
|
||||
|
||||
max_probs = np.argmax(probs, axis=-1)
|
||||
strings, offsets = self.convert_to_strings(max_probs, sizes, remove_repetitions=True, return_offsets=True)
|
||||
return strings, offsets
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""learning rate generator"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(lr_init, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
half_epoch = total_epochs // 2
|
||||
for i in range(total_epochs * steps_per_epoch):
|
||||
if i < half_epoch:
|
||||
lr_each_step.append(lr_init)
|
||||
else:
|
||||
lr_each_step.append(lr_init / (1.1 ** (i - half_epoch)))
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
return learning_rate
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""train_criteo."""
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
|
||||
from mindspore import context, Tensor, ParameterTuple
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.train import Model
|
||||
|
||||
from src.deepspeech2 import DeepSpeechModel, NetWithLossClass
|
||||
from src.lr_generator import get_lr
|
||||
from src.callback import Monitor
|
||||
from src.config import train_config
|
||||
from src.dataset import create_dataset
|
||||
|
||||
parser = argparse.ArgumentParser(description='DeepSpeech2 training')
|
||||
parser.add_argument('--pre_trained_model_path', type=str, default='', help='Pretrained checkpoint path')
|
||||
parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training')
|
||||
parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
rank_id = 0
|
||||
group_size = 1
|
||||
config = train_config
|
||||
if args.is_distributed:
|
||||
init('nccl')
|
||||
rank_id = get_rank()
|
||||
group_size = get_group_size()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
|
||||
|
||||
with open(config.DataConfig.labels_path) as label_file:
|
||||
labels = json.load(label_file)
|
||||
|
||||
ds_train = create_dataset(audio_conf=config.DataConfig.SpectConfig,
|
||||
manifest_filepath=config.DataConfig.train_manifest,
|
||||
labels=labels, normalize=True, train_mode=True,
|
||||
batch_size=config.DataConfig.batch_size, rank=rank_id, group_size=group_size)
|
||||
steps_size = ds_train.get_dataset_size()
|
||||
|
||||
lr = get_lr(lr_init=config.OptimConfig.learning_rate, total_epochs=config.TrainingConfig.epochs,
|
||||
steps_per_epoch=steps_size)
|
||||
lr = Tensor(lr)
|
||||
|
||||
deepspeech_net = DeepSpeechModel(batch_size=config.DataConfig.batch_size,
|
||||
rnn_hidden_size=config.ModelConfig.hidden_size,
|
||||
nb_layers=config.ModelConfig.hidden_layers,
|
||||
labels=labels,
|
||||
rnn_type=config.ModelConfig.rnn_type,
|
||||
audio_conf=config.DataConfig.SpectConfig,
|
||||
bidirectional=True)
|
||||
|
||||
loss_net = NetWithLossClass(deepspeech_net)
|
||||
weights = ParameterTuple(deepspeech_net.trainable_params())
|
||||
|
||||
optimizer = Adam(weights, learning_rate=config.OptimConfig.learning_rate, eps=config.OptimConfig.epsilon,
|
||||
loss_scale=config.OptimConfig.loss_scale)
|
||||
train_net = TrainOneStepCell(loss_net, optimizer)
|
||||
|
||||
if args.pre_trained_model_path is not None:
|
||||
param_dict = load_checkpoint(args.pre_trained_model_path)
|
||||
load_param_into_net(train_net, param_dict)
|
||||
print('Successfully loading the pre-trained model')
|
||||
|
||||
model = Model(train_net)
|
||||
lr_cb = Monitor(lr)
|
||||
callback_list = [lr_cb]
|
||||
|
||||
if args.is_distributed:
|
||||
config.CheckpointConfig.ckpt_file_name_prefix = config.CheckpointConfig.ckpt_file_name_prefix + str(get_rank())
|
||||
config.CheckpointConfig.ckpt_path = os.path.join(config.CheckpointConfig.ckpt_path,
|
||||
'ckpt_' + str(get_rank()) + '/')
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=1,
|
||||
keep_checkpoint_max=config.CheckpointConfig.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix=config.CheckpointConfig.ckpt_file_name_prefix,
|
||||
directory=config.CheckpointConfig.ckpt_path, config=config_ck)
|
||||
callback_list.append(ckpt_cb)
|
||||
model.train(config.TrainingConfig.epochs, ds_train, callbacks=callback_list)
|
Loading…
Reference in New Issue