!11773 Add WaveNet to Model Zoo

From: @wanyiming
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-30 08:55:28 +08:00 committed by Gitee
commit b82df95b43
15 changed files with 2584 additions and 0 deletions

View File

@ -0,0 +1,254 @@
# Contents
- [WaveNet Description](#WaveNet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Convert Process](#convert-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [WaveNet Description](#contents)
WaveNet is a deep neural network for generating raw audio waveforms. The model is fully probabilistic and autoregressive, with the predictive distribution for each audio sample conditioned on all previous ones. We support training and evaluation on GPU.
[Paper](https://arxiv.org/pdf/1609.03499.pdf): ord A, Dieleman S, Zen H, et al. Wavenet: A generative model for raw audio
# [Model Architecture](#contents)
The current model consists of a pre-convolution layer, followed by several residual block which has residual and skip connection with gated activation units.
Finally, post convolution layers are added to predict the distribution.
# [Dataset](#contents)
In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
- Dataset size2.6G
- Data formataudio clips(13100) and transcription
- The dataset structure is as follows:
```path
.
└── LJSpeech-1.1
├─ wavs //audio clips files
└─ metadata.csv //transcripts
```
# [Environment Requirements](#contents)
- HardwareGPU
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](https://cmc-szv.clouddragon.huawei.com/cmcversion/index/search?searchKey=Do-MindSpore%20V100R001C00B622)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
**Note that some of the scripts described below are not included our code**. These scripts should first be download them from [r9y9](https://github.com/r9y9/wavenet_vocoder) and added into this project.
```path
.
├── audio
└──wavenet
├──datasets // Note the datasets folder should be download from the above link
├──egs // Note the egs folder should be download from the above link
├──utils // Note the utils folder should be download from the above link
├── audio.py // audio utils. Note this script should be download from a third party
├── compute-meanvar-stats.py // Compute mean-variance normalization stats. Note this script should be download from the above link
├── evaluate.py // evaluation
├── export.py // convert mindspore model to air model
├── hparams.py // hyper-parameter configuration. Note this script should be download from the above link
├── mksubset.py // Make subset of dataset. Note this script should be download from the above link
├── preprocess.py // Preprocess dataset. Note this script should be download from the above link
├── preprocess_normalize.py // Perform meanvar normalization to preprocessed features. Note this script should be download from the above link
├── README.md // descriptions about WaveNet
├── train.py // training scripts
├── train_pytorch.py // Note this script should be download from the above link. The initial name of this script is train.py in the project from the link
├── src
│ ├──__init__.py
│ ├──dataset.py // generate dataloader and data processing entry
│ ├──callback.py // callbacks to monitor the training
│ ├──lr_generator.py // learning rate generator
│ └──loss.py // loss function definition
└── wavenet_vocoder
├──__init__.py
├──conv.py // extended 1D convolution
├──mixture.py // loss function for training and sample function for testing
├──modules.py // modules for Wavenet construction
├──upsample.py // upsample layer definition
├──util.py // utils. Note this script should be download from the above link
├──wavenet.py // WaveNet networks
└──tfcompat // Note this script should be download from the above link
├──__init__.py
└──hparam.py // param management tools
```
## [Script Parameters](#contents)
### Training
```text
usage: train.py [--data_path DATA_PATH] [--preset PRESET]
[--checkpoint_dir CHECKPOINT_DIR] [--checkpoint CHECKPOINT]
[--speaker_id SPEAKER_ID] [--is_distributed IS_DISTRIBUTED]
options:
--data_path dataset path
--preset path of preset parameters (json)
--checkpoint_dir directory of saving model checkpoints
--checkpoint pre-trained ckpt path, default is "./checkpoints"
--speaker_id specific speaker of data in case for multi-speaker datasets, not used currently
--is_distributed whether distributed training or not
```
### Evaluation
```text
usage: evaluate.py [--data_path DATA_PATH] [--preset PRESET]
[--pretrain_ckpt PRETRAIN_CKPT] [--output_path OUTPUT_PATH]
[--speaker_id SPEAKER_ID]
options:
--data_path dataset path
--preset path of preset parameters (json)
--pretrain_ckpt pre-trained ckpt path
--is_numpy whether using numpy for inference or not
--output_path path to save synthesized audio
--speaker_id specific speaker of data in case for multi-speaker datasets, not used currently
```
More parameters for training and evaluation can be set in file `hparams.py`.
## [Training Process](#contents)
Before your first training, some dependency scripts should be downloaded and placed in correct directory as described in [Script and Sample Code].
After that, raw data should be pre-processed by using the scripts in `egs`. The directory of egs is as follows:
```path
.
├── egs
├──gaussian
│ ├──conf
│ │ ├──gaussian_wavenet.json
│ │ └──gaussian_wavenet_demo.json
│ └──run.sh
├──mol
│ ├──conf
│ │ ├──mol_wavenet.json
│ │ └──mol_wavenet_demo.json
│ └──run.sh
├──mulaw256
│ ├──conf
│ │ ├──mulaw_wavenet.json
│ │ └──mulaw_wavenet_demo.json
│ └──run.sh
└──README.md
```
In this project, three different losses are implemented to train the network:
- mulaw256: categorical output distribution. The input is 8-bit mulaw quantized waveform.
- mol: discretized mix logistic loss. The input is 16-bit raw audio.
- gaussian: mix gaussian loss. The input is 16-bit raw audio.
The three folder gaussian, mol, mulaw is used to generate corresponding training data respectively. For example, To generate the training data for
mix gaussian loss, you should first modify the `run.sh` in line 28. Change `conf/gaussian_wavenet_demo.json` to
`conf/gaussian_wavenet.json`. We use the default parameter in `gaussian_wavenet.json`. By this setting, data will be generated to adapt to mix gaussian loss and
some parameters in `hparams.py` will be covered by that in `gaussian_wavenet.json`. You can also define your own hyper-parameter json here. After the modification,
The following command can be ran for data generation. Note that if you want to change values of some parameters, you may need to modify in `gaussian_wavenet.json` instead of `hparams.py` since `gaussian_wavenet.json` may cover that in`hparams.py`.
```bash
bash run.sh --stage 0 --stop-stage 0 --db-root /path_to_dataset/LJSpeech-1.1/wavs
bash run.sh --stage 1 --stop-stage 1
```
After the processing, the directory of gaussian will be as follows:
```path
.
├── gaussian
├──conf
├──data
├──exp
└──dump
└──lj
└──logmelspectrogram
├──org
└──norm
├──train_no_dev
├──dev
└──eval
```
The train_no_dev folder contains the final training data. For mulaw256 and mol, the process is the same. When the training data is prepared,
you can run the following command to train the network:
```bash
# standalone training
python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt
distributed training
CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' mpirun --allow-run-as-root -n 8 python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt --is_distributed=True
eval
python evaluate.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --output_path=path_to_save_audio
```
## [Evaluation Process](#contents)
WaveNet has a process of auto-regression and this process currently cannot be run in Graph mode(place the auto-regression into `construct`). Therefore, we implement the process in a common function. Here, we provide two kinds of ways to realize the function: using Numpy or using MindSpore ops. One can set `is_numpy` to determine which mode is used. We recommend using numpy since it is much faster than using MindSpore ops. This is because the auto-regression process only calls some simple operation like Matmul and Bias_add. Unlike Graph mode, there will exist some fixed cost each step and this leads to a lower speed. For more information, please refer to
this [link](https://bbs.huaweicloud.com/forum/thread-94852-1-1.html)
```bash
eval
python evaluate.py --data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio
```
## [Convert Process](#contents)
```bash
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | WaveNet |
| -------------------------- | ---------------------------------------------------------------|
| Resource | NV SMX2 V100-32G |
| uploaded Date | 01/14/2021 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | LJSpeech-1.1 |
| Training Parameters | 1p, epoch=600(max), steps=1635 * epoch, batch_size = 8, lr=1e-3 |
| Optimizer | Adam |
| Loss Function | SoftmaxCrossEntropyWithLogits/discretized_mix_logistic/mix_gaussian |
| Loss | around 2.0(mulaw256)/around 4.5(mol)/around -6.0(gaussian) |
| Speed | 1p 1.467s/step |
| Total time: training | 1p(mol/gaussian): around 4 days; 2p(mulaw256):around 1 week |
| Checkpoint | 59.79MM/54.87M/54.83M (.ckpt file) |
| Scripts | [WaveNet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/wavenet) |
### Inference Performance On GPU
Audio samples will be demonstrated online soon.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,253 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluation"""
import os
from os.path import join
import argparse
import glob
from hparams import hparams, hparams_debug_string
import audio
import numpy as np
from scipy.io import wavfile
from tqdm import tqdm
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
from nnmnkwii import preprocessing as P
from nnmnkwii.datasets import FileSourceDataset
from wavenet_vocoder import WaveNet
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_scalar_input
from src.dataset import RawAudioDataSource, MelSpecDataSource, DualDataset
parser = argparse.ArgumentParser(description='TTS training')
parser.add_argument('--data_path', type=str, required=True, default='',
help='Directory contains preprocessed features.')
parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).')
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
parser.add_argument('--is_numpy', action="store_false", default=True, help='Using numpy for inference or not')
parser.add_argument('--output_path', type=str, default='./out_wave/', help='Path to save generated audios')
parser.add_argument('--speaker_id', type=str, default='',
help=' Use specific speaker of data in case for multi-speaker datasets.')
args = parser.parse_args()
def get_data_loader(hparam, data_dir):
"""
test data loader
"""
wav_paths = glob.glob(os.path.join(data_dir, "*-wave.npy"))
if wav_paths:
X = FileSourceDataset(RawAudioDataSource(data_dir,
hop_size=audio.get_hop_size(),
max_steps=None, cin_pad=hparam.cin_pad))
else:
X = None
C = FileSourceDataset(MelSpecDataSource(data_dir,
hop_size=audio.get_hop_size(),
max_steps=None, cin_pad=hparam.cin_pad))
length_x = np.array(C.file_data_source.lengths)
if C[0].shape[-1] != hparam.cin_channels:
raise RuntimeError("Invalid cin_channnels {}. Expected to be {}.".format(hparam.cin_channels, C[0].shape[-1]))
dataset = DualDataset(X, C, length_x, batch_size=hparam.batch_size, hparams=hparam)
data_loader = de.GeneratorDataset(dataset, ["x_batch", "y_batch", "c_batch", "g_batch", "input_lengths", "mask"])
return data_loader, dataset
def batch_wavegen(hparam, net, c_input=None, g_input=None, tqdm_=None, is_numpy=True):
"""
generate audio
"""
assert c_input is not None
B = c_input.shape[0]
net.set_train(False)
if hparam.upsample_conditional_features:
length = (c_input.shape[-1] - hparam.cin_pad * 2) * audio.get_hop_size()
else:
# already dupulicated
length = c_input.shape[-1]
y_hat = net.incremental_forward(c=c_input, g=g_input, T=length, tqdm=tqdm_, softmax=True, quantize=True,
log_scale_min=hparam.log_scale_min, is_numpy=is_numpy)
if is_mulaw_quantize(hparam.input_type):
# needs to be float since mulaw_inv returns in range of [-1, 1]
y_hat = np.reshape(np.argmax(y_hat, 1), (B, -1))
y_hat = y_hat.astype(np.float32)
for k in range(B):
y_hat[k] = P.inv_mulaw_quantize(y_hat[k], hparam.quantize_channels - 1)
elif is_mulaw(hparam.input_type):
y_hat = np.reshape(y_hat, (B, -1))
for k in range(B):
y_hat[k] = P.inv_mulaw(y_hat[k], hparam.quantize_channels - 1)
else:
y_hat = np.reshape(y_hat, (B, -1))
if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]:
for k in range(B):
y_hat[k] = getattr(audio, hparam.postprocess)(y_hat[k])
if hparam.global_gain_scale > 0:
for k in range(B):
y_hat[k] /= hparam.global_gain_scale
return y_hat
def to_int16(x_):
"""
convert datatype to int16
"""
if x_.dtype == np.int16:
return x_
assert x_.dtype == np.float32
assert x_.min() >= -1 and x_.max() <= 1.0
return (x_ * 32767).astype(np.int16)
def get_reference_file(hparam, dataset_source, idx):
"""
get reference files
"""
reference_files = []
reference_feats = []
for _ in range(hparam.batch_size):
if hasattr(dataset_source, "X"):
reference_files.append(dataset_source.X.collected_files[idx][0])
else:
pass
if hasattr(dataset_source, "Mel"):
reference_feats.append(dataset_source.Mel.collected_files[idx][0])
else:
reference_feats.append(dataset_source.collected_files[idx][0])
idx += 1
return reference_files, reference_feats, idx
def get_saved_audio_name(has_ref_file_, ref_file, ref_feat, g_fp):
"""get path to save reference audio"""
if has_ref_file_:
target_audio_path = ref_file
name = os.path.splitext(os.path.basename(target_audio_path))[0].replace("-wave", "")
else:
target_feat_path = ref_feat
name = os.path.splitext(os.path.basename(target_feat_path))[0].replace("-feats", "")
# Paths
if g_fp is None:
dst_wav_path_ = join(args.output_path, "{}_gen.wav".format(name))
target_wav_path_ = join(args.output_path, "{}_ref.wav".format(name))
else:
dst_wav_path_ = join(args.output_path, "speaker{}_{}_gen.wav".format(g, name))
target_wav_path_ = join(args.output_path, "speaker{}_{}_ref.wav".format(g, name))
return dst_wav_path_, target_wav_path_
def save_ref_audio(hparam, ref, length, target_wav_path_):
"""
save reference audio
"""
if is_mulaw_quantize(hparam.input_type):
ref = np.reshape(np.argmax(ref, 0), (-1))[:length]
ref = ref.astype(np.float32)
else:
ref = np.reshape(ref, (-1))[:length]
if is_mulaw_quantize(hparam.input_type):
ref = P.inv_mulaw_quantize(ref, hparam.quantize_channels - 1)
elif is_mulaw(hparam.input_type):
ref = P.inv_mulaw(ref, hparam.quantize_channels - 1)
if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]:
ref = getattr(audio, hparam.postprocess)(ref)
if hparam.global_gain_scale > 0:
ref /= hparam.global_gain_scale
ref = np.clip(ref, -1.0, 1.0)
wavfile.write(target_wav_path_, hparam.sample_rate, to_int16(ref))
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
if args.preset is not None:
with open(args.preset) as f:
hparams.parse_json(f.read())
assert hparams.name == "wavenet_vocoder"
print(hparams_debug_string())
fs = hparams.sample_rate
hparams.batch_size = 10
hparams.max_time_sec = None
hparams.max_time_steps = None
data_loaders, source_dataset = get_data_loader(hparam=hparams, data_dir=args.data_path)
upsample_params = hparams.upsample_params
upsample_params["cin_channels"] = hparams.cin_channels
upsample_params["cin_pad"] = hparams.cin_pad
model = WaveNet(
out_channels=hparams.out_channels,
layers=hparams.layers,
stacks=hparams.stacks,
residual_channels=hparams.residual_channels,
gate_channels=hparams.gate_channels,
skip_out_channels=hparams.skip_out_channels,
cin_channels=hparams.cin_channels,
gin_channels=hparams.gin_channels,
n_speakers=hparams.n_speakers,
dropout=hparams.dropout,
kernel_size=hparams.kernel_size,
cin_pad=hparams.cin_pad,
upsample_conditional_features=hparams.upsample_conditional_features,
upsample_params=upsample_params,
scalar_input=is_scalar_input(hparams.input_type),
output_distribution=hparams.output_distribution,
)
param_dict = load_checkpoint(args.pretrain_ckpt)
load_param_into_net(model, param_dict)
print('Successfully loading the pre-trained model')
os.makedirs(args.output_path, exist_ok=True)
cin_pad = hparams.cin_pad
file_idx = 0
for data in data_loaders.create_dict_iterator():
x, y, c, g, input_lengths = data['x_batch'], data['y_batch'], data['c_batch'], data['g_batch'], data[
'input_lengths']
if cin_pad > 0:
c = c.asnumpy()
c = np.pad(c, pad_width=(cin_pad, cin_pad), mode="edge")
c = Tensor(c)
ref_files, ref_feats, file_idx = get_reference_file(hparams, source_dataset, file_idx)
# Generate
y_hats = batch_wavegen(hparams, model, data['c_batch'], tqdm_=tqdm, is_numpy=args.is_numpy)
x = x.asnumpy()
input_lengths = input_lengths.asnumpy()
# Save each utt.
has_ref_file = bool(ref_files)
for i, (ref_, gen_, length_) in enumerate(zip(x, y_hats, input_lengths)):
dst_wav_path, target_wav_path = get_saved_audio_name(has_ref_file_=has_ref_file, ref_file=ref_files[i],
ref_feat=ref_feats[i], g_fp=g)
save_ref_audio(hparams, ref_, length_, target_wav_path)
gen = gen_[:length_]
gen = np.clip(gen, -1.0, 1.0)
wavfile.write(dst_wav_path, hparams.sample_rate, to_int16(gen))

View File

@ -0,0 +1,95 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export mindir."""
import json
from os.path import join
import argparse
from warnings import warn
from hparams import hparams, hparams_debug_string
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from wavenet_vocoder import WaveNet
from wavenet_vocoder.util import is_mulaw_quantize, is_scalar_input
import numpy as np
from src.loss import PredictNet
parser = argparse.ArgumentParser(description='TTS training')
parser.add_argument('--preset', type=str, default='', help='Path of preset parameters (json).')
parser.add_argument('--speaker_id', type=str, default='',
help=' Use specific speaker of data in case for multi-speaker datasets.')
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
args = parser.parse_args()
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
if args.preset is not None:
with open(args.preset) as f:
hparams.parse_json(f.read())
assert hparams.name == "wavenet_vocoder"
print(hparams_debug_string())
fs = hparams.sample_rate
output_json_path = join(args.checkpoint_dir, "hparams.json")
with open(output_json_path, "w") as f:
json.dump(hparams.values(), f, indent=2)
if is_mulaw_quantize(hparams.input_type):
if hparams.out_channels != hparams.quantize_channels:
raise RuntimeError(
"out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'")
if hparams.upsample_conditional_features and hparams.cin_channels < 0:
s = "Upsample conv layers were specified while local conditioning disabled. "
s += "Notice that upsample conv layers will never be used."
warn(s)
upsample_params = hparams.upsample_params
upsample_params["cin_channels"] = hparams.cin_channels
upsample_params["cin_pad"] = hparams.cin_pad
model = WaveNet(
out_channels=hparams.out_channels,
layers=hparams.layers,
stacks=hparams.stacks,
residual_channels=hparams.residual_channels,
gate_channels=hparams.gate_channels,
skip_out_channels=hparams.skip_out_channels,
cin_channels=hparams.cin_channels,
gin_channels=hparams.gin_channels,
n_speakers=hparams.n_speakers,
dropout=hparams.dropout,
kernel_size=hparams.kernel_size,
cin_pad=hparams.cin_pad,
upsample_conditional_features=hparams.upsample_conditional_features,
upsample_params=upsample_params,
scalar_input=is_scalar_input(hparams.input_type),
output_distribution=hparams.output_distribution,
)
Net = PredictNet(model)
Net.set_train(False)
receptive_field = model.receptive_field
print("Receptive field (samples / ms): {} / {}".format(receptive_field, receptive_field / fs * 1000))
param_dict = load_checkpoint(args.pretrain_ckpt)
load_param_into_net(model, param_dict)
print('Successfully loading the pre-trained model')
x = np.array(np.random.random((2, 256, 10240)), dtype=np.float32)
c = np.array(np.random.random((2, 80, 44)), dtype=np.float32)
g = np.array([0, 0], dtype=np.int64)
export(Net, Tensor(x), Tensor(c), Tensor(g), file_name="WaveNet", file_format='MINDIR')

View File

@ -0,0 +1,14 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,103 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defined callback for DeepFM.
"""
import time
from mindspore.train.callback import Callback
from mindspore import Tensor
import numpy as np
class TimeMonitor(Callback):
"""
Time monitor for calculating cost of each epoch.
Args:
data_size (int): step size of an epoch.
"""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / self.data_size
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
step_mseconds = (time.time() - self.step_time) * 1000
print(f"step time {step_mseconds}", flush=True)
class Monitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None
Examples:
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""
def __init__(self, lr_init=None):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time)
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.6f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
"""step end"""
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time)
step_loss = cb_params.net_outputs
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.6f}/{:5.6f}], time:[{:5.3f}], lr:[{:.9f}]".format(
cb_params.cur_epoch_num -
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1].asnumpy()))

View File

@ -0,0 +1,258 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Create train dataset.
"""
import os
import math
import numpy as np
import audio
from nnmnkwii.datasets import FileSourceDataset
from nnmnkwii import preprocessing as P
from wavenet_vocoder.util import is_mulaw_quantize
from train_pytorch import _pad, _pad_2d, to_categorical, ensure_divisible, RawAudioDataSource, MelSpecDataSource, assert_ready_for_upsampling
import mindspore.dataset.engine as de
def sequence_mask(sequence_length, max_len=None):
"""make sequence mask for loss"""
if max_len is None:
max_len = np.max(sequence_length)
batch_size = len(sequence_length)
seq_range = np.linspace(0, max_len - 1, max_len, dtype=np.int32)
seq_range_expand = np.tile(np.expand_dims(seq_range, 0), (batch_size, 1))
seq_length_expand = np.tile(np.expand_dims(sequence_length, 1), (1, max_len))
seq_length_expand = np.expand_dims(np.array(seq_range_expand < seq_length_expand, dtype=np.float32), -1)
return seq_length_expand
class DistributedSampler():
"""function to distribute and shuffle sample
"""
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
self.dataset = dataset
self.rank = rank
self.group_size = group_size
self.dataset_len = len(self.dataset)
self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
self.total_size = self.num_samplers * self.group_size
self.shuffle = shuffle
self.seed = seed
def __iter__(self):
if self.shuffle:
self.seed = (self.seed + 1) & 0xffffffff
np.random.seed(self.seed)
indices = np.random.permutation(self.dataset_len).tolist()
else:
indices = list(range(self.dataset_len))
indices += indices[:(self.total_size - len(indices))]
indices = indices[self.rank::self.group_size]
return iter(indices)
def __len__(self):
return self.num_samplers
def process_condition_batch(max_time_steps, hparams, batch):
"""process condition batch"""
cin_pad = hparams.cin_pad
new_batch = []
for batch_ in batch:
x, c, g = batch_
if hparams.upsample_conditional_features:
assert_ready_for_upsampling(x, c, cin_pad=0)
if max_time_steps is not None:
max_steps = ensure_divisible(max_time_steps, audio.get_hop_size(), True)
if len(x) > max_steps:
max_time_frames = max_steps // audio.get_hop_size()
s = np.random.randint(cin_pad, len(c) - max_time_frames - cin_pad)
ts = s * audio.get_hop_size()
x = x[ts:ts + audio.get_hop_size() * max_time_frames]
c = c[s - cin_pad:s + max_time_frames + cin_pad, :]
assert_ready_for_upsampling(x, c, cin_pad=cin_pad)
else:
x, c = audio.adjust_time_resolution(x, c)
if max_time_steps is not None and len(x) > max_time_steps:
s = np.random.randint(cin_pad, len(x) - max_time_steps - cin_pad)
x = x[s:s + max_time_steps]
c = c[s - cin_pad:s + max_time_steps + cin_pad, :]
assert len(x) == len(c)
new_batch.append((x, c, g))
return new_batch
def process_no_condition_batch(max_time_steps, batch):
"""process no condition batch"""
new_batch = []
for batch_ in batch:
x, c, g = batch_
x = audio.trim(x)
if max_time_steps is not None and len(x) > max_time_steps:
s = np.random.randint(0, len(x) - max_time_steps)
x = x[s:s + max_time_steps]
new_batch.append((x, c, g))
return new_batch
def collate_fn(batch, hparams):
"""
Create batch
"""
local_conditioning = len(batch[0]) >= 2 and hparams.cin_channels > 0
global_conditioning = len(batch[0]) >= 3 and hparams.gin_channels > 0
if hparams.max_time_sec is not None:
max_time_steps = int(hparams.max_time_sec * hparams.sample_rate)
elif hparams.max_time_steps is not None:
max_time_steps = hparams.max_time_steps
else:
max_time_steps = None
if local_conditioning:
new_batch = process_condition_batch(max_time_steps, hparams, batch)
else:
new_batch = process_no_condition_batch(max_time_steps, batch)
batch = new_batch
# Lengths
input_lengths = [len(x[0]) for x in batch]
max_input_len = max(input_lengths)
# (B, T, C)
# pad for time-axis
if is_mulaw_quantize(hparams.input_type):
padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1)
x_batch = np.array(
[_pad_2d(to_categorical(x[0], num_classes=hparams.quantize_channels), max_input_len, 0, padding_value) for x
in batch], dtype=np.float32)
else:
x_batch = np.array([_pad_2d(x[0].reshape(-1, 1), max_input_len)
for x in batch], dtype=np.float32)
assert len(x_batch.shape) == 3
# (B, T)
if is_mulaw_quantize(hparams.input_type):
padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1)
y_batch = np.array([_pad(x[0], max_input_len, constant_values=padding_value)
for x in batch], dtype=np.int32)
else:
y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.float32)
assert len(y_batch.shape) == 2
# (B, T, D)
if local_conditioning:
max_len = max([len(x[1]) for x in batch])
c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], dtype=np.float32)
assert len(c_batch.shape) == 3
# (B x C x T)
c_batch = c_batch.transpose((0, 2, 1))
else:
c_batch = np.zeros(hparams.batch_size, dtype=np.float32)
if global_conditioning:
g_batch = [x[2] for x in batch]
else:
# g_batch = None # MindSpore does not support None input
g_batch = np.zeros(hparams.batch_size, dtype=np.int64)
# Convert to channel first (B, C, T)
x_batch = x_batch.transpose((0, 2, 1))
# Add extra axis
if is_mulaw_quantize(hparams.input_type):
y_batch = np.expand_dims(y_batch, axis=-1)
else:
y_batch = np.expand_dims(y_batch, axis=-1)
input_lengths = input_lengths
mask = sequence_mask(input_lengths, max_len=x_batch.shape[-1])
return x_batch, y_batch, c_batch, g_batch, input_lengths, mask
class DualDataset():
"""Create Dataset loader for audio Mel and Audio"""
def __init__(self, X, Mel, length, batch_size, hparams):
self.multi_speaker = X.file_data_source.multi_speaker
self.X = X
self.Mel = Mel
self.length = length
self.hparams = hparams
self.sorted_index = list(np.argsort(length))
self.bins = [self.sorted_index[i:i + batch_size] for i in range(0, len(self.sorted_index), batch_size)]
if len(self.sorted_index) / batch_size != 0:
self.bins.append(self.sorted_index[-batch_size:])
self.size = len(self.bins)
def __getitem__(self, idx):
if self.multi_speaker:
speaker_id = self.X.file_data_source.speaker_ids[idx]
else:
speaker_id = None
combined_data = []
mel_len, audio_len = [], []
for i in self.bins[idx]:
if self.Mel is not None:
mel = self.Mel[i]
raw_audio = self.X[i]
length_mel, length_x = mel.shape[0], raw_audio.shape[0]
combined_data.append((raw_audio, mel, speaker_id))
mel_len.append(length_mel)
audio_len.append(length_x)
else:
raw_audio = self.X[i]
length_x = raw_audio.shape[0]
combined_data.append((raw_audio, speaker_id))
audio_len.append(length_x)
x_batch, y_batch, c_batch, g_batch, input_lengths, mask = collate_fn(combined_data, self.hparams)
return x_batch, y_batch, c_batch, g_batch, input_lengths, mask
def __len__(self):
return self.size
def get_data_loaders(dump_root, speaker_id, hparams=None, rank_id=None, group_size=None):
"""create train dataset"""
local_conditioning = hparams.cin_channels > 0
if hparams.max_time_steps is not None:
max_steps = ensure_divisible(hparams.max_time_steps, audio.get_hop_size(), True)
else:
max_steps = None
X = FileSourceDataset(
RawAudioDataSource(os.path.join(dump_root, 'train_no_dev'), speaker_id=speaker_id,
max_steps=max_steps, cin_pad=hparams.cin_pad,
hop_size=audio.get_hop_size()))
if local_conditioning:
Mel = FileSourceDataset(
MelSpecDataSource(os.path.join(dump_root, 'train_no_dev'), speaker_id=speaker_id,
max_steps=max_steps, cin_pad=hparams.cin_pad,
hop_size=audio.get_hop_size()))
assert len(X) == len(Mel)
print("Local conditioning enabled. Shape of a sample: {}.".format(Mel[0].shape))
else:
Mel = None
print("length of the dataset is {}".format(len(X)))
length_x = np.array(X.file_data_source.lengths)
dataset = DualDataset(X, Mel, length_x, batch_size=hparams.batch_size, hparams=hparams)
sampler = DistributedSampler(dataset, rank_id, group_size, shuffle=True, seed=0)
data_loaders = de.GeneratorDataset(dataset, ["x_batch", "y_batch", "c_batch", "g_batch", "input_lengths", "mask"],
sampler=sampler)
return data_loaders

View File

@ -0,0 +1,238 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""loss function definition"""
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mindspore import nn, Tensor
from mindspore.ops import operations as P
from nnmnkwii import preprocessing as P1
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw
from wavenet_vocoder.mixture import discretized_mix_logistic_loss
from wavenet_vocoder.mixture import mix_gaussian_loss
from train_pytorch import to_categorical
from tqdm import tqdm
import audio
import librosa
import librosa.display
matplotlib.use('Agg')
def sequence_mask(sequence_length, max_len=None):
"""make sequence mask"""
sequence_length = sequence_length.asnumpy()
if max_len is None:
max_len = np.max(sequence_length)
batch_size = sequence_length.shape[0]
seq_range = np.linspace(0, max_len-1, max_len, dtype=np.int32)
seq_range_expand = np.tile(np.expand_dims(seq_range, 0), (batch_size, 1))
seq_length_expand = np.tile(np.expand_dims(sequence_length, 1), (1, max_len))
seq_length_expand = np.expand_dims(np.array(seq_range_expand < seq_length_expand, dtype=np.float32), -1)
return Tensor(seq_length_expand)
class MaskedCrossEntropyLoss(nn.Cell):
"""MaskedCrossEntropyLoss"""
def __init__(self):
super(MaskedCrossEntropyLoss, self).__init__()
self.criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
def construct(self, inputs, target):
losses = self.criterion(inputs, target)
return losses
class DiscretizedMixturelogisticLoss(nn.Cell):
"""DiscretizedMixturelogisticLoss"""
def __init__(self, hparams):
super(DiscretizedMixturelogisticLoss, self).__init__()
self.quantize_channels = hparams.quantize_channels
self.log_scale_min = hparams.log_scale_min
self.discretized_mix_logistic_loss = discretized_mix_logistic_loss(num_classes=hparams.quantize_channels,
log_scale_min=hparams.log_scale_min,
reduce=False)
self.reduce_sum_op = P.ReduceSum()
self.reduce_mean_op = P.ReduceMean()
def construct(self, inputs, target, mask=None):
losses = self.discretized_mix_logistic_loss(inputs, target)
return self.reduce_sum_op(losses * mask) / self.reduce_sum_op(mask)
class MixtureGaussianLoss(nn.Cell):
"""MixtureGaussianLoss"""
def __init__(self, hparams):
super(MixtureGaussianLoss, self).__init__()
self.quantize_channels = hparams.quantize_channels
self.log_scale_min = hparams.log_scale_min
self.mix_gaussian_loss = mix_gaussian_loss(log_scale_min=hparams.log_scale_min, reduce=False)
self.reduce_sum_op = P.ReduceSum()
self.reduce_mean_op = P.ReduceMean()
def construct(self, inputs, target, mask=None):
"""
Args:
inputs (Tensor): Predicted distribution
target (Tensor): Target
mask (Tensor): Mask
Returns:
Tensor: Loss tensor
"""
losses = self.mix_gaussian_loss(inputs, target)
return self.reduce_sum_op(losses * mask) / self.reduce_sum_op(mask)
def save_waveplot(path, y_hat, y_target, sample_rate):
sr = sample_rate
plt.figure(figsize=(16, 6))
plt.subplot(2, 1, 1)
librosa.display.waveplot(y_target, sr=sr)
plt.subplot(2, 1, 2)
librosa.display.waveplot(y_hat, sr=sr)
plt.tight_layout()
plt.savefig(path, format="png")
plt.close()
def eval_model(hparams, global_step, model, x, y, c, g, input_lengths, eval_dir):
"""
Function for model evaluation. This function is used for debugging in this project.
"""
model.set_train(False)
idx = np.random.randint(0, len(y))
length = input_lengths.asnumpy()[idx]
y_target = np.reshape(y.asnumpy()[idx], (-1))
y_target = y_target[:length]
if c is not None:
expand_op = P.ExpandDims()
if hparams.upsample_conditional_features:
c = expand_op(c[idx, :, :int(length // audio.get_hop_size() + hparams.cin_pad * 2)], 0)
else:
c = expand_op(c[idx, :, :length], 0)
assert c.dim() == 3
print("Shape of local conditioning features: {}".format(c.size()))
if g is not None:
g = g[idx]
print("Shape of global conditioning features: {}".format(g.size()))
# Dummy silence
if is_mulaw_quantize(hparams.input_type):
initial_value = P1.mulaw_quantize(0, hparams.quantize_channels - 1)
elif is_mulaw(hparams.input_type):
initial_value = P1.mulaw(0.0, hparams.quantize_channels)
else:
initial_value = 0.0
# (C,)
if is_mulaw_quantize(hparams.input_type):
initial_input = to_categorical(
initial_value, num_classes=hparams.quantize_channels).astype(np.float32)
initial_input = Tensor(np.reshape(initial_input, (1, 1, hparams.quantize_channels)))
else:
initial_input = np.ones((1, 1, 1)) * initial_value
initial_input = Tensor(initial_input)
# Run the model in fast eval mode
y_hat = model.incremental_forward(initial_input, c=c, g=g, T=length, softmax=True, quantize=True, tqdm=tqdm,
log_scale_min=hparams.log_scale_min)
if is_mulaw_quantize(hparams.input_type):
y_hat = np.reshape(np.argmax(y_hat, 1), (-1))
y_hat = P1.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1)
y_target = P1.inv_mulaw_quantize(y_target, hparams.quantize_channels - 1)
elif is_mulaw(hparams.input_type):
y_hat = P1.inv_mulaw(np.reshape(y_hat, (-1)), hparams.quantize_channels)
y_target = P1.inv_mulaw(y_target, hparams.quantize_channels)
else:
y_hat = np.reshape(y_hat, (-1))
# Save audio
os.makedirs(eval_dir, exist_ok=True)
path = os.path.join(eval_dir, "step{:09d}_predicted.wav".format(global_step))
librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate)
path = os.path.join(eval_dir, "step{:09d}_target.wav".format(global_step))
librosa.output.write_wav(path, y_target, sr=hparams.sample_rate)
# Save figure
path = os.path.join(eval_dir, "step{:09d}_waveplots.png".format(global_step))
save_waveplot(path, y_hat, y_target, hparams.sample_rate)
class PredictNet(nn.Cell):
"""
NetWithLossClass definition
"""
def __init__(self, network):
super(PredictNet, self).__init__(auto_prefix=False)
self.network = network
def construct(self, x, c, g):
y_hat = self.network(x, c, g, False)
return y_hat
class NetWithLossClass(nn.Cell):
"""
NetWithLossClass definition
Args:
network (Cell): Pre-defined WaveNet.
hparams (optional): Parameters.
Returns:
Tensor, loss tensor.
"""
def __init__(self, network, hparams):
super(NetWithLossClass, self).__init__(auto_prefix=False)
self.network = network
self.hparams = hparams
self.ReduceMean_false = P.ReduceMean(keep_dims=False)
self.expand_op = P.ExpandDims()
self.transpose_op = P.Transpose()
self.reshape_op = P.Reshape()
self.is_mulaw_quant = is_mulaw_quantize(hparams.input_type)
if self.is_mulaw_quant:
self.criterion = MaskedCrossEntropyLoss()
else:
if hparams.output_distribution == "Logistic":
self.criterion = DiscretizedMixturelogisticLoss(hparams)
elif hparams.output_distribution == "Normal":
self.criterion = MixtureGaussianLoss(hparams)
else:
self.criterion = None
raise RuntimeError(
"Not supported output distribution type: {}".format(hparams.output_distribution))
def construct(self, x, y, c, g, input_lengths, mask):
y_hat = self.network(x, c, g, False)
if self.is_mulaw_quant:
y_hat = self.transpose_op(y_hat[:, :, :-1], (0, 2, 1))
y_hat = self.reshape_op(y_hat, (-1, y_hat.shape[-1]))
y = self.reshape_op(y[:, 1:, 0], (-1,))
loss = self.criterion(y_hat, y)
else:
loss = self.criterion(y_hat[:, :, :-1], y[:, 1:, :], mask[:, 1:, :])
return loss

View File

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import numpy as np
def get_lr(init_lr, total_epoch, step_per_epoch,
anneal_rate=0.5,
anneal_interval=200000):
"""
Learning rate generating
Args:
init_lr (float): Initial learning rate
total_epoch (int): Total epoch
step_per_epoch (int): Step per epoch
anneal_rate (float): anneal rate
anneal_interval (int ): anneal interval
Returns:
ndarray: learning rate
"""
total_step = total_epoch * step_per_epoch
lr_step = []
for i in range(total_step):
lr_step.append(init_lr * anneal_rate ** (i // anneal_interval))
learning_rate = np.array(lr_step).astype(np.float32)
return learning_rate

View File

@ -0,0 +1,135 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_criteo."""
import os
from os.path import join
import json
import argparse
from warnings import warn
from hparams import hparams, hparams_debug_string
from mindspore import context, Tensor
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Adam
from mindspore.nn import TrainOneStepCell
from mindspore.train import Model
from src.lr_generator import get_lr
from src.dataset import get_data_loaders
from src.loss import NetWithLossClass
from src.callback import Monitor
from wavenet_vocoder import WaveNet
from wavenet_vocoder.util import is_mulaw_quantize, is_scalar_input
parser = argparse.ArgumentParser(description='TTS training')
parser.add_argument('--data_path', type=str, required=True, default='',
help='Directory contains preprocessed features.')
parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_test',
help='Directory where to save model checkpoints [default: checkpoints].')
parser.add_argument('--checkpoint', type=str, default='', help='Restore model from checkpoint path if given.')
parser.add_argument('--speaker_id', type=str, default='',
help=' Use specific speaker of data in case for multi-speaker datasets.')
parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training')
args = parser.parse_args()
if __name__ == '__main__':
if args.is_distributed:
init('nccl')
rank_id = get_rank()
group_size = get_group_size()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
rank_id = 0
group_size = 1
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
if args.preset is not None:
with open(args.preset) as f:
hparams.parse_json(f.read())
assert hparams.name == "wavenet_vocoder"
print(hparams_debug_string())
fs = hparams.sample_rate
os.makedirs(args.checkpoint_dir, exist_ok=True)
output_json_path = join(args.checkpoint_dir, "hparams.json")
with open(output_json_path, "w") as f:
json.dump(hparams.values(), f, indent=2)
data_loaders = get_data_loaders(args.data_path, args.speaker_id, hparams=hparams, rank_id=rank_id,
group_size=group_size)
step_size_per_epoch = data_loaders.get_dataset_size()
if is_mulaw_quantize(hparams.input_type):
if hparams.out_channels != hparams.quantize_channels:
raise RuntimeError(
"out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'")
if hparams.upsample_conditional_features and hparams.cin_channels < 0:
s = "Upsample conv layers were specified while local conditioning disabled. "
s += "Notice that upsample conv layers will never be used."
warn(s)
upsample_params = hparams.upsample_params
upsample_params["cin_channels"] = hparams.cin_channels
upsample_params["cin_pad"] = hparams.cin_pad
model = WaveNet(
out_channels=hparams.out_channels,
layers=hparams.layers,
stacks=hparams.stacks,
residual_channels=hparams.residual_channels,
gate_channels=hparams.gate_channels,
skip_out_channels=hparams.skip_out_channels,
cin_channels=hparams.cin_channels,
gin_channels=hparams.gin_channels,
n_speakers=hparams.n_speakers,
dropout=hparams.dropout,
kernel_size=hparams.kernel_size,
cin_pad=hparams.cin_pad,
upsample_conditional_features=hparams.upsample_conditional_features,
upsample_params=upsample_params,
scalar_input=is_scalar_input(hparams.input_type),
output_distribution=hparams.output_distribution,
)
loss_net = NetWithLossClass(model, hparams)
lr = get_lr(hparams.optimizer_params["lr"], hparams.nepochs, step_size_per_epoch)
lr = Tensor(lr)
if args.checkpoint != '':
param_dict = load_checkpoint(args.pre_trained_model_path)
load_param_into_net(model, param_dict)
print('Successfully loading the pre-trained model')
weights = model.trainable_params()
optimizer = Adam(weights, learning_rate=lr, loss_scale=1024.)
train_net = TrainOneStepCell(loss_net, optimizer)
model = Model(train_net)
lr_cb = Monitor(lr)
callback_list = [lr_cb]
if args.is_distributed:
ckpt_path = os.path.join(args.checkpoint_dir, 'ckpt_' + str(get_rank()) + '/')
else:
ckpt_path = args.checkpoint_dir
config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=10)
ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck)
callback_list.append(ckpt_cb)
model.train(hparams.nepochs, data_loaders, callbacks=callback_list)

View File

@ -0,0 +1,17 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init"""
from __future__ import with_statement, print_function, absolute_import
from .wavenet import WaveNet

View File

@ -0,0 +1,176 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Extended Conv1D."""
import math
from mindspore import nn, Tensor
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
import numpy as np
class Conv1d(nn.Conv1d):
"""
Extended nn.Conv1d to adapt to incremental dilated convolutions.
During training, initial Conv1D is used and during evaluation, incremental_forward is called.
To improve the inference speed, tensor will be converted as numpy and the following calculation is based on numpy.
These operation will be replaced with MindSpore ops in the future. Currently, some operation is not supported by
MindSpore and a mixed use of numpy and MindSpore will take a long time.
"""
def __init__(self, *args, **kwargs):
super(Conv1d, self).__init__(*args, **kwargs)
self.clear_buffer()
self._linearized_weight = None
self.transpose_op = P.Transpose()
self.reshape_op = P.Reshape()
self.squeeze_op = P.Squeeze(-2)
self.zeros = P.Zeros()
self.concat_op = P.Concat(axis=1)
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
self.get_weight = None
self.get_bias = None
def incremental_forward(self, inputs, is_numpy=True):
if is_numpy:
return self.incremental_forward_numpy(inputs)
return self.incremental_forward_pynative(inputs)
def incremental_forward_pynative(self, inputs):
"""
Incremental forward.
Args:
inputs: B x T x C
Returns:
ndarray
"""
# input: (B, T, C)
if self.training:
raise RuntimeError('incremental_forward only supports eval mode')
if self.get_weight is None:
self.get_weight = self._get_linearized_weight()
if self.get_bias is None and self.bias is not None:
self.get_bias = self.bias
# Note mindspore uses Conv2D to construct Conv1D
kw = self.kernel_size[1]
dilation = self.dilation[1]
bsz = inputs.shape[0] # input: bsz x len x dim
if kw > 1:
if self.input_buffer is None:
init_buffer = self.zeros((bsz, kw + (kw - 1) * (dilation - 1), inputs.shape[2]), mstype.float32)
self.input_buffer = self.concat_op((init_buffer[:, 1:, :], inputs[:, 0:1, :]))
else:
# shift buffer
self.input_buffer = self.concat_op((self.input_buffer[:, 1:, :], inputs[:, 0:1, :]))
inputs = self.input_buffer
if dilation > 1:
inputs = inputs[:, 0::dilation, :]
output = self.matmul(self.reshape_op(inputs, (bsz, -1)), self.get_weight)
if self.bias is not None:
output = self.bias_add(output, self.bias)
return self.reshape_op(output, (bsz, 1, -1))
def incremental_forward_numpy(self, inputs):
"""
Incremental forward.
Args:
inputs: B x T x C
Returns:
ndarray
"""
# input: (B, T, C)
if self.training:
raise RuntimeError('incremental_forward only supports eval mode')
if self.get_weight is None:
weight = self._get_linearized_weight()
self.get_weight = weight.asnumpy()
if self.get_bias is None and self.bias is not None:
bias = self.bias
self.get_bias = bias.asnumpy()
# Note mindspore uses Conv2D to construct Conv1D
kw = self.kernel_size[1]
dilation = self.dilation[1]
bsz = inputs.shape[0] # input: bsz x len x dim
if kw > 1:
if self.input_buffer is None:
self.input_buffer = np.zeros((bsz, kw + (kw - 1) * (dilation - 1), inputs.shape[2]), dtype=np.float32)
else:
# shift buffer
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :]
# append next
self.input_buffer[:, -1, :] = inputs[:, -1, :]
inputs = self.input_buffer
if dilation > 1:
inputs = inputs[:, 0::dilation, :]
output = inputs.reshape(bsz, -1).dot(self.get_weight.T)
if self.bias is not None:
output = output + np.expand_dims(self.get_bias, 0)
return np.reshape(output, (bsz, 1, -1))
def clear_buffer(self):
self.input_buffer = None
def _get_linearized_weight(self):
"""
get linearized weight
"""
weight = self.squeeze_op(self.weight)
if self._linearized_weight is None:
# Note mindspore uses Conv2D to construct Conv1D
kw = self.kernel_size[1]
if weight.shape == (self.out_channels, self.in_channels, kw):
weight = self.transpose_op(weight, (0, 2, 1))
else:
weight = self.transpose_op(weight, (2, 0, 1))
self._linearized_weight = self.reshape_op(weight, (self.out_channels, -1))
return self._linearized_weight
def _clear_linearized_weight(self, *args):
self._linearized_weight = None
def _initialize_weights(self):
"""
weight initialization
"""
self.init_parameters_data()
std_mul = 4.0
for _, m in self.cells_and_names():
if isinstance(m, nn.Conv1d):
std = math.sqrt((std_mul * 0.1) / (m.kernel_size[1] * self.in_channels))
m.weight.set_data(Tensor(np.random.normal(0, std, m.weight.data.shape).astype("float32")))
if m.bias is not None:
m.bias.set_data(
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_data(
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_data(
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))

View File

@ -0,0 +1,323 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
loss function for training and sample function for testing
"""
import numpy as np
import mindspore as ms
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.ops as P
class log_sum_exp(nn.Cell):
"""Numerically stable log_sum_exp
"""
def __init__(self):
super(log_sum_exp, self).__init__()
self.maxi = P.ReduceMax()
self.maxi_dim = P.ReduceMax(keep_dims=True)
self.log = P.Log()
self.sums = P.ReduceSum()
self.exp = P.Exp()
def construct(self, x):
axis = len(x.shape) - 1
m = self.maxi(x, axis)
m2 = self.maxi_dim(x, axis)
return m + self.log(self.sums(self.exp(x - m2), axis))
class Stable_softplus(nn.Cell):
"""Numerically stable softplus
"""
def __init__(self):
super(Stable_softplus, self).__init__()
self.log_op = P.Log()
self.abs_op = P.Abs()
self.relu_op = P.ReLU()
self.exp_op = P.Exp()
def construct(self, x):
return self.log_op(1 + self.exp_op(- self.abs_op(x))) + self.relu_op(x)
class discretized_mix_logistic_loss(nn.Cell):
"""
Discretized_mix_logistic_loss
Args:
num_classes (int): Num_classes
log_scale_min (float): Log scale minimum value
"""
def __init__(self, num_classes=256, log_scale_min=-7.0, reduce=True):
super(discretized_mix_logistic_loss, self).__init__()
self.num_classes = num_classes
self.log_scale_min = log_scale_min
self.reduce = reduce
self.transpose_op = P.Transpose()
self.exp = P.Exp()
self.sigmoid = P.Sigmoid()
self.softplus = Stable_softplus()
self.log = P.Log()
self.cast = P.Cast()
self.logsoftmax = P.LogSoftmax(-1)
self.expand_dims = P.ExpandDims()
self.tile = P.Tile()
self.maximum = P.Maximum()
self.sums = P.ReduceSum()
self.lse = log_sum_exp()
self.reshape = P.Reshape()
self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32))
def construct(self, y_hat, y):
"""
Args:
y_hat (Tensor): Predicted distribution
y (Tensor): Target
Returns:
Tensor: Discretized_mix_logistic_loss
"""
nr_mix = y_hat.shape[1] // 3
# (B x T x C)
y_hat = self.transpose_op(y_hat, (0, 2, 1))
# (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix]
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], self.log_scale_min)
# B x T x 1 -> B x T x num_mixtures
y = self.tile(y, (1, 1, nr_mix))
centered_y = y - means
inv_stdv = self.exp(-log_scales)
plus_in = inv_stdv * (centered_y + 1. / (self.num_classes - 1))
cdf_plus = self.sigmoid(plus_in)
min_in = inv_stdv * (centered_y - 1. / (self.num_classes - 1))
cdf_min = self.sigmoid(min_in)
log_cdf_plus = plus_in - self.softplus(plus_in)
log_one_minus_cdf_min = -self.softplus(min_in)
cdf_delta = cdf_plus - cdf_min
mid_in = inv_stdv * centered_y
log_pdf_mid = mid_in - log_scales - 2. * self.softplus(mid_in)
inner_inner_cond = self.cast(cdf_delta > 1e-5, ms.float32)
inner_inner_out = inner_inner_cond * \
self.log(self.maximum(cdf_delta, 1e-12)) + \
(1. - inner_inner_cond) * (log_pdf_mid - self.factor)
inner_cond = self.cast(y > 0.999, ms.float32)
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
cond = self.cast(y < -0.999, ms.float32)
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
a, b, c = logit_probs.shape[0], logit_probs.shape[1], logit_probs.shape[2]
logit_probs = self.logsoftmax(self.reshape(logit_probs, (-1, c)))
logit_probs = self.reshape(logit_probs, (a, b, c))
log_probs = log_probs + logit_probs
if self.reduce:
return -self.sums(self.lse(log_probs))
return self.expand_dims(-self.lse(log_probs), -1)
def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0):
"""
Sample from discretized mixture of logistic distributions
Args:
y (ndarray): B x C x T
log_scale_min (float): Log scale minimum value
Returns:
ndarray
"""
nr_mix = y.shape[1] // 3
# B x T x C
y = np.transpose(y, (0, 2, 1))
logit_probs = y[:, :, :nr_mix]
temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape)
temp = logit_probs - np.log(- np.log(temp))
argmax = np.argmax(temp, axis=-1)
# (B, T) -> (B, T, nr_mix)
one_hot = np.eye(nr_mix)[argmax]
means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
log_scales = np.clip(np.sum(
y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1), a_min=log_scale_min, a_max=None)
u = np.random.uniform(1e-5, 1.0 - 1e-5, means.shape)
x = means + np.exp(log_scales) * (np.log(u) - np.log(1. - u))
x = np.clip(x, -1., 1.)
return x.astype(np.float32)
class mix_gaussian_loss(nn.Cell):
"""
Mix gaussian loss
"""
def __init__(self, log_scale_min=-7.0, reduce=True):
super(mix_gaussian_loss, self).__init__()
self.log_scale_min = log_scale_min
self.reduce = reduce
self.transpose_op = P.Transpose()
self.maximum = P.Maximum()
self.tile = P.Tile()
self.exp = P.Exp()
self.logsoftmax = P.LogSoftmax(-1)
self.expand_dims = P.ExpandDims()
self.sums = P.ReduceSum()
self.lse = log_sum_exp()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.const = P.ScalarToArray()
self.log = P.Log()
def construct(self, y_hat, y):
"""
Args:
y_hat (Tensor): Predicted probability
y (Tensor): Target
Returns:
Tensor: Mix_gaussian_loss
"""
C = y_hat.shape[1]
if C == 2:
nr_mix = 1
else:
nr_mix = y_hat.shape[1] // 3
# (B x T x C)
y_hat = self.transpose_op(y_hat, (0, 2, 1))
if C == 2:
logit_probs = None
means = y_hat[:, :, 0:1]
log_scales = self.maximum(y_hat[:, :, 1:2], self.log_scale_min)
else:
# (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix]
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], self.log_scale_min)
# B x T x 1 -> B x T x num_mixtures
y = self.tile(y, (1, 1, nr_mix))
centered_y = y - means
sd = self.exp(log_scales)
unnormalized_log_prob = -1. * (self.sq(centered_y - 0.)) / (2. * self.sq(sd))
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
log_probs = unnormalized_log_prob + neg_normalization
if nr_mix > 1:
log_probs = log_probs + self.logsoftmax(logit_probs)
if self.reduce:
if nr_mix == 1:
return -self.sums(log_probs)
return -self.sums(self.lse(log_probs))
if nr_mix == 1:
return -log_probs
return self.expand_dims(-self.lse(log_probs), -1)
def sample_from_mix_gaussian(y, log_scale_min=-7.0):
"""
Sample_from_mix_gaussian
Args:
y (ndarray): B x C x T
Returns:
ndarray
"""
C = y.shape[1]
if C == 2:
nr_mix = 1
else:
nr_mix = y.shape[1] // 3
# B x T x C
y = np.transpose(y, (0, 2, 1))
if C == 2:
logit_probs = None
else:
logit_probs = y[:, :, :nr_mix]
if nr_mix > 1:
temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape)
temp = logit_probs - np.log(- np.log(temp))
argmax = np.argmax(temp, axis=-1)
# (B, T) -> (B, T, nr_mix)
one_hot = np.eye(nr_mix)[argmax]
means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
log_scales = np.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1)
else:
if C == 2:
means, log_scales = y[:, :, 0], y[:, :, 1]
elif C == 3:
means, log_scales = y[:, :, 1], y[:, :, 2]
else:
assert False, "shouldn't happen"
scales = np.exp(log_scales)
x = np.random.normal(loc=means, scale=scales)
x = np.clip(x, -1., 1.)
return x.astype(np.float32)
# self-implemented onehotcategorical distribution
# https://zhuanlan.zhihu.com/p/59550457
def sample_from_mix_onehotcategorical(x):
"""
Sample_from_mix_onehotcategorical
Args:
x (ndarray): Predicted softmax probability
Returns:
ndarray
"""
pi = np.log(x)
u = np.random.uniform(0, 1, x.shape)
g = -np.log(-np.log(u))
c = np.argmax(pi + g, axis=1)
return np.array(np.eye(256)[c], dtype=np.float32)

View File

@ -0,0 +1,213 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
modules for wavenet
"""
from __future__ import with_statement, print_function, absolute_import
import math
import numpy as np
from wavenet_vocoder import conv
from mindspore import nn
from mindspore.ops import operations as P
def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs)
return m
def Conv1d1x1(in_channels, out_channels, has_bias=True):
return Conv1d(in_channels, out_channels, kernel_size=1, pad_mode='pad', padding=0, dilation=1, has_bias=has_bias)
def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
return m
def _conv1x1_forward(conv_, x, is_incremental, is_numpy=True):
"""
Conv1x1 forward
"""
if is_incremental:
x = conv_.incremental_forward(x, is_numpy=is_numpy)
else:
x = conv_(x)
return x
class ResidualConv1dGLU(nn.Cell):
"""Residual dilated conv1d with gated activation units
Args:
residual_channels (int): Residual input / output channels
gate_channels (int): Gated activation channels.
kernel_size (int): Kernel size
skip_out_channels (int): Skip connection channels. If None, it will set to the same as residual_channels.
cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled.
gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled.
dropout (float): Dropout rate.
padding (int): Padding for convolution layers. If None, padding value will be computed according to dilation
and kernel_size.
dilation (int): Dilation factor.
"""
def __init__(self, residual_channels=None, gate_channels=None, kernel_size=None, skip_out_channels=None, bias=True,
dropout=1 - 0.95, dilation=1, cin_channels=-1, gin_channels=-1, padding=None, causal=True):
super(ResidualConv1dGLU, self).__init__()
self.dropout = dropout
self.dropout_op = nn.Dropout(keep_prob=1. - self.dropout)
self.eval_split_op = P.Split(axis=-1, output_num=2)
self.train_split_op = P.Split(axis=1, output_num=2)
self.tanh = P.Tanh()
self.sigmoid = P.Sigmoid()
self.mul = P.Mul()
self.add = P.TensorAdd()
if skip_out_channels is None:
skip_out_channels = residual_channels
if padding is None:
if causal:
padding = (kernel_size - 1) * dilation
else:
padding = (kernel_size - 1) // 2 * dilation
self.causal = causal
self.conv = Conv1d(residual_channels, gate_channels, kernel_size, pad_mode='pad',
padding=padding, dilation=dilation, has_bias=bias)
# local conditioning
if cin_channels > 0:
self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, has_bias=False)
else:
self.conv1x1c = None
# global conditioning
if gin_channels > 0:
self.conv1x1g = Conv1d(gin_channels, gate_channels, has_bias=False, kernel_size=1, dilation=1)
else:
self.conv1x1g = None
gate_out_channels = gate_channels // 2
self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, has_bias=bias)
self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, has_bias=bias)
self.factor = math.sqrt(0.5)
def construct(self, x, c=None, g=None):
"""
Args:
x(Tensor): One-hot audio signal, the shape is B x C x T
c(Tensor): local conditional feature, the shape is B x cin_channels x T
g(Tensor): global conditional feature, not used currently
Returns:
Tensor: Output tensor
"""
residual = x
x = self.dropout_op(x)
x = self.conv(x)
# remove future time steps
x = x[:, :, :residual.shape[-1]] if self.causal else x
split_op = self.train_split_op
a, b = split_op(x)
# local conditioning
if c is not None:
c = _conv1x1_forward(self.conv1x1c, c, is_incremental=False)
ca, cb = split_op(c)
a, b = a + ca, b + cb
# global conditioning
if g is not None:
g = _conv1x1_forward(self.conv1x1g, g, is_incremental=False)
ga, gb = self.split(g)
a, b = a + ga, b + gb
x = self.mul(self.tanh(a), self.sigmoid(b))
# For skip connection
s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental=False)
# For residual connection
x = _conv1x1_forward(self.conv1x1_out, x, is_incremental=False)
x = self.add(x, residual) * self.factor
return x, s
def sigmoid_numpy(self, x):
return 1. / (1 + np.exp(-x))
def incremental_forward(self, x, c=None, g=None, is_numpy=True):
"""
Incremental forward. Used for inference stage
Args:
x (Tensor): One-hot audio signal, the shape is B x C x T
c (Tensor): local conditional feature, the shape is B x cin_channels x T
g (Tensor): global conditional feature, not used currently
Returns:
ndarray
"""
residual = x
x = self.conv.incremental_forward(x, is_numpy=is_numpy)
if is_numpy:
a, b = np.split(x, indices_or_sections=2, axis=-1)
else:
a, b = self.eval_split_op(x)
# local conditioning
if c is not None:
c = _conv1x1_forward(self.conv1x1c, c, is_incremental=True, is_numpy=is_numpy)
if is_numpy:
ca, cb = np.split(c, indices_or_sections=2, axis=-1)
else:
ca, cb = self.eval_split_op(c)
a, b = a + ca, b + cb
# global conditioning
if g is not None:
g = _conv1x1_forward(self.conv1x1g, g, is_incremental=True, is_numpy=is_numpy)
if is_numpy:
ga, gb = np.split(g, indices_or_sections=2, axis=-1)
else:
ga, gb = self.eval_split_op(c)
a, b = a + ga, b + gb
if is_numpy:
x = np.tanh(a) * self.sigmoid_numpy(b)
else:
x = self.mul(self.tanh(a), self.sigmoid(b))
# For skip connection
s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental=True, is_numpy=is_numpy)
# For residual connection
x = _conv1x1_forward(self.conv1x1_out, x, is_incremental=True, is_numpy=is_numpy)
x = (x + residual) * self.factor
return x, s
def clear_buffer(self):
"""clear buffer"""
for c in [self.conv, self.conv1x1_out, self.conv1x1_skip,
self.conv1x1c, self.conv1x1g]:
if c is not None:
c.clear_buffer()

View File

@ -0,0 +1,118 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Upsampling
"""
from __future__ import with_statement, print_function, absolute_import
import numpy as np
from mindspore import nn
from mindspore.ops import operations as P
class Resize(nn.Cell):
"""
Resize input Tensor
"""
def __init__(self, x_scale, y_scale, mode="nearest"):
super(Resize, self).__init__()
self.x_scale = x_scale
self.y_scale = y_scale
self.mode = mode
def construct(self, x):
_, _, h, w = x.shape
interpolate_op = P.ResizeNearestNeighbor((self.y_scale * h, self.x_scale * w))
return interpolate_op(x)
def _get_activation(upsample_activation):
"""get activation"""
nonlinear = getattr(nn, upsample_activation)
return nonlinear
class UpsampleNetwork(nn.Cell):
"""UpsampleNetwork"""
def __init__(self, upsample_scales, mode="nearest",
freq_axis_kernel_size=1, cin_pad=0, cin_channels=80):
super(UpsampleNetwork, self).__init__()
self.expand_op = P.ExpandDims()
self.squeeze_op = P.Squeeze(1)
up_layers = []
total_scale = np.prod(upsample_scales)
self.indent = cin_pad * total_scale
for scale in upsample_scales:
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
k_size = (freq_axis_kernel_size, scale * 2 + 1)
# padding = (freq_axis_padding, scale)
padding = (freq_axis_padding, freq_axis_padding, scale, scale)
stretch = Resize(scale, 1, mode)
conv = nn.Conv2d(1, 1, kernel_size=k_size, has_bias=False, pad_mode='pad', padding=padding)
up_layers.append(stretch)
up_layers.append(conv)
# if upsample_activation != "none":
# nonlinear = _get_activation(upsample_activation)
# up_layers.append(nonlinear(**upsample_activation_params))
self.up_layers = nn.CellList(up_layers)
def construct(self, c):
"""
Args:
c (Tensor): Local conditioning feature
Returns:
Tensor: Upsampling feature
"""
# B x 1 x C x T
c = self.expand_op(c, 1)
for f in self.up_layers:
c = f(c)
# B x C x T
c = self.squeeze_op(c)
# if self.indent > 0:
# c = c[:, :, self.indent:-self.indent]
return c
class ConvInUpsampleNetwork(nn.Cell):
"""Upsample Network
Args:
upsample_scales (list): Upsample_scales list.
upsample_activation (str): Upsample_activation.
mode (str): Resize mode, default is NearestNeighbor.
cin_channels (int): Local conditioning channels.
freq_axis_kernel_size (int): Freq-axis kernel_size for the convolution layers after resize.
"""
def __init__(self, upsample_scales, mode="nearest",
freq_axis_kernel_size=1, cin_pad=0,
cin_channels=80):
super(ConvInUpsampleNetwork, self).__init__()
ks = 2 * cin_pad + 1
self.conv_in = nn.Conv1d(cin_channels, cin_channels, kernel_size=ks, has_bias=False, pad_mode='pad', padding=0)
self.upsample = UpsampleNetwork(upsample_scales, mode, freq_axis_kernel_size, cin_pad=0,
cin_channels=cin_channels)
def construct(self, c):
c = self.conv_in(c)
c_up = self.upsample(c)
return c_up

View File

@ -0,0 +1,346 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""WaveNet construction"""
from __future__ import with_statement, print_function, absolute_import
import math
import numpy as np
from mindspore import nn, Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from wavenet_vocoder import upsample
from .modules import Embedding
from .modules import Conv1d1x1
from .modules import ResidualConv1dGLU
from .mixture import sample_from_discretized_mix_logistic
from .mixture import sample_from_mix_gaussian
from .mixture import sample_from_mix_onehotcategorical
class WaveNet(nn.Cell):
"""
WaveNet model definition. Only local condition is supported
Args:
out_channels (int): Output channels. If input_type is mu-law quantized one-hot vecror, it should equal to the
quantize channels. Otherwise, it equals to num_mixtures x 3. Default: 256.
layers (int): Number of ResidualConv1dGLU layers
stacks (int): Number of dilation cycles
residual_channels (int): Residual input / output channels
gate_channels (int): Gated activation channels.
skip_out_channels (int): Skip connection channels.
kernel_size (int): Kernel size .
dropout (float): Dropout rate.
cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled.
gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled.
n_speakers (int): Number of speakers. This is used when global conditioning is enabled.
upsample_conditional_features (bool): Whether upsampling local conditioning features by resize_nearestneighbor
and conv or not.
scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise, quantized one-hot vector
is expected.
use_speaker_embedding (Bool): Use speaker embedding or Not.
"""
def __init__(self, out_channels=256, layers=20, stacks=2,
residual_channels=512,
gate_channels=512,
skip_out_channels=512,
kernel_size=3, dropout=1 - 0.95,
cin_channels=-1, gin_channels=-1, n_speakers=None,
upsample_conditional_features=False,
upsample_net="ConvInUpsampleNetwork",
upsample_params=None,
scalar_input=False,
use_speaker_embedding=False,
output_distribution="Logistic",
cin_pad=0,
):
super(WaveNet, self).__init__()
self.transpose_op = P.Transpose()
self.softmax = P.Softmax(axis=1)
self.reshape_op = P.Reshape()
self.zeros_op = P.Zeros()
self.ones_op = P.Ones()
self.relu_op = P.ReLU()
self.squeeze_op = P.Squeeze()
self.expandim_op = P.ExpandDims()
self.transpose_op = P.Transpose()
self.tile_op = P.Tile()
self.scalar_input = scalar_input
self.out_channels = out_channels
self.cin_channels = cin_channels
self.output_distribution = output_distribution
self.fack_data = P.Zeros()
assert layers % stacks == 0
layers_per_stack = layers // stacks
if scalar_input:
self.first_conv = Conv1d1x1(1, residual_channels)
else:
self.first_conv = Conv1d1x1(out_channels, residual_channels)
conv_layers = []
for layer in range(layers):
dilation = 2 ** (layer % layers_per_stack)
conv = ResidualConv1dGLU(
residual_channels, gate_channels,
kernel_size=kernel_size,
skip_out_channels=skip_out_channels,
bias=True,
dropout=dropout,
dilation=dilation,
cin_channels=cin_channels,
gin_channels=gin_channels)
conv_layers.append(conv)
self.conv_layers = nn.CellList(conv_layers)
self.last_conv_layers = nn.CellList([
nn.ReLU(),
Conv1d1x1(skip_out_channels, skip_out_channels),
nn.ReLU(),
Conv1d1x1(skip_out_channels, out_channels)])
if gin_channels > 0 and use_speaker_embedding:
assert n_speakers is not None
self.embed_speakers = Embedding(
n_speakers, gin_channels, padding_idx=None, std=0.1)
else:
self.embed_speakers = None
if upsample_conditional_features:
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
else:
self.upsample_net = None
self.factor = math.sqrt(1.0 / len(self.conv_layers))
def _expand_global_features(self, batch_size, time_step, g_fp, is_expand=True):
"""Expand global conditioning features to all time steps
Args:
batch_size (int): Batch size.
time_step (int): Time length.
g_fp (Tensor): Global features, (B x C) or (B x C x 1).
is_expand (bool) : Expanded global conditioning features
Returns:
Tensor: B x C x T or B x T x C or None
"""
if g_fp is None:
return None
if len(g_fp.shape) == 2:
g_fp = self.expandim_op(g_fp, -1)
else:
g_fp = g_fp
if is_expand:
expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step))
return expand_fp
expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step))
expand_fp = self.transpose_op(expand_fp, (0, 2, 1))
return expand_fp
def construct(self, x, c=None, g=None, softmax=False):
"""
Args:
x (Tensor): One-hot encoded audio signal
c (Tensor): Local conditioning feature
g (Tensor): Global conditioning feature
softmax (bool): Whether use softmax or not
Returns:
Tensor: Net output
"""
g = None
B, _, T = x.shape
if g is not None:
if self.embed_speakers is not None:
g = self.embed_speakers(self.reshape_op(g, (B, -1)))
g = self.transpose_op(g, (0, 2, 1))
g_bct = self._expand_global_features(B, T, g, is_expand=True)
if c is not None and self.upsample_net is not None:
c = self.upsample_net(c)
x = self.first_conv(x)
skips = 0
for f in self.conv_layers:
x, h = f(x, c, g_bct)
skips += h
skips *= self.factor
x = skips
for f in self.last_conv_layers:
x = f(x)
x = self.softmax(x) if softmax else x
return x
def relu_numpy(self, inX):
"""numpy relu function"""
return np.maximum(0, inX)
def softmax_numpy(self, x):
""" numpy softmax function """
x -= np.max(x, axis=1, keepdims=True)
return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
def incremental_forward(self, initial_input=None, c=None, g=None,
T=100, test_inputs=None,
tqdm=lambda x: x, softmax=True, quantize=True,
log_scale_min=-50.0, is_numpy=True):
"""
Incremental forward. Current output depends on last output.
Args:
initial_input (Tensor): Initial input, the shape is B x C x 1
c (Tensor): Local conditioning feature, the shape is B x C x T
g (Tensor): Global conditioning feature, the shape is B x C or B x C x 1
T (int): decoding time step.
test_inputs: Teacher forcing inputs (for debugging)
tqdm (lamda): tqmd
softmax (bool): Whether use softmax or not
quantize (bool): Whether quantize softmax output in last step when decoding current step
log_scale_min (float): Log scale minimum value
Returns:
Tensor: Predicted on-hot encoded samples or scalar vector depending on loss type
"""
self.clear_buffer()
B = 1
if test_inputs is not None:
if self.scalar_input:
if test_inputs.shape[1] == 1:
test_inputs = self.transpose_op(test_inputs, (0, 2, 1))
else:
if test_inputs.shape[1] == self.out_channels:
test_inputs = self.transpose_op(test_inputs, (0, 2, 1))
B = test_inputs.shape[0]
if T is None:
T = test_inputs.shape[1]
else:
T = max(T, test_inputs.shape[1])
T = int(T)
# Global conditioning
if g is not None:
if self.embed_speakers is not None:
g = self.embed_speakers(self.reshape_op(g, (B, -1)))
g = self.transpose_op(g, (0, 2, 1))
assert g.dim() == 3
g_btc = self._expand_global_features(B, T, g, is_expand=False)
# Local conditioning
if c is not None:
B = c.shape[0]
if self.upsample_net is not None:
c = self.upsample_net(c)
assert c.shape[-1] == T
if c.shape[-1] == T:
c = self.transpose_op(c, (0, 2, 1))
outputs = []
if initial_input is None:
if self.scalar_input:
initial_input = self.zeros_op((B, 1, 1), mstype.float32)
else:
initial_input = np.zeros((B, 1, self.out_channels), np.float32)
initial_input[:, :, 127] = 1
initial_input = Tensor(initial_input)
else:
if initial_input.shape[1] == self.out_channels:
initial_input = self.transpose_op(initial_input, (0, 2, 1))
if is_numpy:
current_input = initial_input.asnumpy()
else:
current_input = initial_input
for t in tqdm(range(T)):
if test_inputs is not None and t < test_inputs.shape[1]:
current_input = self.expandim_op(test_inputs[:, t, :], 1)
else:
if t > 0:
if not is_numpy:
current_input = Tensor(outputs[-1])
else:
current_input = outputs[-1]
# Conditioning features for single time step
ct = None if c is None else self.expandim_op(c[:, t, :], 1)
gt = None if g is None else self.expandim_op(g_btc[:, t, :], 1)
x = current_input
if is_numpy:
ct = ct.asnumpy()
x = self.first_conv.incremental_forward(x, is_numpy=is_numpy)
skips = 0
for f in self.conv_layers:
x, h = f.incremental_forward(x, ct, gt, is_numpy=is_numpy)
skips += h
skips *= self.factor
x = skips
for f in self.last_conv_layers:
try:
x = f.incremental_forward(x, is_numpy=is_numpy)
except AttributeError:
if is_numpy:
x = self.relu_numpy(x)
else:
x = self.relu_op(x)
# Generate next input by sampling
if not is_numpy:
x = x.asnumpy()
if self.scalar_input:
if self.output_distribution == "Logistic":
x = sample_from_discretized_mix_logistic(x.reshape((B, -1, 1)), log_scale_min=log_scale_min)
elif self.output_distribution == "Normal":
x = sample_from_mix_gaussian(x.reshape((B, -1, 1)), log_scale_min=log_scale_min)
else:
assert False
else:
x = self.softmax_numpy(np.reshape(x, (B, -1))) if softmax else np.reshape(x, (B, -1))
if quantize:
x = sample_from_mix_onehotcategorical(x)
outputs += [x]
# T x B x C
outputs = np.stack(outputs, 0)
# B x C x T
outputs = np.transpose(outputs, (1, 2, 0))
self.clear_buffer()
return outputs
def clear_buffer(self):
"""clear buffer"""
self.first_conv.clear_buffer()
for f in self.conv_layers:
f.clear_buffer()
for f in self.last_conv_layers:
try:
f.clear_buffer()
except AttributeError:
pass