forked from mindspore-Ecosystem/mindspore
deepspeech cpu training
This commit is contained in:
parent
8a61767f32
commit
7a09b16311
|
@ -19,7 +19,7 @@
|
|||
# [DeepSpeech2 Description](#contents)
|
||||
|
||||
DeepSpeech2 is a speech recognition models which is trained with CTC loss. It replaces entire pipelines of hand-engineered components with neural networks and can handle a diverse variety of speech including noisy
|
||||
environments, accents and different languages. We support training and evaluation on GPU.
|
||||
environments, accents and different languages. We support training and evaluation on CPU and GPU.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1512.02595v1.pdf): Amodei, Dario, et al. Deep speech 2: End-to-end speech recognition in english and mandarin.
|
||||
|
||||
|
@ -97,10 +97,12 @@ usage: train.py [--use_pretrained USE_PRETRAINED]
|
|||
[--pre_trained_model_path PRE_TRAINED_MODEL_PATH]
|
||||
[--is_distributed IS_DISTRIBUTED]
|
||||
[--bidirectional BIDIRECTIONAL]
|
||||
[--device_target DEVICE_TARGET]
|
||||
options:
|
||||
--pre_trained_model_path pretrained checkpoint path, default is ''
|
||||
--is_distributed distributed training, default is False
|
||||
--bidirectional whether or not to use bidirectional RNN, default is True. Currently, only bidirectional model is implemented
|
||||
--device_target device where the code will be implemented: "GPU" | "CPU", default is "GPU"
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
@ -108,10 +110,12 @@ options:
|
|||
```text
|
||||
usage: eval.py [--bidirectional BIDIRECTIONAL]
|
||||
[--pretrain_ckpt PRETRAIN_CKPT]
|
||||
[--device_target DEVICE_TARGET]
|
||||
|
||||
options:
|
||||
--bidirectional whether to use bidirectional RNN, default is True. Currently, only bidirectional model is implemented
|
||||
--pretrain_ckpt saved checkpoint path, default is ''
|
||||
--device_target device where the code will be implemented: "GPU" | "CPU", default is "GPU"
|
||||
```
|
||||
|
||||
### Options and Parameters
|
||||
|
@ -210,7 +214,7 @@ for evaluation configuration
|
|||
|
||||
```
|
||||
|
||||
The three*.csv files will be used in training and evaluation process. Before training, some requirements should be installed, including `librosa` and `Levenshtein`
|
||||
Before training, some requirements should be installed, including `librosa` and `Levenshtein`
|
||||
After installing MindSpore via the official website and finishing dataset processing, you can start training as follows:
|
||||
|
||||
```shell
|
||||
|
|
|
@ -24,17 +24,18 @@ from src.config import eval_config
|
|||
from src.deepspeech2 import DeepSpeechModel, PredictWithSoftmax
|
||||
from src.dataset import create_dataset
|
||||
from src.greedydecoder import MSGreedyDecoder
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
|
||||
|
||||
parser = argparse.ArgumentParser(description='DeepSpeech evaluation')
|
||||
parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN')
|
||||
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
|
||||
parser.add_argument('--device_target', type=str, default="GPU", choices=("GPU", "CPU"),
|
||||
help='Device target, support GPU and CPU, Default: GPU')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
|
||||
config = eval_config
|
||||
with open(config.DataConfig.labels_path) as label_file:
|
||||
labels = json.load(label_file)
|
||||
|
|
|
@ -115,7 +115,7 @@ class ASRDataset(LoadAudioAndTranscript):
|
|||
batch_idx = self.bins[index]
|
||||
batch_size = len(batch_idx)
|
||||
batch_spect, batch_script, target_indices = [], [], []
|
||||
input_length = np.zeros(batch_size, np.int32)
|
||||
input_length = np.zeros(batch_size, np.float32)
|
||||
for data in batch_idx:
|
||||
audio_path, transcript_path = data[0], data[1]
|
||||
spect = self.parse_audio(audio_path)
|
||||
|
|
|
@ -19,6 +19,7 @@ DeepSpeech2 model
|
|||
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import nn, Tensor, ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
|
@ -112,7 +113,7 @@ class BatchRNN(nn.Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, batch_size, input_size, hidden_size, num_layers, bidirectional=False, batch_norm=False,
|
||||
rnn_type='LSTM'):
|
||||
rnn_type='LSTM', device_target="GPU"):
|
||||
super(BatchRNN, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.input_size = input_size
|
||||
|
@ -141,7 +142,10 @@ class BatchRNN(nn.Cell):
|
|||
for i in range(num_layers):
|
||||
weight_size = (input_size_list[i] + hidden_size) * hidden_size * self.num_directions * 4
|
||||
if self.has_bias:
|
||||
bias_size = self.num_directions * hidden_size * 4 * 2
|
||||
if device_target == "GPU":
|
||||
bias_size = self.num_directions * hidden_size * 4 * 2
|
||||
else:
|
||||
bias_size = self.num_directions * hidden_size * 4
|
||||
weight_size = weight_size + bias_size
|
||||
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
|
@ -195,7 +199,8 @@ class DeepSpeechModel(nn.Cell):
|
|||
bidirectional(bool): use bidirectional rnn (default=True)
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, labels, rnn_hidden_size, nb_layers, audio_conf, rnn_type='LSTM', bidirectional=True):
|
||||
def __init__(self, batch_size, labels, rnn_hidden_size, nb_layers, audio_conf, rnn_type='LSTM',
|
||||
bidirectional=True, device_target='GPU'):
|
||||
super(DeepSpeechModel, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.hidden_size = rnn_hidden_size
|
||||
|
@ -226,7 +231,7 @@ class DeepSpeechModel(nn.Cell):
|
|||
|
||||
self.RNN = BatchRNN(batch_size=self.batch_size, input_size=rnn_input_size, num_layers=nb_layers,
|
||||
hidden_size=rnn_hidden_size, bidirectional=bidirectional, batch_norm=False,
|
||||
rnn_type=self.rnn_type)
|
||||
rnn_type=self.rnn_type, device_target=device_target)
|
||||
fully_connected = nn.Dense(rnn_hidden_size, num_classes, has_bias=False)
|
||||
self.fc = SequenceWise(fully_connected)
|
||||
|
||||
|
@ -275,10 +280,11 @@ class NetWithLossClass(nn.Cell):
|
|||
self.network = network
|
||||
self.ReduceMean_false = P.ReduceMean(keep_dims=False)
|
||||
self.squeeze_op = P.Squeeze(0)
|
||||
self.cast_op = P.Cast()
|
||||
|
||||
def construct(self, inputs, input_length, target_indices, label_values):
|
||||
predict, output_length = self.network(inputs, input_length)
|
||||
loss = self.loss(predict, target_indices, label_values, output_length)
|
||||
loss = self.loss(predict, target_indices, label_values, self.cast_op(output_length, mstype.int32))
|
||||
return self.ReduceMean_false(loss[0])
|
||||
|
||||
|
||||
|
@ -292,9 +298,10 @@ class PredictWithSoftmax(nn.Cell):
|
|||
self.network = network
|
||||
self.inference_softmax = P.Softmax(axis=-1)
|
||||
self.transpose_op = P.Transpose()
|
||||
self.cast_op = P.Cast()
|
||||
|
||||
def construct(self, inputs, input_length):
|
||||
x, output_sizes = self.network(inputs, input_length)
|
||||
x, output_sizes = self.network(inputs, self.cast_op(input_length, mstype.int32))
|
||||
x = self.inference_softmax(x)
|
||||
x = self.transpose_op(x, (1, 0, 2))
|
||||
return x, output_sizes
|
||||
|
|
|
@ -21,7 +21,7 @@ import argparse
|
|||
from mindspore import context, Tensor, ParameterTuple
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
|
@ -29,7 +29,6 @@ from mindspore.train import Model
|
|||
|
||||
from src.deepspeech2 import DeepSpeechModel, NetWithLossClass
|
||||
from src.lr_generator import get_lr
|
||||
from src.callback import Monitor
|
||||
from src.config import train_config
|
||||
from src.dataset import create_dataset
|
||||
|
||||
|
@ -37,22 +36,23 @@ parser = argparse.ArgumentParser(description='DeepSpeech2 training')
|
|||
parser.add_argument('--pre_trained_model_path', type=str, default='', help='Pretrained checkpoint path')
|
||||
parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training')
|
||||
parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN')
|
||||
parser.add_argument('--device_target', type=str, default="GPU", choices=("GPU", "CPU"),
|
||||
help='Device target, support GPU and CPU, Default: GPU')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
rank_id = 0
|
||||
group_size = 1
|
||||
config = train_config
|
||||
data_sink = (args.device_target == "GPU")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
|
||||
if args.is_distributed:
|
||||
init('nccl')
|
||||
rank_id = get_rank()
|
||||
group_size = get_group_size()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
|
||||
|
||||
with open(config.DataConfig.labels_path) as label_file:
|
||||
labels = json.load(label_file)
|
||||
|
@ -73,7 +73,8 @@ if __name__ == '__main__':
|
|||
labels=labels,
|
||||
rnn_type=config.ModelConfig.rnn_type,
|
||||
audio_conf=config.DataConfig.SpectConfig,
|
||||
bidirectional=True)
|
||||
bidirectional=True,
|
||||
device_target=args.device_target)
|
||||
|
||||
loss_net = NetWithLossClass(deepspeech_net)
|
||||
weights = ParameterTuple(deepspeech_net.trainable_params())
|
||||
|
@ -88,8 +89,7 @@ if __name__ == '__main__':
|
|||
print('Successfully loading the pre-trained model')
|
||||
|
||||
model = Model(train_net)
|
||||
lr_cb = Monitor(lr)
|
||||
callback_list = [lr_cb]
|
||||
callback_list = [LossMonitor()]
|
||||
|
||||
if args.is_distributed:
|
||||
config.CheckpointConfig.ckpt_file_name_prefix = config.CheckpointConfig.ckpt_file_name_prefix + str(get_rank())
|
||||
|
@ -100,4 +100,4 @@ if __name__ == '__main__':
|
|||
ckpt_cb = ModelCheckpoint(prefix=config.CheckpointConfig.ckpt_file_name_prefix,
|
||||
directory=config.CheckpointConfig.ckpt_path, config=config_ck)
|
||||
callback_list.append(ckpt_cb)
|
||||
model.train(config.TrainingConfig.epochs, ds_train, callbacks=callback_list)
|
||||
model.train(config.TrainingConfig.epochs, ds_train, callbacks=callback_list, dataset_sink_mode=data_sink)
|
||||
|
|
|
@ -50,10 +50,10 @@ Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
|
|||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://cmc-szv.clouddragon.huawei.com/cmcversion/index/search?searchKey=Do-MindSpore%20V100R001C00B622)
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
|
|
Loading…
Reference in New Issue