forked from mindspore-Ecosystem/mindspore
!11773 Add WaveNet to Model Zoo
From: @wanyiming Reviewed-by: Signed-off-by:
This commit is contained in:
commit
b82df95b43
|
@ -0,0 +1,254 @@
|
|||
# Contents
|
||||
|
||||
- [WaveNet Description](#WaveNet-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Convert Process](#convert-process)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [WaveNet Description](#contents)
|
||||
|
||||
WaveNet is a deep neural network for generating raw audio waveforms. The model is fully probabilistic and autoregressive, with the predictive distribution for each audio sample conditioned on all previous ones. We support training and evaluation on GPU.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1609.03499.pdf): ord A, Dieleman S, Zen H, et al. Wavenet: A generative model for raw audio
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
The current model consists of a pre-convolution layer, followed by several residual block which has residual and skip connection with gated activation units.
|
||||
Finally, post convolution layers are added to predict the distribution.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
|
||||
|
||||
- Dataset size:2.6G
|
||||
- Data format:audio clips(13100) and transcription
|
||||
|
||||
- The dataset structure is as follows:
|
||||
|
||||
```path
|
||||
.
|
||||
└── LJSpeech-1.1
|
||||
├─ wavs //audio clips files
|
||||
└─ metadata.csv //transcripts
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://cmc-szv.clouddragon.huawei.com/cmcversion/index/search?searchKey=Do-MindSpore%20V100R001C00B622)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
**Note that some of the scripts described below are not included our code**. These scripts should first be download them from [r9y9](https://github.com/r9y9/wavenet_vocoder) and added into this project.
|
||||
|
||||
```path
|
||||
.
|
||||
├── audio
|
||||
└──wavenet
|
||||
├──datasets // Note the datasets folder should be download from the above link
|
||||
├──egs // Note the egs folder should be download from the above link
|
||||
├──utils // Note the utils folder should be download from the above link
|
||||
├── audio.py // audio utils. Note this script should be download from a third party
|
||||
├── compute-meanvar-stats.py // Compute mean-variance normalization stats. Note this script should be download from the above link
|
||||
├── evaluate.py // evaluation
|
||||
├── export.py // convert mindspore model to air model
|
||||
├── hparams.py // hyper-parameter configuration. Note this script should be download from the above link
|
||||
├── mksubset.py // Make subset of dataset. Note this script should be download from the above link
|
||||
├── preprocess.py // Preprocess dataset. Note this script should be download from the above link
|
||||
├── preprocess_normalize.py // Perform meanvar normalization to preprocessed features. Note this script should be download from the above link
|
||||
├── README.md // descriptions about WaveNet
|
||||
├── train.py // training scripts
|
||||
├── train_pytorch.py // Note this script should be download from the above link. The initial name of this script is train.py in the project from the link
|
||||
├── src
|
||||
│ ├──__init__.py
|
||||
│ ├──dataset.py // generate dataloader and data processing entry
|
||||
│ ├──callback.py // callbacks to monitor the training
|
||||
│ ├──lr_generator.py // learning rate generator
|
||||
│ └──loss.py // loss function definition
|
||||
└── wavenet_vocoder
|
||||
├──__init__.py
|
||||
├──conv.py // extended 1D convolution
|
||||
├──mixture.py // loss function for training and sample function for testing
|
||||
├──modules.py // modules for Wavenet construction
|
||||
├──upsample.py // upsample layer definition
|
||||
├──util.py // utils. Note this script should be download from the above link
|
||||
├──wavenet.py // WaveNet networks
|
||||
└──tfcompat // Note this script should be download from the above link
|
||||
├──__init__.py
|
||||
└──hparam.py // param management tools
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
```text
|
||||
usage: train.py [--data_path DATA_PATH] [--preset PRESET]
|
||||
[--checkpoint_dir CHECKPOINT_DIR] [--checkpoint CHECKPOINT]
|
||||
[--speaker_id SPEAKER_ID] [--is_distributed IS_DISTRIBUTED]
|
||||
options:
|
||||
--data_path dataset path
|
||||
--preset path of preset parameters (json)
|
||||
--checkpoint_dir directory of saving model checkpoints
|
||||
--checkpoint pre-trained ckpt path, default is "./checkpoints"
|
||||
--speaker_id specific speaker of data in case for multi-speaker datasets, not used currently
|
||||
--is_distributed whether distributed training or not
|
||||
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```text
|
||||
usage: evaluate.py [--data_path DATA_PATH] [--preset PRESET]
|
||||
[--pretrain_ckpt PRETRAIN_CKPT] [--output_path OUTPUT_PATH]
|
||||
[--speaker_id SPEAKER_ID]
|
||||
options:
|
||||
--data_path dataset path
|
||||
--preset path of preset parameters (json)
|
||||
--pretrain_ckpt pre-trained ckpt path
|
||||
--is_numpy whether using numpy for inference or not
|
||||
--output_path path to save synthesized audio
|
||||
--speaker_id specific speaker of data in case for multi-speaker datasets, not used currently
|
||||
```
|
||||
|
||||
More parameters for training and evaluation can be set in file `hparams.py`.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
Before your first training, some dependency scripts should be downloaded and placed in correct directory as described in [Script and Sample Code].
|
||||
After that, raw data should be pre-processed by using the scripts in `egs`. The directory of egs is as follows:
|
||||
|
||||
```path
|
||||
.
|
||||
├── egs
|
||||
├──gaussian
|
||||
│ ├──conf
|
||||
│ │ ├──gaussian_wavenet.json
|
||||
│ │ └──gaussian_wavenet_demo.json
|
||||
│ └──run.sh
|
||||
├──mol
|
||||
│ ├──conf
|
||||
│ │ ├──mol_wavenet.json
|
||||
│ │ └──mol_wavenet_demo.json
|
||||
│ └──run.sh
|
||||
├──mulaw256
|
||||
│ ├──conf
|
||||
│ │ ├──mulaw_wavenet.json
|
||||
│ │ └──mulaw_wavenet_demo.json
|
||||
│ └──run.sh
|
||||
└──README.md
|
||||
```
|
||||
|
||||
In this project, three different losses are implemented to train the network:
|
||||
|
||||
- mulaw256: categorical output distribution. The input is 8-bit mulaw quantized waveform.
|
||||
- mol: discretized mix logistic loss. The input is 16-bit raw audio.
|
||||
- gaussian: mix gaussian loss. The input is 16-bit raw audio.
|
||||
|
||||
The three folder gaussian, mol, mulaw is used to generate corresponding training data respectively. For example, To generate the training data for
|
||||
mix gaussian loss, you should first modify the `run.sh` in line 28. Change `conf/gaussian_wavenet_demo.json` to
|
||||
`conf/gaussian_wavenet.json`. We use the default parameter in `gaussian_wavenet.json`. By this setting, data will be generated to adapt to mix gaussian loss and
|
||||
some parameters in `hparams.py` will be covered by that in `gaussian_wavenet.json`. You can also define your own hyper-parameter json here. After the modification,
|
||||
The following command can be ran for data generation. Note that if you want to change values of some parameters, you may need to modify in `gaussian_wavenet.json` instead of `hparams.py` since `gaussian_wavenet.json` may cover that in`hparams.py`.
|
||||
|
||||
```bash
|
||||
bash run.sh --stage 0 --stop-stage 0 --db-root /path_to_dataset/LJSpeech-1.1/wavs
|
||||
bash run.sh --stage 1 --stop-stage 1
|
||||
```
|
||||
|
||||
After the processing, the directory of gaussian will be as follows:
|
||||
|
||||
```path
|
||||
.
|
||||
├── gaussian
|
||||
├──conf
|
||||
├──data
|
||||
├──exp
|
||||
└──dump
|
||||
└──lj
|
||||
└──logmelspectrogram
|
||||
├──org
|
||||
└──norm
|
||||
├──train_no_dev
|
||||
├──dev
|
||||
└──eval
|
||||
```
|
||||
|
||||
The train_no_dev folder contains the final training data. For mulaw256 and mol, the process is the same. When the training data is prepared,
|
||||
you can run the following command to train the network:
|
||||
|
||||
```bash
|
||||
# standalone training
|
||||
python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt
|
||||
|
||||
distributed training
|
||||
CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' mpirun --allow-run-as-root -n 8 python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt --is_distributed=True
|
||||
|
||||
eval
|
||||
python evaluate.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --output_path=path_to_save_audio
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
WaveNet has a process of auto-regression and this process currently cannot be run in Graph mode(place the auto-regression into `construct`). Therefore, we implement the process in a common function. Here, we provide two kinds of ways to realize the function: using Numpy or using MindSpore ops. One can set `is_numpy` to determine which mode is used. We recommend using numpy since it is much faster than using MindSpore ops. This is because the auto-regression process only calls some simple operation like Matmul and Bias_add. Unlike Graph mode, there will exist some fixed cost each step and this leads to a lower speed. For more information, please refer to
|
||||
this [link](https://bbs.huaweicloud.com/forum/thread-94852-1-1.html)
|
||||
|
||||
```bash
|
||||
eval
|
||||
python evaluate.py --data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio
|
||||
```
|
||||
|
||||
## [Convert Process](#contents)
|
||||
|
||||
```bash
|
||||
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | WaveNet |
|
||||
| -------------------------- | ---------------------------------------------------------------|
|
||||
| Resource | NV SMX2 V100-32G |
|
||||
| uploaded Date | 01/14/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | LJSpeech-1.1 |
|
||||
| Training Parameters | 1p, epoch=600(max), steps=1635 * epoch, batch_size = 8, lr=1e-3 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | SoftmaxCrossEntropyWithLogits/discretized_mix_logistic/mix_gaussian |
|
||||
| Loss | around 2.0(mulaw256)/around 4.5(mol)/around -6.0(gaussian) |
|
||||
| Speed | 1p 1.467s/step |
|
||||
| Total time: training | 1p(mol/gaussian): around 4 days; 2p(mulaw256):around 1 week |
|
||||
| Checkpoint | 59.79MM/54.87M/54.83M (.ckpt file) |
|
||||
| Scripts | [WaveNet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/wavenet) |
|
||||
|
||||
### Inference Performance On GPU
|
||||
|
||||
Audio samples will be demonstrated online soon.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,253 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluation"""
|
||||
import os
|
||||
from os.path import join
|
||||
import argparse
|
||||
import glob
|
||||
from hparams import hparams, hparams_debug_string
|
||||
import audio
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
from tqdm import tqdm
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
from nnmnkwii import preprocessing as P
|
||||
from nnmnkwii.datasets import FileSourceDataset
|
||||
from wavenet_vocoder import WaveNet
|
||||
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_scalar_input
|
||||
from src.dataset import RawAudioDataSource, MelSpecDataSource, DualDataset
|
||||
|
||||
parser = argparse.ArgumentParser(description='TTS training')
|
||||
parser.add_argument('--data_path', type=str, required=True, default='',
|
||||
help='Directory contains preprocessed features.')
|
||||
parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).')
|
||||
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
|
||||
parser.add_argument('--is_numpy', action="store_false", default=True, help='Using numpy for inference or not')
|
||||
parser.add_argument('--output_path', type=str, default='./out_wave/', help='Path to save generated audios')
|
||||
parser.add_argument('--speaker_id', type=str, default='',
|
||||
help=' Use specific speaker of data in case for multi-speaker datasets.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def get_data_loader(hparam, data_dir):
|
||||
"""
|
||||
test data loader
|
||||
"""
|
||||
wav_paths = glob.glob(os.path.join(data_dir, "*-wave.npy"))
|
||||
if wav_paths:
|
||||
X = FileSourceDataset(RawAudioDataSource(data_dir,
|
||||
hop_size=audio.get_hop_size(),
|
||||
max_steps=None, cin_pad=hparam.cin_pad))
|
||||
else:
|
||||
X = None
|
||||
C = FileSourceDataset(MelSpecDataSource(data_dir,
|
||||
hop_size=audio.get_hop_size(),
|
||||
max_steps=None, cin_pad=hparam.cin_pad))
|
||||
|
||||
length_x = np.array(C.file_data_source.lengths)
|
||||
if C[0].shape[-1] != hparam.cin_channels:
|
||||
raise RuntimeError("Invalid cin_channnels {}. Expected to be {}.".format(hparam.cin_channels, C[0].shape[-1]))
|
||||
|
||||
dataset = DualDataset(X, C, length_x, batch_size=hparam.batch_size, hparams=hparam)
|
||||
|
||||
data_loader = de.GeneratorDataset(dataset, ["x_batch", "y_batch", "c_batch", "g_batch", "input_lengths", "mask"])
|
||||
|
||||
return data_loader, dataset
|
||||
|
||||
|
||||
def batch_wavegen(hparam, net, c_input=None, g_input=None, tqdm_=None, is_numpy=True):
|
||||
"""
|
||||
generate audio
|
||||
"""
|
||||
assert c_input is not None
|
||||
B = c_input.shape[0]
|
||||
net.set_train(False)
|
||||
|
||||
if hparam.upsample_conditional_features:
|
||||
length = (c_input.shape[-1] - hparam.cin_pad * 2) * audio.get_hop_size()
|
||||
else:
|
||||
# already dupulicated
|
||||
length = c_input.shape[-1]
|
||||
|
||||
y_hat = net.incremental_forward(c=c_input, g=g_input, T=length, tqdm=tqdm_, softmax=True, quantize=True,
|
||||
log_scale_min=hparam.log_scale_min, is_numpy=is_numpy)
|
||||
|
||||
if is_mulaw_quantize(hparam.input_type):
|
||||
# needs to be float since mulaw_inv returns in range of [-1, 1]
|
||||
y_hat = np.reshape(np.argmax(y_hat, 1), (B, -1))
|
||||
y_hat = y_hat.astype(np.float32)
|
||||
for k in range(B):
|
||||
y_hat[k] = P.inv_mulaw_quantize(y_hat[k], hparam.quantize_channels - 1)
|
||||
elif is_mulaw(hparam.input_type):
|
||||
y_hat = np.reshape(y_hat, (B, -1))
|
||||
for k in range(B):
|
||||
y_hat[k] = P.inv_mulaw(y_hat[k], hparam.quantize_channels - 1)
|
||||
else:
|
||||
y_hat = np.reshape(y_hat, (B, -1))
|
||||
|
||||
if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]:
|
||||
for k in range(B):
|
||||
y_hat[k] = getattr(audio, hparam.postprocess)(y_hat[k])
|
||||
|
||||
if hparam.global_gain_scale > 0:
|
||||
for k in range(B):
|
||||
y_hat[k] /= hparam.global_gain_scale
|
||||
|
||||
return y_hat
|
||||
|
||||
|
||||
def to_int16(x_):
|
||||
"""
|
||||
convert datatype to int16
|
||||
"""
|
||||
if x_.dtype == np.int16:
|
||||
return x_
|
||||
assert x_.dtype == np.float32
|
||||
assert x_.min() >= -1 and x_.max() <= 1.0
|
||||
return (x_ * 32767).astype(np.int16)
|
||||
|
||||
|
||||
def get_reference_file(hparam, dataset_source, idx):
|
||||
"""
|
||||
get reference files
|
||||
"""
|
||||
reference_files = []
|
||||
reference_feats = []
|
||||
for _ in range(hparam.batch_size):
|
||||
if hasattr(dataset_source, "X"):
|
||||
reference_files.append(dataset_source.X.collected_files[idx][0])
|
||||
else:
|
||||
pass
|
||||
if hasattr(dataset_source, "Mel"):
|
||||
reference_feats.append(dataset_source.Mel.collected_files[idx][0])
|
||||
else:
|
||||
reference_feats.append(dataset_source.collected_files[idx][0])
|
||||
idx += 1
|
||||
return reference_files, reference_feats, idx
|
||||
|
||||
|
||||
def get_saved_audio_name(has_ref_file_, ref_file, ref_feat, g_fp):
|
||||
"""get path to save reference audio"""
|
||||
if has_ref_file_:
|
||||
target_audio_path = ref_file
|
||||
name = os.path.splitext(os.path.basename(target_audio_path))[0].replace("-wave", "")
|
||||
else:
|
||||
target_feat_path = ref_feat
|
||||
name = os.path.splitext(os.path.basename(target_feat_path))[0].replace("-feats", "")
|
||||
# Paths
|
||||
if g_fp is None:
|
||||
dst_wav_path_ = join(args.output_path, "{}_gen.wav".format(name))
|
||||
target_wav_path_ = join(args.output_path, "{}_ref.wav".format(name))
|
||||
else:
|
||||
dst_wav_path_ = join(args.output_path, "speaker{}_{}_gen.wav".format(g, name))
|
||||
target_wav_path_ = join(args.output_path, "speaker{}_{}_ref.wav".format(g, name))
|
||||
return dst_wav_path_, target_wav_path_
|
||||
|
||||
|
||||
def save_ref_audio(hparam, ref, length, target_wav_path_):
|
||||
"""
|
||||
save reference audio
|
||||
"""
|
||||
if is_mulaw_quantize(hparam.input_type):
|
||||
ref = np.reshape(np.argmax(ref, 0), (-1))[:length]
|
||||
ref = ref.astype(np.float32)
|
||||
else:
|
||||
ref = np.reshape(ref, (-1))[:length]
|
||||
|
||||
if is_mulaw_quantize(hparam.input_type):
|
||||
ref = P.inv_mulaw_quantize(ref, hparam.quantize_channels - 1)
|
||||
elif is_mulaw(hparam.input_type):
|
||||
ref = P.inv_mulaw(ref, hparam.quantize_channels - 1)
|
||||
if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]:
|
||||
ref = getattr(audio, hparam.postprocess)(ref)
|
||||
if hparam.global_gain_scale > 0:
|
||||
ref /= hparam.global_gain_scale
|
||||
|
||||
ref = np.clip(ref, -1.0, 1.0)
|
||||
|
||||
wavfile.write(target_wav_path_, hparam.sample_rate, to_int16(ref))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
|
||||
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
|
||||
if args.preset is not None:
|
||||
with open(args.preset) as f:
|
||||
hparams.parse_json(f.read())
|
||||
|
||||
assert hparams.name == "wavenet_vocoder"
|
||||
print(hparams_debug_string())
|
||||
|
||||
fs = hparams.sample_rate
|
||||
hparams.batch_size = 10
|
||||
hparams.max_time_sec = None
|
||||
hparams.max_time_steps = None
|
||||
data_loaders, source_dataset = get_data_loader(hparam=hparams, data_dir=args.data_path)
|
||||
|
||||
upsample_params = hparams.upsample_params
|
||||
upsample_params["cin_channels"] = hparams.cin_channels
|
||||
upsample_params["cin_pad"] = hparams.cin_pad
|
||||
model = WaveNet(
|
||||
out_channels=hparams.out_channels,
|
||||
layers=hparams.layers,
|
||||
stacks=hparams.stacks,
|
||||
residual_channels=hparams.residual_channels,
|
||||
gate_channels=hparams.gate_channels,
|
||||
skip_out_channels=hparams.skip_out_channels,
|
||||
cin_channels=hparams.cin_channels,
|
||||
gin_channels=hparams.gin_channels,
|
||||
n_speakers=hparams.n_speakers,
|
||||
dropout=hparams.dropout,
|
||||
kernel_size=hparams.kernel_size,
|
||||
cin_pad=hparams.cin_pad,
|
||||
upsample_conditional_features=hparams.upsample_conditional_features,
|
||||
upsample_params=upsample_params,
|
||||
scalar_input=is_scalar_input(hparams.input_type),
|
||||
output_distribution=hparams.output_distribution,
|
||||
)
|
||||
|
||||
param_dict = load_checkpoint(args.pretrain_ckpt)
|
||||
load_param_into_net(model, param_dict)
|
||||
print('Successfully loading the pre-trained model')
|
||||
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
cin_pad = hparams.cin_pad
|
||||
|
||||
file_idx = 0
|
||||
for data in data_loaders.create_dict_iterator():
|
||||
x, y, c, g, input_lengths = data['x_batch'], data['y_batch'], data['c_batch'], data['g_batch'], data[
|
||||
'input_lengths']
|
||||
if cin_pad > 0:
|
||||
c = c.asnumpy()
|
||||
c = np.pad(c, pad_width=(cin_pad, cin_pad), mode="edge")
|
||||
c = Tensor(c)
|
||||
|
||||
ref_files, ref_feats, file_idx = get_reference_file(hparams, source_dataset, file_idx)
|
||||
# Generate
|
||||
y_hats = batch_wavegen(hparams, model, data['c_batch'], tqdm_=tqdm, is_numpy=args.is_numpy)
|
||||
x = x.asnumpy()
|
||||
input_lengths = input_lengths.asnumpy()
|
||||
# Save each utt.
|
||||
has_ref_file = bool(ref_files)
|
||||
for i, (ref_, gen_, length_) in enumerate(zip(x, y_hats, input_lengths)):
|
||||
dst_wav_path, target_wav_path = get_saved_audio_name(has_ref_file_=has_ref_file, ref_file=ref_files[i],
|
||||
ref_feat=ref_feats[i], g_fp=g)
|
||||
save_ref_audio(hparams, ref_, length_, target_wav_path)
|
||||
|
||||
gen = gen_[:length_]
|
||||
gen = np.clip(gen, -1.0, 1.0)
|
||||
wavfile.write(dst_wav_path, hparams.sample_rate, to_int16(gen))
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export mindir."""
|
||||
import json
|
||||
from os.path import join
|
||||
import argparse
|
||||
from warnings import warn
|
||||
from hparams import hparams, hparams_debug_string
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from wavenet_vocoder import WaveNet
|
||||
from wavenet_vocoder.util import is_mulaw_quantize, is_scalar_input
|
||||
import numpy as np
|
||||
from src.loss import PredictNet
|
||||
|
||||
parser = argparse.ArgumentParser(description='TTS training')
|
||||
parser.add_argument('--preset', type=str, default='', help='Path of preset parameters (json).')
|
||||
parser.add_argument('--speaker_id', type=str, default='',
|
||||
help=' Use specific speaker of data in case for multi-speaker datasets.')
|
||||
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
|
||||
|
||||
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
|
||||
if args.preset is not None:
|
||||
with open(args.preset) as f:
|
||||
hparams.parse_json(f.read())
|
||||
|
||||
assert hparams.name == "wavenet_vocoder"
|
||||
print(hparams_debug_string())
|
||||
|
||||
fs = hparams.sample_rate
|
||||
output_json_path = join(args.checkpoint_dir, "hparams.json")
|
||||
with open(output_json_path, "w") as f:
|
||||
json.dump(hparams.values(), f, indent=2)
|
||||
|
||||
if is_mulaw_quantize(hparams.input_type):
|
||||
if hparams.out_channels != hparams.quantize_channels:
|
||||
raise RuntimeError(
|
||||
"out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'")
|
||||
if hparams.upsample_conditional_features and hparams.cin_channels < 0:
|
||||
s = "Upsample conv layers were specified while local conditioning disabled. "
|
||||
s += "Notice that upsample conv layers will never be used."
|
||||
warn(s)
|
||||
|
||||
upsample_params = hparams.upsample_params
|
||||
upsample_params["cin_channels"] = hparams.cin_channels
|
||||
upsample_params["cin_pad"] = hparams.cin_pad
|
||||
model = WaveNet(
|
||||
out_channels=hparams.out_channels,
|
||||
layers=hparams.layers,
|
||||
stacks=hparams.stacks,
|
||||
residual_channels=hparams.residual_channels,
|
||||
gate_channels=hparams.gate_channels,
|
||||
skip_out_channels=hparams.skip_out_channels,
|
||||
cin_channels=hparams.cin_channels,
|
||||
gin_channels=hparams.gin_channels,
|
||||
n_speakers=hparams.n_speakers,
|
||||
dropout=hparams.dropout,
|
||||
kernel_size=hparams.kernel_size,
|
||||
cin_pad=hparams.cin_pad,
|
||||
upsample_conditional_features=hparams.upsample_conditional_features,
|
||||
upsample_params=upsample_params,
|
||||
scalar_input=is_scalar_input(hparams.input_type),
|
||||
output_distribution=hparams.output_distribution,
|
||||
)
|
||||
|
||||
Net = PredictNet(model)
|
||||
Net.set_train(False)
|
||||
receptive_field = model.receptive_field
|
||||
print("Receptive field (samples / ms): {} / {}".format(receptive_field, receptive_field / fs * 1000))
|
||||
param_dict = load_checkpoint(args.pretrain_ckpt)
|
||||
load_param_into_net(model, param_dict)
|
||||
print('Successfully loading the pre-trained model')
|
||||
|
||||
x = np.array(np.random.random((2, 256, 10240)), dtype=np.float32)
|
||||
c = np.array(np.random.random((2, 80, 44)), dtype=np.float32)
|
||||
g = np.array([0, 0], dtype=np.int64)
|
||||
|
||||
export(Net, Tensor(x), Tensor(c), Tensor(g), file_name="WaveNet", file_format='MINDIR')
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Defined callback for DeepFM.
|
||||
"""
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore import Tensor
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TimeMonitor(Callback):
|
||||
"""
|
||||
Time monitor for calculating cost of each epoch.
|
||||
|
||||
Args:
|
||||
data_size (int): step size of an epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_size):
|
||||
super(TimeMonitor, self).__init__()
|
||||
self.data_size = data_size
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / self.data_size
|
||||
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
print(f"step time {step_mseconds}", flush=True)
|
||||
|
||||
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
||||
Args:
|
||||
lr_init (numpy array): train lr
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
|
||||
"""
|
||||
|
||||
def __init__(self, lr_init=None):
|
||||
super(Monitor, self).__init__()
|
||||
self.lr_init = lr_init
|
||||
self.lr_init_len = len(lr_init)
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.losses = []
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
epoch_mseconds = (time.time() - self.epoch_time)
|
||||
per_step_mseconds = epoch_mseconds / cb_params.batch_num
|
||||
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.6f}".format(epoch_mseconds,
|
||||
per_step_mseconds,
|
||||
np.mean(self.losses)))
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end"""
|
||||
cb_params = run_context.original_args()
|
||||
step_mseconds = (time.time() - self.step_time)
|
||||
step_loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
|
||||
step_loss = step_loss[0]
|
||||
if isinstance(step_loss, Tensor):
|
||||
step_loss = np.mean(step_loss.asnumpy())
|
||||
|
||||
self.losses.append(step_loss)
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
|
||||
|
||||
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.6f}/{:5.6f}], time:[{:5.3f}], lr:[{:.9f}]".format(
|
||||
cb_params.cur_epoch_num -
|
||||
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
|
||||
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1].asnumpy()))
|
|
@ -0,0 +1,258 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Create train dataset.
|
||||
"""
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import audio
|
||||
from nnmnkwii.datasets import FileSourceDataset
|
||||
from nnmnkwii import preprocessing as P
|
||||
from wavenet_vocoder.util import is_mulaw_quantize
|
||||
from train_pytorch import _pad, _pad_2d, to_categorical, ensure_divisible, RawAudioDataSource, MelSpecDataSource, assert_ready_for_upsampling
|
||||
import mindspore.dataset.engine as de
|
||||
|
||||
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
"""make sequence mask for loss"""
|
||||
if max_len is None:
|
||||
max_len = np.max(sequence_length)
|
||||
batch_size = len(sequence_length)
|
||||
seq_range = np.linspace(0, max_len - 1, max_len, dtype=np.int32)
|
||||
seq_range_expand = np.tile(np.expand_dims(seq_range, 0), (batch_size, 1))
|
||||
seq_length_expand = np.tile(np.expand_dims(sequence_length, 1), (1, max_len))
|
||||
seq_length_expand = np.expand_dims(np.array(seq_range_expand < seq_length_expand, dtype=np.float32), -1)
|
||||
return seq_length_expand
|
||||
|
||||
class DistributedSampler():
|
||||
"""function to distribute and shuffle sample
|
||||
"""
|
||||
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.group_size = group_size
|
||||
self.dataset_len = len(self.dataset)
|
||||
self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
|
||||
self.total_size = self.num_samplers * self.group_size
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.seed = (self.seed + 1) & 0xffffffff
|
||||
np.random.seed(self.seed)
|
||||
indices = np.random.permutation(self.dataset_len).tolist()
|
||||
else:
|
||||
indices = list(range(self.dataset_len))
|
||||
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
indices = indices[self.rank::self.group_size]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samplers
|
||||
|
||||
|
||||
def process_condition_batch(max_time_steps, hparams, batch):
|
||||
"""process condition batch"""
|
||||
cin_pad = hparams.cin_pad
|
||||
new_batch = []
|
||||
for batch_ in batch:
|
||||
x, c, g = batch_
|
||||
if hparams.upsample_conditional_features:
|
||||
assert_ready_for_upsampling(x, c, cin_pad=0)
|
||||
if max_time_steps is not None:
|
||||
max_steps = ensure_divisible(max_time_steps, audio.get_hop_size(), True)
|
||||
if len(x) > max_steps:
|
||||
max_time_frames = max_steps // audio.get_hop_size()
|
||||
s = np.random.randint(cin_pad, len(c) - max_time_frames - cin_pad)
|
||||
ts = s * audio.get_hop_size()
|
||||
x = x[ts:ts + audio.get_hop_size() * max_time_frames]
|
||||
c = c[s - cin_pad:s + max_time_frames + cin_pad, :]
|
||||
assert_ready_for_upsampling(x, c, cin_pad=cin_pad)
|
||||
else:
|
||||
x, c = audio.adjust_time_resolution(x, c)
|
||||
if max_time_steps is not None and len(x) > max_time_steps:
|
||||
s = np.random.randint(cin_pad, len(x) - max_time_steps - cin_pad)
|
||||
x = x[s:s + max_time_steps]
|
||||
c = c[s - cin_pad:s + max_time_steps + cin_pad, :]
|
||||
assert len(x) == len(c)
|
||||
new_batch.append((x, c, g))
|
||||
return new_batch
|
||||
|
||||
|
||||
def process_no_condition_batch(max_time_steps, batch):
|
||||
"""process no condition batch"""
|
||||
new_batch = []
|
||||
for batch_ in batch:
|
||||
x, c, g = batch_
|
||||
x = audio.trim(x)
|
||||
if max_time_steps is not None and len(x) > max_time_steps:
|
||||
s = np.random.randint(0, len(x) - max_time_steps)
|
||||
x = x[s:s + max_time_steps]
|
||||
new_batch.append((x, c, g))
|
||||
return new_batch
|
||||
|
||||
|
||||
|
||||
def collate_fn(batch, hparams):
|
||||
"""
|
||||
Create batch
|
||||
"""
|
||||
local_conditioning = len(batch[0]) >= 2 and hparams.cin_channels > 0
|
||||
global_conditioning = len(batch[0]) >= 3 and hparams.gin_channels > 0
|
||||
|
||||
if hparams.max_time_sec is not None:
|
||||
max_time_steps = int(hparams.max_time_sec * hparams.sample_rate)
|
||||
elif hparams.max_time_steps is not None:
|
||||
max_time_steps = hparams.max_time_steps
|
||||
else:
|
||||
max_time_steps = None
|
||||
|
||||
if local_conditioning:
|
||||
new_batch = process_condition_batch(max_time_steps, hparams, batch)
|
||||
else:
|
||||
new_batch = process_no_condition_batch(max_time_steps, batch)
|
||||
batch = new_batch
|
||||
# Lengths
|
||||
input_lengths = [len(x[0]) for x in batch]
|
||||
max_input_len = max(input_lengths)
|
||||
# (B, T, C)
|
||||
# pad for time-axis
|
||||
if is_mulaw_quantize(hparams.input_type):
|
||||
padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1)
|
||||
x_batch = np.array(
|
||||
[_pad_2d(to_categorical(x[0], num_classes=hparams.quantize_channels), max_input_len, 0, padding_value) for x
|
||||
in batch], dtype=np.float32)
|
||||
else:
|
||||
x_batch = np.array([_pad_2d(x[0].reshape(-1, 1), max_input_len)
|
||||
for x in batch], dtype=np.float32)
|
||||
assert len(x_batch.shape) == 3
|
||||
|
||||
# (B, T)
|
||||
if is_mulaw_quantize(hparams.input_type):
|
||||
padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1)
|
||||
y_batch = np.array([_pad(x[0], max_input_len, constant_values=padding_value)
|
||||
for x in batch], dtype=np.int32)
|
||||
else:
|
||||
y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.float32)
|
||||
assert len(y_batch.shape) == 2
|
||||
|
||||
# (B, T, D)
|
||||
if local_conditioning:
|
||||
max_len = max([len(x[1]) for x in batch])
|
||||
c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], dtype=np.float32)
|
||||
assert len(c_batch.shape) == 3
|
||||
# (B x C x T)
|
||||
c_batch = c_batch.transpose((0, 2, 1))
|
||||
else:
|
||||
c_batch = np.zeros(hparams.batch_size, dtype=np.float32)
|
||||
|
||||
if global_conditioning:
|
||||
g_batch = [x[2] for x in batch]
|
||||
else:
|
||||
# g_batch = None # MindSpore does not support None input
|
||||
g_batch = np.zeros(hparams.batch_size, dtype=np.int64)
|
||||
|
||||
# Convert to channel first (B, C, T)
|
||||
x_batch = x_batch.transpose((0, 2, 1))
|
||||
# Add extra axis
|
||||
if is_mulaw_quantize(hparams.input_type):
|
||||
y_batch = np.expand_dims(y_batch, axis=-1)
|
||||
else:
|
||||
y_batch = np.expand_dims(y_batch, axis=-1)
|
||||
|
||||
input_lengths = input_lengths
|
||||
|
||||
mask = sequence_mask(input_lengths, max_len=x_batch.shape[-1])
|
||||
|
||||
return x_batch, y_batch, c_batch, g_batch, input_lengths, mask
|
||||
|
||||
|
||||
class DualDataset():
|
||||
"""Create Dataset loader for audio Mel and Audio"""
|
||||
def __init__(self, X, Mel, length, batch_size, hparams):
|
||||
self.multi_speaker = X.file_data_source.multi_speaker
|
||||
self.X = X
|
||||
self.Mel = Mel
|
||||
self.length = length
|
||||
self.hparams = hparams
|
||||
self.sorted_index = list(np.argsort(length))
|
||||
self.bins = [self.sorted_index[i:i + batch_size] for i in range(0, len(self.sorted_index), batch_size)]
|
||||
if len(self.sorted_index) / batch_size != 0:
|
||||
self.bins.append(self.sorted_index[-batch_size:])
|
||||
self.size = len(self.bins)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.multi_speaker:
|
||||
speaker_id = self.X.file_data_source.speaker_ids[idx]
|
||||
else:
|
||||
speaker_id = None
|
||||
|
||||
combined_data = []
|
||||
mel_len, audio_len = [], []
|
||||
for i in self.bins[idx]:
|
||||
if self.Mel is not None:
|
||||
mel = self.Mel[i]
|
||||
raw_audio = self.X[i]
|
||||
length_mel, length_x = mel.shape[0], raw_audio.shape[0]
|
||||
combined_data.append((raw_audio, mel, speaker_id))
|
||||
mel_len.append(length_mel)
|
||||
audio_len.append(length_x)
|
||||
else:
|
||||
raw_audio = self.X[i]
|
||||
length_x = raw_audio.shape[0]
|
||||
combined_data.append((raw_audio, speaker_id))
|
||||
audio_len.append(length_x)
|
||||
|
||||
x_batch, y_batch, c_batch, g_batch, input_lengths, mask = collate_fn(combined_data, self.hparams)
|
||||
|
||||
return x_batch, y_batch, c_batch, g_batch, input_lengths, mask
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
|
||||
def get_data_loaders(dump_root, speaker_id, hparams=None, rank_id=None, group_size=None):
|
||||
"""create train dataset"""
|
||||
local_conditioning = hparams.cin_channels > 0
|
||||
|
||||
if hparams.max_time_steps is not None:
|
||||
max_steps = ensure_divisible(hparams.max_time_steps, audio.get_hop_size(), True)
|
||||
else:
|
||||
max_steps = None
|
||||
|
||||
X = FileSourceDataset(
|
||||
RawAudioDataSource(os.path.join(dump_root, 'train_no_dev'), speaker_id=speaker_id,
|
||||
max_steps=max_steps, cin_pad=hparams.cin_pad,
|
||||
hop_size=audio.get_hop_size()))
|
||||
|
||||
if local_conditioning:
|
||||
Mel = FileSourceDataset(
|
||||
MelSpecDataSource(os.path.join(dump_root, 'train_no_dev'), speaker_id=speaker_id,
|
||||
max_steps=max_steps, cin_pad=hparams.cin_pad,
|
||||
hop_size=audio.get_hop_size()))
|
||||
assert len(X) == len(Mel)
|
||||
print("Local conditioning enabled. Shape of a sample: {}.".format(Mel[0].shape))
|
||||
else:
|
||||
Mel = None
|
||||
print("length of the dataset is {}".format(len(X)))
|
||||
length_x = np.array(X.file_data_source.lengths)
|
||||
dataset = DualDataset(X, Mel, length_x, batch_size=hparams.batch_size, hparams=hparams)
|
||||
sampler = DistributedSampler(dataset, rank_id, group_size, shuffle=True, seed=0)
|
||||
data_loaders = de.GeneratorDataset(dataset, ["x_batch", "y_batch", "c_batch", "g_batch", "input_lengths", "mask"],
|
||||
sampler=sampler)
|
||||
|
||||
return data_loaders
|
|
@ -0,0 +1,238 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""loss function definition"""
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from nnmnkwii import preprocessing as P1
|
||||
|
||||
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw
|
||||
from wavenet_vocoder.mixture import discretized_mix_logistic_loss
|
||||
from wavenet_vocoder.mixture import mix_gaussian_loss
|
||||
from train_pytorch import to_categorical
|
||||
from tqdm import tqdm
|
||||
import audio
|
||||
import librosa
|
||||
import librosa.display
|
||||
matplotlib.use('Agg')
|
||||
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
"""make sequence mask"""
|
||||
sequence_length = sequence_length.asnumpy()
|
||||
if max_len is None:
|
||||
max_len = np.max(sequence_length)
|
||||
batch_size = sequence_length.shape[0]
|
||||
seq_range = np.linspace(0, max_len-1, max_len, dtype=np.int32)
|
||||
seq_range_expand = np.tile(np.expand_dims(seq_range, 0), (batch_size, 1))
|
||||
seq_length_expand = np.tile(np.expand_dims(sequence_length, 1), (1, max_len))
|
||||
seq_length_expand = np.expand_dims(np.array(seq_range_expand < seq_length_expand, dtype=np.float32), -1)
|
||||
return Tensor(seq_length_expand)
|
||||
|
||||
class MaskedCrossEntropyLoss(nn.Cell):
|
||||
"""MaskedCrossEntropyLoss"""
|
||||
def __init__(self):
|
||||
super(MaskedCrossEntropyLoss, self).__init__()
|
||||
self.criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
||||
def construct(self, inputs, target):
|
||||
losses = self.criterion(inputs, target)
|
||||
return losses
|
||||
|
||||
|
||||
class DiscretizedMixturelogisticLoss(nn.Cell):
|
||||
"""DiscretizedMixturelogisticLoss"""
|
||||
def __init__(self, hparams):
|
||||
super(DiscretizedMixturelogisticLoss, self).__init__()
|
||||
self.quantize_channels = hparams.quantize_channels
|
||||
self.log_scale_min = hparams.log_scale_min
|
||||
self.discretized_mix_logistic_loss = discretized_mix_logistic_loss(num_classes=hparams.quantize_channels,
|
||||
log_scale_min=hparams.log_scale_min,
|
||||
reduce=False)
|
||||
self.reduce_sum_op = P.ReduceSum()
|
||||
self.reduce_mean_op = P.ReduceMean()
|
||||
|
||||
def construct(self, inputs, target, mask=None):
|
||||
losses = self.discretized_mix_logistic_loss(inputs, target)
|
||||
return self.reduce_sum_op(losses * mask) / self.reduce_sum_op(mask)
|
||||
|
||||
|
||||
class MixtureGaussianLoss(nn.Cell):
|
||||
"""MixtureGaussianLoss"""
|
||||
def __init__(self, hparams):
|
||||
super(MixtureGaussianLoss, self).__init__()
|
||||
self.quantize_channels = hparams.quantize_channels
|
||||
self.log_scale_min = hparams.log_scale_min
|
||||
self.mix_gaussian_loss = mix_gaussian_loss(log_scale_min=hparams.log_scale_min, reduce=False)
|
||||
self.reduce_sum_op = P.ReduceSum()
|
||||
self.reduce_mean_op = P.ReduceMean()
|
||||
|
||||
def construct(self, inputs, target, mask=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Predicted distribution
|
||||
target (Tensor): Target
|
||||
mask (Tensor): Mask
|
||||
|
||||
Returns:
|
||||
Tensor: Loss tensor
|
||||
|
||||
"""
|
||||
losses = self.mix_gaussian_loss(inputs, target)
|
||||
return self.reduce_sum_op(losses * mask) / self.reduce_sum_op(mask)
|
||||
|
||||
|
||||
def save_waveplot(path, y_hat, y_target, sample_rate):
|
||||
sr = sample_rate
|
||||
plt.figure(figsize=(16, 6))
|
||||
plt.subplot(2, 1, 1)
|
||||
librosa.display.waveplot(y_target, sr=sr)
|
||||
plt.subplot(2, 1, 2)
|
||||
librosa.display.waveplot(y_hat, sr=sr)
|
||||
plt.tight_layout()
|
||||
plt.savefig(path, format="png")
|
||||
plt.close()
|
||||
|
||||
|
||||
def eval_model(hparams, global_step, model, x, y, c, g, input_lengths, eval_dir):
|
||||
"""
|
||||
Function for model evaluation. This function is used for debugging in this project.
|
||||
"""
|
||||
|
||||
model.set_train(False)
|
||||
idx = np.random.randint(0, len(y))
|
||||
length = input_lengths.asnumpy()[idx]
|
||||
y_target = np.reshape(y.asnumpy()[idx], (-1))
|
||||
y_target = y_target[:length]
|
||||
|
||||
if c is not None:
|
||||
expand_op = P.ExpandDims()
|
||||
if hparams.upsample_conditional_features:
|
||||
c = expand_op(c[idx, :, :int(length // audio.get_hop_size() + hparams.cin_pad * 2)], 0)
|
||||
else:
|
||||
c = expand_op(c[idx, :, :length], 0)
|
||||
assert c.dim() == 3
|
||||
print("Shape of local conditioning features: {}".format(c.size()))
|
||||
|
||||
if g is not None:
|
||||
g = g[idx]
|
||||
print("Shape of global conditioning features: {}".format(g.size()))
|
||||
|
||||
# Dummy silence
|
||||
if is_mulaw_quantize(hparams.input_type):
|
||||
initial_value = P1.mulaw_quantize(0, hparams.quantize_channels - 1)
|
||||
elif is_mulaw(hparams.input_type):
|
||||
initial_value = P1.mulaw(0.0, hparams.quantize_channels)
|
||||
else:
|
||||
initial_value = 0.0
|
||||
|
||||
# (C,)
|
||||
if is_mulaw_quantize(hparams.input_type):
|
||||
initial_input = to_categorical(
|
||||
initial_value, num_classes=hparams.quantize_channels).astype(np.float32)
|
||||
initial_input = Tensor(np.reshape(initial_input, (1, 1, hparams.quantize_channels)))
|
||||
|
||||
else:
|
||||
initial_input = np.ones((1, 1, 1)) * initial_value
|
||||
initial_input = Tensor(initial_input)
|
||||
|
||||
# Run the model in fast eval mode
|
||||
y_hat = model.incremental_forward(initial_input, c=c, g=g, T=length, softmax=True, quantize=True, tqdm=tqdm,
|
||||
log_scale_min=hparams.log_scale_min)
|
||||
|
||||
if is_mulaw_quantize(hparams.input_type):
|
||||
y_hat = np.reshape(np.argmax(y_hat, 1), (-1))
|
||||
y_hat = P1.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1)
|
||||
y_target = P1.inv_mulaw_quantize(y_target, hparams.quantize_channels - 1)
|
||||
elif is_mulaw(hparams.input_type):
|
||||
y_hat = P1.inv_mulaw(np.reshape(y_hat, (-1)), hparams.quantize_channels)
|
||||
y_target = P1.inv_mulaw(y_target, hparams.quantize_channels)
|
||||
else:
|
||||
y_hat = np.reshape(y_hat, (-1))
|
||||
|
||||
# Save audio
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
path = os.path.join(eval_dir, "step{:09d}_predicted.wav".format(global_step))
|
||||
librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate)
|
||||
|
||||
path = os.path.join(eval_dir, "step{:09d}_target.wav".format(global_step))
|
||||
librosa.output.write_wav(path, y_target, sr=hparams.sample_rate)
|
||||
|
||||
# Save figure
|
||||
path = os.path.join(eval_dir, "step{:09d}_waveplots.png".format(global_step))
|
||||
save_waveplot(path, y_hat, y_target, hparams.sample_rate)
|
||||
|
||||
|
||||
class PredictNet(nn.Cell):
|
||||
"""
|
||||
NetWithLossClass definition
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(PredictNet, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, c, g):
|
||||
y_hat = self.network(x, c, g, False)
|
||||
return y_hat
|
||||
|
||||
|
||||
class NetWithLossClass(nn.Cell):
|
||||
"""
|
||||
NetWithLossClass definition
|
||||
|
||||
Args:
|
||||
network (Cell): Pre-defined WaveNet.
|
||||
hparams (optional): Parameters.
|
||||
|
||||
Returns:
|
||||
Tensor, loss tensor.
|
||||
"""
|
||||
def __init__(self, network, hparams):
|
||||
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.hparams = hparams
|
||||
self.ReduceMean_false = P.ReduceMean(keep_dims=False)
|
||||
self.expand_op = P.ExpandDims()
|
||||
self.transpose_op = P.Transpose()
|
||||
self.reshape_op = P.Reshape()
|
||||
self.is_mulaw_quant = is_mulaw_quantize(hparams.input_type)
|
||||
|
||||
if self.is_mulaw_quant:
|
||||
self.criterion = MaskedCrossEntropyLoss()
|
||||
else:
|
||||
if hparams.output_distribution == "Logistic":
|
||||
self.criterion = DiscretizedMixturelogisticLoss(hparams)
|
||||
elif hparams.output_distribution == "Normal":
|
||||
self.criterion = MixtureGaussianLoss(hparams)
|
||||
else:
|
||||
self.criterion = None
|
||||
raise RuntimeError(
|
||||
"Not supported output distribution type: {}".format(hparams.output_distribution))
|
||||
|
||||
def construct(self, x, y, c, g, input_lengths, mask):
|
||||
y_hat = self.network(x, c, g, False)
|
||||
if self.is_mulaw_quant:
|
||||
y_hat = self.transpose_op(y_hat[:, :, :-1], (0, 2, 1))
|
||||
y_hat = self.reshape_op(y_hat, (-1, y_hat.shape[-1]))
|
||||
y = self.reshape_op(y[:, 1:, 0], (-1,))
|
||||
loss = self.criterion(y_hat, y)
|
||||
else:
|
||||
loss = self.criterion(y_hat[:, :, :-1], y[:, 1:, :], mask[:, 1:, :])
|
||||
return loss
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(init_lr, total_epoch, step_per_epoch,
|
||||
anneal_rate=0.5,
|
||||
anneal_interval=200000):
|
||||
"""
|
||||
Learning rate generating
|
||||
|
||||
Args:
|
||||
init_lr (float): Initial learning rate
|
||||
total_epoch (int): Total epoch
|
||||
step_per_epoch (int): Step per epoch
|
||||
anneal_rate (float): anneal rate
|
||||
anneal_interval (int ): anneal interval
|
||||
|
||||
Returns:
|
||||
ndarray: learning rate
|
||||
|
||||
"""
|
||||
total_step = total_epoch * step_per_epoch
|
||||
lr_step = []
|
||||
for i in range(total_step):
|
||||
lr_step.append(init_lr * anneal_rate ** (i // anneal_interval))
|
||||
learning_rate = np.array(lr_step).astype(np.float32)
|
||||
return learning_rate
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_criteo."""
|
||||
import os
|
||||
from os.path import join
|
||||
import json
|
||||
import argparse
|
||||
from warnings import warn
|
||||
from hparams import hparams, hparams_debug_string
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.train import Model
|
||||
from src.lr_generator import get_lr
|
||||
from src.dataset import get_data_loaders
|
||||
from src.loss import NetWithLossClass
|
||||
from src.callback import Monitor
|
||||
from wavenet_vocoder import WaveNet
|
||||
from wavenet_vocoder.util import is_mulaw_quantize, is_scalar_input
|
||||
|
||||
parser = argparse.ArgumentParser(description='TTS training')
|
||||
parser.add_argument('--data_path', type=str, required=True, default='',
|
||||
help='Directory contains preprocessed features.')
|
||||
parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).')
|
||||
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_test',
|
||||
help='Directory where to save model checkpoints [default: checkpoints].')
|
||||
parser.add_argument('--checkpoint', type=str, default='', help='Restore model from checkpoint path if given.')
|
||||
parser.add_argument('--speaker_id', type=str, default='',
|
||||
help=' Use specific speaker of data in case for multi-speaker datasets.')
|
||||
parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.is_distributed:
|
||||
init('nccl')
|
||||
rank_id = get_rank()
|
||||
group_size = get_group_size()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
|
||||
rank_id = 0
|
||||
group_size = 1
|
||||
|
||||
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
|
||||
if args.preset is not None:
|
||||
with open(args.preset) as f:
|
||||
hparams.parse_json(f.read())
|
||||
|
||||
assert hparams.name == "wavenet_vocoder"
|
||||
print(hparams_debug_string())
|
||||
fs = hparams.sample_rate
|
||||
os.makedirs(args.checkpoint_dir, exist_ok=True)
|
||||
|
||||
output_json_path = join(args.checkpoint_dir, "hparams.json")
|
||||
with open(output_json_path, "w") as f:
|
||||
json.dump(hparams.values(), f, indent=2)
|
||||
|
||||
data_loaders = get_data_loaders(args.data_path, args.speaker_id, hparams=hparams, rank_id=rank_id,
|
||||
group_size=group_size)
|
||||
step_size_per_epoch = data_loaders.get_dataset_size()
|
||||
|
||||
if is_mulaw_quantize(hparams.input_type):
|
||||
if hparams.out_channels != hparams.quantize_channels:
|
||||
raise RuntimeError(
|
||||
"out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'")
|
||||
if hparams.upsample_conditional_features and hparams.cin_channels < 0:
|
||||
s = "Upsample conv layers were specified while local conditioning disabled. "
|
||||
s += "Notice that upsample conv layers will never be used."
|
||||
warn(s)
|
||||
|
||||
upsample_params = hparams.upsample_params
|
||||
upsample_params["cin_channels"] = hparams.cin_channels
|
||||
upsample_params["cin_pad"] = hparams.cin_pad
|
||||
model = WaveNet(
|
||||
out_channels=hparams.out_channels,
|
||||
layers=hparams.layers,
|
||||
stacks=hparams.stacks,
|
||||
residual_channels=hparams.residual_channels,
|
||||
gate_channels=hparams.gate_channels,
|
||||
skip_out_channels=hparams.skip_out_channels,
|
||||
cin_channels=hparams.cin_channels,
|
||||
gin_channels=hparams.gin_channels,
|
||||
n_speakers=hparams.n_speakers,
|
||||
dropout=hparams.dropout,
|
||||
kernel_size=hparams.kernel_size,
|
||||
cin_pad=hparams.cin_pad,
|
||||
upsample_conditional_features=hparams.upsample_conditional_features,
|
||||
upsample_params=upsample_params,
|
||||
scalar_input=is_scalar_input(hparams.input_type),
|
||||
output_distribution=hparams.output_distribution,
|
||||
)
|
||||
loss_net = NetWithLossClass(model, hparams)
|
||||
lr = get_lr(hparams.optimizer_params["lr"], hparams.nepochs, step_size_per_epoch)
|
||||
lr = Tensor(lr)
|
||||
|
||||
if args.checkpoint != '':
|
||||
param_dict = load_checkpoint(args.pre_trained_model_path)
|
||||
load_param_into_net(model, param_dict)
|
||||
print('Successfully loading the pre-trained model')
|
||||
|
||||
weights = model.trainable_params()
|
||||
optimizer = Adam(weights, learning_rate=lr, loss_scale=1024.)
|
||||
train_net = TrainOneStepCell(loss_net, optimizer)
|
||||
|
||||
model = Model(train_net)
|
||||
lr_cb = Monitor(lr)
|
||||
callback_list = [lr_cb]
|
||||
if args.is_distributed:
|
||||
ckpt_path = os.path.join(args.checkpoint_dir, 'ckpt_' + str(get_rank()) + '/')
|
||||
else:
|
||||
ckpt_path = args.checkpoint_dir
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=10)
|
||||
ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck)
|
||||
callback_list.append(ckpt_cb)
|
||||
model.train(hparams.nepochs, data_loaders, callbacks=callback_list)
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""init"""
|
||||
from __future__ import with_statement, print_function, absolute_import
|
||||
from .wavenet import WaveNet
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Extended Conv1D."""
|
||||
|
||||
import math
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
import numpy as np
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
"""
|
||||
Extended nn.Conv1d to adapt to incremental dilated convolutions.
|
||||
During training, initial Conv1D is used and during evaluation, incremental_forward is called.
|
||||
To improve the inference speed, tensor will be converted as numpy and the following calculation is based on numpy.
|
||||
These operation will be replaced with MindSpore ops in the future. Currently, some operation is not supported by
|
||||
MindSpore and a mixed use of numpy and MindSpore will take a long time.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Conv1d, self).__init__(*args, **kwargs)
|
||||
self.clear_buffer()
|
||||
self._linearized_weight = None
|
||||
self.transpose_op = P.Transpose()
|
||||
self.reshape_op = P.Reshape()
|
||||
self.squeeze_op = P.Squeeze(-2)
|
||||
self.zeros = P.Zeros()
|
||||
self.concat_op = P.Concat(axis=1)
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.get_weight = None
|
||||
self.get_bias = None
|
||||
|
||||
def incremental_forward(self, inputs, is_numpy=True):
|
||||
if is_numpy:
|
||||
return self.incremental_forward_numpy(inputs)
|
||||
return self.incremental_forward_pynative(inputs)
|
||||
|
||||
def incremental_forward_pynative(self, inputs):
|
||||
"""
|
||||
Incremental forward.
|
||||
|
||||
Args:
|
||||
inputs: B x T x C
|
||||
|
||||
Returns:
|
||||
ndarray
|
||||
|
||||
"""
|
||||
# input: (B, T, C)
|
||||
if self.training:
|
||||
raise RuntimeError('incremental_forward only supports eval mode')
|
||||
|
||||
if self.get_weight is None:
|
||||
self.get_weight = self._get_linearized_weight()
|
||||
|
||||
if self.get_bias is None and self.bias is not None:
|
||||
self.get_bias = self.bias
|
||||
|
||||
# Note mindspore uses Conv2D to construct Conv1D
|
||||
kw = self.kernel_size[1]
|
||||
dilation = self.dilation[1]
|
||||
|
||||
bsz = inputs.shape[0] # input: bsz x len x dim
|
||||
if kw > 1:
|
||||
if self.input_buffer is None:
|
||||
init_buffer = self.zeros((bsz, kw + (kw - 1) * (dilation - 1), inputs.shape[2]), mstype.float32)
|
||||
self.input_buffer = self.concat_op((init_buffer[:, 1:, :], inputs[:, 0:1, :]))
|
||||
else:
|
||||
# shift buffer
|
||||
self.input_buffer = self.concat_op((self.input_buffer[:, 1:, :], inputs[:, 0:1, :]))
|
||||
inputs = self.input_buffer
|
||||
if dilation > 1:
|
||||
inputs = inputs[:, 0::dilation, :]
|
||||
|
||||
output = self.matmul(self.reshape_op(inputs, (bsz, -1)), self.get_weight)
|
||||
if self.bias is not None:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return self.reshape_op(output, (bsz, 1, -1))
|
||||
|
||||
def incremental_forward_numpy(self, inputs):
|
||||
"""
|
||||
Incremental forward.
|
||||
|
||||
Args:
|
||||
inputs: B x T x C
|
||||
|
||||
Returns:
|
||||
ndarray
|
||||
|
||||
"""
|
||||
# input: (B, T, C)
|
||||
if self.training:
|
||||
raise RuntimeError('incremental_forward only supports eval mode')
|
||||
|
||||
if self.get_weight is None:
|
||||
weight = self._get_linearized_weight()
|
||||
self.get_weight = weight.asnumpy()
|
||||
|
||||
if self.get_bias is None and self.bias is not None:
|
||||
bias = self.bias
|
||||
self.get_bias = bias.asnumpy()
|
||||
|
||||
# Note mindspore uses Conv2D to construct Conv1D
|
||||
kw = self.kernel_size[1]
|
||||
dilation = self.dilation[1]
|
||||
|
||||
bsz = inputs.shape[0] # input: bsz x len x dim
|
||||
if kw > 1:
|
||||
if self.input_buffer is None:
|
||||
self.input_buffer = np.zeros((bsz, kw + (kw - 1) * (dilation - 1), inputs.shape[2]), dtype=np.float32)
|
||||
else:
|
||||
# shift buffer
|
||||
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :]
|
||||
# append next
|
||||
self.input_buffer[:, -1, :] = inputs[:, -1, :]
|
||||
inputs = self.input_buffer
|
||||
if dilation > 1:
|
||||
inputs = inputs[:, 0::dilation, :]
|
||||
output = inputs.reshape(bsz, -1).dot(self.get_weight.T)
|
||||
if self.bias is not None:
|
||||
output = output + np.expand_dims(self.get_bias, 0)
|
||||
return np.reshape(output, (bsz, 1, -1))
|
||||
|
||||
def clear_buffer(self):
|
||||
self.input_buffer = None
|
||||
|
||||
def _get_linearized_weight(self):
|
||||
"""
|
||||
get linearized weight
|
||||
"""
|
||||
weight = self.squeeze_op(self.weight)
|
||||
if self._linearized_weight is None:
|
||||
# Note mindspore uses Conv2D to construct Conv1D
|
||||
kw = self.kernel_size[1]
|
||||
if weight.shape == (self.out_channels, self.in_channels, kw):
|
||||
weight = self.transpose_op(weight, (0, 2, 1))
|
||||
else:
|
||||
weight = self.transpose_op(weight, (2, 0, 1))
|
||||
self._linearized_weight = self.reshape_op(weight, (self.out_channels, -1))
|
||||
return self._linearized_weight
|
||||
|
||||
def _clear_linearized_weight(self, *args):
|
||||
self._linearized_weight = None
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""
|
||||
weight initialization
|
||||
"""
|
||||
self.init_parameters_data()
|
||||
std_mul = 4.0
|
||||
for _, m in self.cells_and_names():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
std = math.sqrt((std_mul * 0.1) / (m.kernel_size[1] * self.in_channels))
|
||||
m.weight.set_data(Tensor(np.random.normal(0, std, m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.gamma.set_data(
|
||||
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
|
||||
m.beta.set_data(
|
||||
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
|
|
@ -0,0 +1,323 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
loss function for training and sample function for testing
|
||||
"""
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as P
|
||||
|
||||
|
||||
class log_sum_exp(nn.Cell):
|
||||
"""Numerically stable log_sum_exp
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(log_sum_exp, self).__init__()
|
||||
self.maxi = P.ReduceMax()
|
||||
self.maxi_dim = P.ReduceMax(keep_dims=True)
|
||||
self.log = P.Log()
|
||||
self.sums = P.ReduceSum()
|
||||
self.exp = P.Exp()
|
||||
|
||||
def construct(self, x):
|
||||
axis = len(x.shape) - 1
|
||||
m = self.maxi(x, axis)
|
||||
m2 = self.maxi_dim(x, axis)
|
||||
return m + self.log(self.sums(self.exp(x - m2), axis))
|
||||
|
||||
|
||||
class Stable_softplus(nn.Cell):
|
||||
"""Numerically stable softplus
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Stable_softplus, self).__init__()
|
||||
self.log_op = P.Log()
|
||||
self.abs_op = P.Abs()
|
||||
self.relu_op = P.ReLU()
|
||||
self.exp_op = P.Exp()
|
||||
|
||||
def construct(self, x):
|
||||
return self.log_op(1 + self.exp_op(- self.abs_op(x))) + self.relu_op(x)
|
||||
|
||||
|
||||
class discretized_mix_logistic_loss(nn.Cell):
|
||||
"""
|
||||
Discretized_mix_logistic_loss
|
||||
|
||||
Args:
|
||||
num_classes (int): Num_classes
|
||||
log_scale_min (float): Log scale minimum value
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=256, log_scale_min=-7.0, reduce=True):
|
||||
super(discretized_mix_logistic_loss, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.log_scale_min = log_scale_min
|
||||
self.reduce = reduce
|
||||
self.transpose_op = P.Transpose()
|
||||
self.exp = P.Exp()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.softplus = Stable_softplus()
|
||||
self.log = P.Log()
|
||||
self.cast = P.Cast()
|
||||
self.logsoftmax = P.LogSoftmax(-1)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.tile = P.Tile()
|
||||
self.maximum = P.Maximum()
|
||||
self.sums = P.ReduceSum()
|
||||
self.lse = log_sum_exp()
|
||||
self.reshape = P.Reshape()
|
||||
self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32))
|
||||
|
||||
def construct(self, y_hat, y):
|
||||
"""
|
||||
|
||||
Args:
|
||||
y_hat (Tensor): Predicted distribution
|
||||
y (Tensor): Target
|
||||
|
||||
Returns:
|
||||
Tensor: Discretized_mix_logistic_loss
|
||||
|
||||
"""
|
||||
nr_mix = y_hat.shape[1] // 3
|
||||
|
||||
# (B x T x C)
|
||||
y_hat = self.transpose_op(y_hat, (0, 2, 1))
|
||||
|
||||
# (B, T, num_mixtures) x 3
|
||||
logit_probs = y_hat[:, :, :nr_mix]
|
||||
means = y_hat[:, :, nr_mix:2 * nr_mix]
|
||||
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], self.log_scale_min)
|
||||
|
||||
# B x T x 1 -> B x T x num_mixtures
|
||||
y = self.tile(y, (1, 1, nr_mix))
|
||||
|
||||
centered_y = y - means
|
||||
inv_stdv = self.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_y + 1. / (self.num_classes - 1))
|
||||
cdf_plus = self.sigmoid(plus_in)
|
||||
min_in = inv_stdv * (centered_y - 1. / (self.num_classes - 1))
|
||||
cdf_min = self.sigmoid(min_in)
|
||||
|
||||
log_cdf_plus = plus_in - self.softplus(plus_in)
|
||||
|
||||
log_one_minus_cdf_min = -self.softplus(min_in)
|
||||
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
mid_in = inv_stdv * centered_y
|
||||
log_pdf_mid = mid_in - log_scales - 2. * self.softplus(mid_in)
|
||||
|
||||
inner_inner_cond = self.cast(cdf_delta > 1e-5, ms.float32)
|
||||
inner_inner_out = inner_inner_cond * \
|
||||
self.log(self.maximum(cdf_delta, 1e-12)) + \
|
||||
(1. - inner_inner_cond) * (log_pdf_mid - self.factor)
|
||||
inner_cond = self.cast(y > 0.999, ms.float32)
|
||||
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
|
||||
cond = self.cast(y < -0.999, ms.float32)
|
||||
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
|
||||
|
||||
a, b, c = logit_probs.shape[0], logit_probs.shape[1], logit_probs.shape[2]
|
||||
logit_probs = self.logsoftmax(self.reshape(logit_probs, (-1, c)))
|
||||
logit_probs = self.reshape(logit_probs, (a, b, c))
|
||||
|
||||
log_probs = log_probs + logit_probs
|
||||
if self.reduce:
|
||||
return -self.sums(self.lse(log_probs))
|
||||
return self.expand_dims(-self.lse(log_probs), -1)
|
||||
|
||||
|
||||
def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0):
|
||||
"""
|
||||
Sample from discretized mixture of logistic distributions
|
||||
|
||||
Args:
|
||||
y (ndarray): B x C x T
|
||||
log_scale_min (float): Log scale minimum value
|
||||
|
||||
Returns:
|
||||
ndarray
|
||||
"""
|
||||
nr_mix = y.shape[1] // 3
|
||||
|
||||
# B x T x C
|
||||
y = np.transpose(y, (0, 2, 1))
|
||||
logit_probs = y[:, :, :nr_mix]
|
||||
|
||||
temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape)
|
||||
temp = logit_probs - np.log(- np.log(temp))
|
||||
|
||||
argmax = np.argmax(temp, axis=-1)
|
||||
|
||||
# (B, T) -> (B, T, nr_mix)
|
||||
one_hot = np.eye(nr_mix)[argmax]
|
||||
means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
|
||||
log_scales = np.clip(np.sum(
|
||||
y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1), a_min=log_scale_min, a_max=None)
|
||||
|
||||
u = np.random.uniform(1e-5, 1.0 - 1e-5, means.shape)
|
||||
x = means + np.exp(log_scales) * (np.log(u) - np.log(1. - u))
|
||||
x = np.clip(x, -1., 1.)
|
||||
return x.astype(np.float32)
|
||||
|
||||
|
||||
class mix_gaussian_loss(nn.Cell):
|
||||
"""
|
||||
Mix gaussian loss
|
||||
"""
|
||||
|
||||
def __init__(self, log_scale_min=-7.0, reduce=True):
|
||||
super(mix_gaussian_loss, self).__init__()
|
||||
self.log_scale_min = log_scale_min
|
||||
self.reduce = reduce
|
||||
self.transpose_op = P.Transpose()
|
||||
self.maximum = P.Maximum()
|
||||
self.tile = P.Tile()
|
||||
self.exp = P.Exp()
|
||||
self.logsoftmax = P.LogSoftmax(-1)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.sums = P.ReduceSum()
|
||||
self.lse = log_sum_exp()
|
||||
|
||||
self.sq = P.Square()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.const = P.ScalarToArray()
|
||||
self.log = P.Log()
|
||||
|
||||
def construct(self, y_hat, y):
|
||||
"""
|
||||
|
||||
Args:
|
||||
y_hat (Tensor): Predicted probability
|
||||
y (Tensor): Target
|
||||
|
||||
Returns:
|
||||
Tensor: Mix_gaussian_loss
|
||||
|
||||
"""
|
||||
C = y_hat.shape[1]
|
||||
if C == 2:
|
||||
nr_mix = 1
|
||||
else:
|
||||
nr_mix = y_hat.shape[1] // 3
|
||||
|
||||
# (B x T x C)
|
||||
y_hat = self.transpose_op(y_hat, (0, 2, 1))
|
||||
|
||||
if C == 2:
|
||||
logit_probs = None
|
||||
means = y_hat[:, :, 0:1]
|
||||
log_scales = self.maximum(y_hat[:, :, 1:2], self.log_scale_min)
|
||||
else:
|
||||
# (B, T, num_mixtures) x 3
|
||||
logit_probs = y_hat[:, :, :nr_mix]
|
||||
means = y_hat[:, :, nr_mix:2 * nr_mix]
|
||||
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], self.log_scale_min)
|
||||
|
||||
# B x T x 1 -> B x T x num_mixtures
|
||||
y = self.tile(y, (1, 1, nr_mix))
|
||||
centered_y = y - means
|
||||
|
||||
sd = self.exp(log_scales)
|
||||
unnormalized_log_prob = -1. * (self.sq(centered_y - 0.)) / (2. * self.sq(sd))
|
||||
neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd)
|
||||
log_probs = unnormalized_log_prob + neg_normalization
|
||||
|
||||
if nr_mix > 1:
|
||||
log_probs = log_probs + self.logsoftmax(logit_probs)
|
||||
|
||||
if self.reduce:
|
||||
if nr_mix == 1:
|
||||
return -self.sums(log_probs)
|
||||
return -self.sums(self.lse(log_probs))
|
||||
if nr_mix == 1:
|
||||
return -log_probs
|
||||
return self.expand_dims(-self.lse(log_probs), -1)
|
||||
|
||||
|
||||
def sample_from_mix_gaussian(y, log_scale_min=-7.0):
|
||||
"""
|
||||
Sample_from_mix_gaussian
|
||||
|
||||
Args:
|
||||
y (ndarray): B x C x T
|
||||
|
||||
Returns:
|
||||
ndarray
|
||||
|
||||
"""
|
||||
C = y.shape[1]
|
||||
if C == 2:
|
||||
nr_mix = 1
|
||||
else:
|
||||
nr_mix = y.shape[1] // 3
|
||||
|
||||
# B x T x C
|
||||
y = np.transpose(y, (0, 2, 1))
|
||||
|
||||
if C == 2:
|
||||
logit_probs = None
|
||||
else:
|
||||
logit_probs = y[:, :, :nr_mix]
|
||||
|
||||
if nr_mix > 1:
|
||||
temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape)
|
||||
temp = logit_probs - np.log(- np.log(temp))
|
||||
argmax = np.argmax(temp, axis=-1)
|
||||
|
||||
# (B, T) -> (B, T, nr_mix)
|
||||
one_hot = np.eye(nr_mix)[argmax]
|
||||
|
||||
means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
|
||||
|
||||
log_scales = np.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1)
|
||||
else:
|
||||
if C == 2:
|
||||
means, log_scales = y[:, :, 0], y[:, :, 1]
|
||||
elif C == 3:
|
||||
means, log_scales = y[:, :, 1], y[:, :, 2]
|
||||
else:
|
||||
assert False, "shouldn't happen"
|
||||
|
||||
scales = np.exp(log_scales)
|
||||
x = np.random.normal(loc=means, scale=scales)
|
||||
x = np.clip(x, -1., 1.)
|
||||
return x.astype(np.float32)
|
||||
|
||||
|
||||
# self-implemented onehotcategorical distribution
|
||||
# https://zhuanlan.zhihu.com/p/59550457
|
||||
def sample_from_mix_onehotcategorical(x):
|
||||
"""
|
||||
Sample_from_mix_onehotcategorical
|
||||
|
||||
Args:
|
||||
x (ndarray): Predicted softmax probability
|
||||
|
||||
Returns:
|
||||
ndarray
|
||||
|
||||
"""
|
||||
pi = np.log(x)
|
||||
u = np.random.uniform(0, 1, x.shape)
|
||||
g = -np.log(-np.log(u))
|
||||
c = np.argmax(pi + g, axis=1)
|
||||
return np.array(np.eye(256)[c], dtype=np.float32)
|
|
@ -0,0 +1,213 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
modules for wavenet
|
||||
"""
|
||||
from __future__ import with_statement, print_function, absolute_import
|
||||
import math
|
||||
import numpy as np
|
||||
from wavenet_vocoder import conv
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
|
||||
m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs)
|
||||
return m
|
||||
|
||||
|
||||
def Conv1d1x1(in_channels, out_channels, has_bias=True):
|
||||
return Conv1d(in_channels, out_channels, kernel_size=1, pad_mode='pad', padding=0, dilation=1, has_bias=has_bias)
|
||||
|
||||
|
||||
def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01):
|
||||
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
return m
|
||||
|
||||
|
||||
def _conv1x1_forward(conv_, x, is_incremental, is_numpy=True):
|
||||
"""
|
||||
Conv1x1 forward
|
||||
"""
|
||||
if is_incremental:
|
||||
x = conv_.incremental_forward(x, is_numpy=is_numpy)
|
||||
else:
|
||||
x = conv_(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualConv1dGLU(nn.Cell):
|
||||
"""Residual dilated conv1d with gated activation units
|
||||
|
||||
Args:
|
||||
residual_channels (int): Residual input / output channels
|
||||
gate_channels (int): Gated activation channels.
|
||||
kernel_size (int): Kernel size
|
||||
skip_out_channels (int): Skip connection channels. If None, it will set to the same as residual_channels.
|
||||
cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled.
|
||||
gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled.
|
||||
dropout (float): Dropout rate.
|
||||
padding (int): Padding for convolution layers. If None, padding value will be computed according to dilation
|
||||
and kernel_size.
|
||||
dilation (int): Dilation factor.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, residual_channels=None, gate_channels=None, kernel_size=None, skip_out_channels=None, bias=True,
|
||||
dropout=1 - 0.95, dilation=1, cin_channels=-1, gin_channels=-1, padding=None, causal=True):
|
||||
super(ResidualConv1dGLU, self).__init__()
|
||||
self.dropout = dropout
|
||||
self.dropout_op = nn.Dropout(keep_prob=1. - self.dropout)
|
||||
self.eval_split_op = P.Split(axis=-1, output_num=2)
|
||||
self.train_split_op = P.Split(axis=1, output_num=2)
|
||||
self.tanh = P.Tanh()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.mul = P.Mul()
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
if skip_out_channels is None:
|
||||
skip_out_channels = residual_channels
|
||||
if padding is None:
|
||||
if causal:
|
||||
padding = (kernel_size - 1) * dilation
|
||||
else:
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
self.causal = causal
|
||||
|
||||
self.conv = Conv1d(residual_channels, gate_channels, kernel_size, pad_mode='pad',
|
||||
padding=padding, dilation=dilation, has_bias=bias)
|
||||
|
||||
# local conditioning
|
||||
if cin_channels > 0:
|
||||
self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, has_bias=False)
|
||||
else:
|
||||
self.conv1x1c = None
|
||||
|
||||
# global conditioning
|
||||
if gin_channels > 0:
|
||||
self.conv1x1g = Conv1d(gin_channels, gate_channels, has_bias=False, kernel_size=1, dilation=1)
|
||||
else:
|
||||
self.conv1x1g = None
|
||||
|
||||
gate_out_channels = gate_channels // 2
|
||||
self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, has_bias=bias)
|
||||
self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, has_bias=bias)
|
||||
self.factor = math.sqrt(0.5)
|
||||
|
||||
def construct(self, x, c=None, g=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
x(Tensor): One-hot audio signal, the shape is B x C x T
|
||||
c(Tensor): local conditional feature, the shape is B x cin_channels x T
|
||||
g(Tensor): global conditional feature, not used currently
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor
|
||||
|
||||
"""
|
||||
|
||||
residual = x
|
||||
x = self.dropout_op(x)
|
||||
x = self.conv(x)
|
||||
# remove future time steps
|
||||
x = x[:, :, :residual.shape[-1]] if self.causal else x
|
||||
split_op = self.train_split_op
|
||||
|
||||
a, b = split_op(x)
|
||||
|
||||
# local conditioning
|
||||
if c is not None:
|
||||
c = _conv1x1_forward(self.conv1x1c, c, is_incremental=False)
|
||||
ca, cb = split_op(c)
|
||||
a, b = a + ca, b + cb
|
||||
|
||||
# global conditioning
|
||||
if g is not None:
|
||||
g = _conv1x1_forward(self.conv1x1g, g, is_incremental=False)
|
||||
ga, gb = self.split(g)
|
||||
a, b = a + ga, b + gb
|
||||
|
||||
x = self.mul(self.tanh(a), self.sigmoid(b))
|
||||
|
||||
# For skip connection
|
||||
s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental=False)
|
||||
|
||||
# For residual connection
|
||||
x = _conv1x1_forward(self.conv1x1_out, x, is_incremental=False)
|
||||
|
||||
x = self.add(x, residual) * self.factor
|
||||
return x, s
|
||||
|
||||
def sigmoid_numpy(self, x):
|
||||
return 1. / (1 + np.exp(-x))
|
||||
|
||||
def incremental_forward(self, x, c=None, g=None, is_numpy=True):
|
||||
"""
|
||||
Incremental forward. Used for inference stage
|
||||
|
||||
Args:
|
||||
x (Tensor): One-hot audio signal, the shape is B x C x T
|
||||
c (Tensor): local conditional feature, the shape is B x cin_channels x T
|
||||
g (Tensor): global conditional feature, not used currently
|
||||
|
||||
Returns:
|
||||
ndarray
|
||||
"""
|
||||
residual = x
|
||||
x = self.conv.incremental_forward(x, is_numpy=is_numpy)
|
||||
if is_numpy:
|
||||
a, b = np.split(x, indices_or_sections=2, axis=-1)
|
||||
else:
|
||||
a, b = self.eval_split_op(x)
|
||||
|
||||
# local conditioning
|
||||
if c is not None:
|
||||
c = _conv1x1_forward(self.conv1x1c, c, is_incremental=True, is_numpy=is_numpy)
|
||||
if is_numpy:
|
||||
ca, cb = np.split(c, indices_or_sections=2, axis=-1)
|
||||
else:
|
||||
ca, cb = self.eval_split_op(c)
|
||||
a, b = a + ca, b + cb
|
||||
|
||||
# global conditioning
|
||||
if g is not None:
|
||||
g = _conv1x1_forward(self.conv1x1g, g, is_incremental=True, is_numpy=is_numpy)
|
||||
if is_numpy:
|
||||
ga, gb = np.split(g, indices_or_sections=2, axis=-1)
|
||||
else:
|
||||
ga, gb = self.eval_split_op(c)
|
||||
a, b = a + ga, b + gb
|
||||
|
||||
if is_numpy:
|
||||
x = np.tanh(a) * self.sigmoid_numpy(b)
|
||||
else:
|
||||
x = self.mul(self.tanh(a), self.sigmoid(b))
|
||||
|
||||
# For skip connection
|
||||
s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental=True, is_numpy=is_numpy)
|
||||
|
||||
# For residual connection
|
||||
x = _conv1x1_forward(self.conv1x1_out, x, is_incremental=True, is_numpy=is_numpy)
|
||||
|
||||
x = (x + residual) * self.factor
|
||||
return x, s
|
||||
|
||||
def clear_buffer(self):
|
||||
"""clear buffer"""
|
||||
for c in [self.conv, self.conv1x1_out, self.conv1x1_skip,
|
||||
self.conv1x1c, self.conv1x1g]:
|
||||
if c is not None:
|
||||
c.clear_buffer()
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Upsampling
|
||||
|
||||
"""
|
||||
from __future__ import with_statement, print_function, absolute_import
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Resize(nn.Cell):
|
||||
"""
|
||||
Resize input Tensor
|
||||
"""
|
||||
|
||||
def __init__(self, x_scale, y_scale, mode="nearest"):
|
||||
super(Resize, self).__init__()
|
||||
self.x_scale = x_scale
|
||||
self.y_scale = y_scale
|
||||
self.mode = mode
|
||||
|
||||
def construct(self, x):
|
||||
_, _, h, w = x.shape
|
||||
interpolate_op = P.ResizeNearestNeighbor((self.y_scale * h, self.x_scale * w))
|
||||
return interpolate_op(x)
|
||||
|
||||
|
||||
def _get_activation(upsample_activation):
|
||||
"""get activation"""
|
||||
nonlinear = getattr(nn, upsample_activation)
|
||||
return nonlinear
|
||||
|
||||
|
||||
class UpsampleNetwork(nn.Cell):
|
||||
"""UpsampleNetwork"""
|
||||
def __init__(self, upsample_scales, mode="nearest",
|
||||
freq_axis_kernel_size=1, cin_pad=0, cin_channels=80):
|
||||
super(UpsampleNetwork, self).__init__()
|
||||
self.expand_op = P.ExpandDims()
|
||||
self.squeeze_op = P.Squeeze(1)
|
||||
up_layers = []
|
||||
total_scale = np.prod(upsample_scales)
|
||||
self.indent = cin_pad * total_scale
|
||||
for scale in upsample_scales:
|
||||
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
||||
k_size = (freq_axis_kernel_size, scale * 2 + 1)
|
||||
# padding = (freq_axis_padding, scale)
|
||||
padding = (freq_axis_padding, freq_axis_padding, scale, scale)
|
||||
stretch = Resize(scale, 1, mode)
|
||||
conv = nn.Conv2d(1, 1, kernel_size=k_size, has_bias=False, pad_mode='pad', padding=padding)
|
||||
up_layers.append(stretch)
|
||||
up_layers.append(conv)
|
||||
# if upsample_activation != "none":
|
||||
# nonlinear = _get_activation(upsample_activation)
|
||||
# up_layers.append(nonlinear(**upsample_activation_params))
|
||||
self.up_layers = nn.CellList(up_layers)
|
||||
|
||||
def construct(self, c):
|
||||
"""
|
||||
|
||||
Args:
|
||||
c (Tensor): Local conditioning feature
|
||||
|
||||
Returns:
|
||||
Tensor: Upsampling feature
|
||||
|
||||
"""
|
||||
# B x 1 x C x T
|
||||
c = self.expand_op(c, 1)
|
||||
for f in self.up_layers:
|
||||
c = f(c)
|
||||
# B x C x T
|
||||
c = self.squeeze_op(c)
|
||||
|
||||
# if self.indent > 0:
|
||||
# c = c[:, :, self.indent:-self.indent]
|
||||
return c
|
||||
|
||||
|
||||
class ConvInUpsampleNetwork(nn.Cell):
|
||||
"""Upsample Network
|
||||
|
||||
Args:
|
||||
upsample_scales (list): Upsample_scales list.
|
||||
upsample_activation (str): Upsample_activation.
|
||||
mode (str): Resize mode, default is NearestNeighbor.
|
||||
cin_channels (int): Local conditioning channels.
|
||||
freq_axis_kernel_size (int): Freq-axis kernel_size for the convolution layers after resize.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, upsample_scales, mode="nearest",
|
||||
freq_axis_kernel_size=1, cin_pad=0,
|
||||
cin_channels=80):
|
||||
super(ConvInUpsampleNetwork, self).__init__()
|
||||
ks = 2 * cin_pad + 1
|
||||
self.conv_in = nn.Conv1d(cin_channels, cin_channels, kernel_size=ks, has_bias=False, pad_mode='pad', padding=0)
|
||||
self.upsample = UpsampleNetwork(upsample_scales, mode, freq_axis_kernel_size, cin_pad=0,
|
||||
cin_channels=cin_channels)
|
||||
|
||||
def construct(self, c):
|
||||
c = self.conv_in(c)
|
||||
c_up = self.upsample(c)
|
||||
return c_up
|
|
@ -0,0 +1,346 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""WaveNet construction"""
|
||||
from __future__ import with_statement, print_function, absolute_import
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from mindspore import nn, Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from wavenet_vocoder import upsample
|
||||
from .modules import Embedding
|
||||
from .modules import Conv1d1x1
|
||||
from .modules import ResidualConv1dGLU
|
||||
from .mixture import sample_from_discretized_mix_logistic
|
||||
from .mixture import sample_from_mix_gaussian
|
||||
from .mixture import sample_from_mix_onehotcategorical
|
||||
|
||||
|
||||
class WaveNet(nn.Cell):
|
||||
"""
|
||||
WaveNet model definition. Only local condition is supported
|
||||
|
||||
Args:
|
||||
out_channels (int): Output channels. If input_type is mu-law quantized one-hot vecror, it should equal to the
|
||||
quantize channels. Otherwise, it equals to num_mixtures x 3. Default: 256.
|
||||
layers (int): Number of ResidualConv1dGLU layers
|
||||
stacks (int): Number of dilation cycles
|
||||
residual_channels (int): Residual input / output channels
|
||||
gate_channels (int): Gated activation channels.
|
||||
skip_out_channels (int): Skip connection channels.
|
||||
kernel_size (int): Kernel size .
|
||||
dropout (float): Dropout rate.
|
||||
cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled.
|
||||
gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled.
|
||||
n_speakers (int): Number of speakers. This is used when global conditioning is enabled.
|
||||
upsample_conditional_features (bool): Whether upsampling local conditioning features by resize_nearestneighbor
|
||||
and conv or not.
|
||||
scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise, quantized one-hot vector
|
||||
is expected.
|
||||
use_speaker_embedding (Bool): Use speaker embedding or Not.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, out_channels=256, layers=20, stacks=2,
|
||||
residual_channels=512,
|
||||
gate_channels=512,
|
||||
skip_out_channels=512,
|
||||
kernel_size=3, dropout=1 - 0.95,
|
||||
cin_channels=-1, gin_channels=-1, n_speakers=None,
|
||||
upsample_conditional_features=False,
|
||||
upsample_net="ConvInUpsampleNetwork",
|
||||
upsample_params=None,
|
||||
scalar_input=False,
|
||||
use_speaker_embedding=False,
|
||||
output_distribution="Logistic",
|
||||
cin_pad=0,
|
||||
):
|
||||
super(WaveNet, self).__init__()
|
||||
self.transpose_op = P.Transpose()
|
||||
self.softmax = P.Softmax(axis=1)
|
||||
self.reshape_op = P.Reshape()
|
||||
self.zeros_op = P.Zeros()
|
||||
self.ones_op = P.Ones()
|
||||
self.relu_op = P.ReLU()
|
||||
self.squeeze_op = P.Squeeze()
|
||||
self.expandim_op = P.ExpandDims()
|
||||
self.transpose_op = P.Transpose()
|
||||
self.tile_op = P.Tile()
|
||||
self.scalar_input = scalar_input
|
||||
self.out_channels = out_channels
|
||||
self.cin_channels = cin_channels
|
||||
self.output_distribution = output_distribution
|
||||
self.fack_data = P.Zeros()
|
||||
assert layers % stacks == 0
|
||||
layers_per_stack = layers // stacks
|
||||
if scalar_input:
|
||||
self.first_conv = Conv1d1x1(1, residual_channels)
|
||||
else:
|
||||
self.first_conv = Conv1d1x1(out_channels, residual_channels)
|
||||
|
||||
conv_layers = []
|
||||
for layer in range(layers):
|
||||
dilation = 2 ** (layer % layers_per_stack)
|
||||
conv = ResidualConv1dGLU(
|
||||
residual_channels, gate_channels,
|
||||
kernel_size=kernel_size,
|
||||
skip_out_channels=skip_out_channels,
|
||||
bias=True,
|
||||
dropout=dropout,
|
||||
dilation=dilation,
|
||||
cin_channels=cin_channels,
|
||||
gin_channels=gin_channels)
|
||||
conv_layers.append(conv)
|
||||
self.conv_layers = nn.CellList(conv_layers)
|
||||
self.last_conv_layers = nn.CellList([
|
||||
nn.ReLU(),
|
||||
Conv1d1x1(skip_out_channels, skip_out_channels),
|
||||
nn.ReLU(),
|
||||
Conv1d1x1(skip_out_channels, out_channels)])
|
||||
|
||||
if gin_channels > 0 and use_speaker_embedding:
|
||||
assert n_speakers is not None
|
||||
self.embed_speakers = Embedding(
|
||||
n_speakers, gin_channels, padding_idx=None, std=0.1)
|
||||
else:
|
||||
self.embed_speakers = None
|
||||
|
||||
if upsample_conditional_features:
|
||||
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
|
||||
else:
|
||||
self.upsample_net = None
|
||||
|
||||
self.factor = math.sqrt(1.0 / len(self.conv_layers))
|
||||
|
||||
def _expand_global_features(self, batch_size, time_step, g_fp, is_expand=True):
|
||||
"""Expand global conditioning features to all time steps
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size.
|
||||
time_step (int): Time length.
|
||||
g_fp (Tensor): Global features, (B x C) or (B x C x 1).
|
||||
is_expand (bool) : Expanded global conditioning features
|
||||
|
||||
Returns:
|
||||
Tensor: B x C x T or B x T x C or None
|
||||
"""
|
||||
if g_fp is None:
|
||||
return None
|
||||
if len(g_fp.shape) == 2:
|
||||
g_fp = self.expandim_op(g_fp, -1)
|
||||
else:
|
||||
g_fp = g_fp
|
||||
|
||||
if is_expand:
|
||||
expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step))
|
||||
return expand_fp
|
||||
expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step))
|
||||
expand_fp = self.transpose_op(expand_fp, (0, 2, 1))
|
||||
return expand_fp
|
||||
|
||||
def construct(self, x, c=None, g=None, softmax=False):
|
||||
"""
|
||||
|
||||
Args:
|
||||
x (Tensor): One-hot encoded audio signal
|
||||
c (Tensor): Local conditioning feature
|
||||
g (Tensor): Global conditioning feature
|
||||
softmax (bool): Whether use softmax or not
|
||||
|
||||
Returns:
|
||||
Tensor: Net output
|
||||
|
||||
"""
|
||||
g = None
|
||||
B, _, T = x.shape
|
||||
if g is not None:
|
||||
if self.embed_speakers is not None:
|
||||
g = self.embed_speakers(self.reshape_op(g, (B, -1)))
|
||||
g = self.transpose_op(g, (0, 2, 1))
|
||||
g_bct = self._expand_global_features(B, T, g, is_expand=True)
|
||||
|
||||
if c is not None and self.upsample_net is not None:
|
||||
c = self.upsample_net(c)
|
||||
|
||||
x = self.first_conv(x)
|
||||
skips = 0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, c, g_bct)
|
||||
skips += h
|
||||
skips *= self.factor
|
||||
|
||||
x = skips
|
||||
for f in self.last_conv_layers:
|
||||
x = f(x)
|
||||
x = self.softmax(x) if softmax else x
|
||||
|
||||
return x
|
||||
|
||||
def relu_numpy(self, inX):
|
||||
"""numpy relu function"""
|
||||
return np.maximum(0, inX)
|
||||
|
||||
def softmax_numpy(self, x):
|
||||
""" numpy softmax function """
|
||||
x -= np.max(x, axis=1, keepdims=True)
|
||||
return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
|
||||
|
||||
def incremental_forward(self, initial_input=None, c=None, g=None,
|
||||
T=100, test_inputs=None,
|
||||
tqdm=lambda x: x, softmax=True, quantize=True,
|
||||
log_scale_min=-50.0, is_numpy=True):
|
||||
"""
|
||||
Incremental forward. Current output depends on last output.
|
||||
|
||||
Args:
|
||||
initial_input (Tensor): Initial input, the shape is B x C x 1
|
||||
c (Tensor): Local conditioning feature, the shape is B x C x T
|
||||
g (Tensor): Global conditioning feature, the shape is B x C or B x C x 1
|
||||
T (int): decoding time step.
|
||||
test_inputs: Teacher forcing inputs (for debugging)
|
||||
tqdm (lamda): tqmd
|
||||
softmax (bool): Whether use softmax or not
|
||||
quantize (bool): Whether quantize softmax output in last step when decoding current step
|
||||
log_scale_min (float): Log scale minimum value
|
||||
|
||||
Returns:
|
||||
Tensor: Predicted on-hot encoded samples or scalar vector depending on loss type
|
||||
|
||||
"""
|
||||
|
||||
self.clear_buffer()
|
||||
B = 1
|
||||
|
||||
if test_inputs is not None:
|
||||
if self.scalar_input:
|
||||
if test_inputs.shape[1] == 1:
|
||||
test_inputs = self.transpose_op(test_inputs, (0, 2, 1))
|
||||
else:
|
||||
if test_inputs.shape[1] == self.out_channels:
|
||||
test_inputs = self.transpose_op(test_inputs, (0, 2, 1))
|
||||
|
||||
B = test_inputs.shape[0]
|
||||
if T is None:
|
||||
T = test_inputs.shape[1]
|
||||
else:
|
||||
T = max(T, test_inputs.shape[1])
|
||||
T = int(T)
|
||||
|
||||
# Global conditioning
|
||||
if g is not None:
|
||||
if self.embed_speakers is not None:
|
||||
g = self.embed_speakers(self.reshape_op(g, (B, -1)))
|
||||
g = self.transpose_op(g, (0, 2, 1))
|
||||
assert g.dim() == 3
|
||||
g_btc = self._expand_global_features(B, T, g, is_expand=False)
|
||||
|
||||
# Local conditioning
|
||||
if c is not None:
|
||||
B = c.shape[0]
|
||||
if self.upsample_net is not None:
|
||||
c = self.upsample_net(c)
|
||||
assert c.shape[-1] == T
|
||||
if c.shape[-1] == T:
|
||||
c = self.transpose_op(c, (0, 2, 1))
|
||||
|
||||
outputs = []
|
||||
if initial_input is None:
|
||||
if self.scalar_input:
|
||||
initial_input = self.zeros_op((B, 1, 1), mstype.float32)
|
||||
else:
|
||||
initial_input = np.zeros((B, 1, self.out_channels), np.float32)
|
||||
initial_input[:, :, 127] = 1
|
||||
initial_input = Tensor(initial_input)
|
||||
else:
|
||||
if initial_input.shape[1] == self.out_channels:
|
||||
initial_input = self.transpose_op(initial_input, (0, 2, 1))
|
||||
|
||||
if is_numpy:
|
||||
current_input = initial_input.asnumpy()
|
||||
else:
|
||||
current_input = initial_input
|
||||
|
||||
for t in tqdm(range(T)):
|
||||
if test_inputs is not None and t < test_inputs.shape[1]:
|
||||
current_input = self.expandim_op(test_inputs[:, t, :], 1)
|
||||
else:
|
||||
if t > 0:
|
||||
if not is_numpy:
|
||||
current_input = Tensor(outputs[-1])
|
||||
else:
|
||||
current_input = outputs[-1]
|
||||
|
||||
# Conditioning features for single time step
|
||||
ct = None if c is None else self.expandim_op(c[:, t, :], 1)
|
||||
gt = None if g is None else self.expandim_op(g_btc[:, t, :], 1)
|
||||
|
||||
x = current_input
|
||||
|
||||
if is_numpy:
|
||||
ct = ct.asnumpy()
|
||||
x = self.first_conv.incremental_forward(x, is_numpy=is_numpy)
|
||||
|
||||
skips = 0
|
||||
for f in self.conv_layers:
|
||||
x, h = f.incremental_forward(x, ct, gt, is_numpy=is_numpy)
|
||||
skips += h
|
||||
skips *= self.factor
|
||||
x = skips
|
||||
|
||||
for f in self.last_conv_layers:
|
||||
try:
|
||||
x = f.incremental_forward(x, is_numpy=is_numpy)
|
||||
except AttributeError:
|
||||
if is_numpy:
|
||||
x = self.relu_numpy(x)
|
||||
else:
|
||||
x = self.relu_op(x)
|
||||
|
||||
# Generate next input by sampling
|
||||
if not is_numpy:
|
||||
x = x.asnumpy()
|
||||
if self.scalar_input:
|
||||
if self.output_distribution == "Logistic":
|
||||
x = sample_from_discretized_mix_logistic(x.reshape((B, -1, 1)), log_scale_min=log_scale_min)
|
||||
|
||||
elif self.output_distribution == "Normal":
|
||||
x = sample_from_mix_gaussian(x.reshape((B, -1, 1)), log_scale_min=log_scale_min)
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
x = self.softmax_numpy(np.reshape(x, (B, -1))) if softmax else np.reshape(x, (B, -1))
|
||||
if quantize:
|
||||
x = sample_from_mix_onehotcategorical(x)
|
||||
|
||||
outputs += [x]
|
||||
# T x B x C
|
||||
outputs = np.stack(outputs, 0)
|
||||
# B x C x T
|
||||
outputs = np.transpose(outputs, (1, 2, 0))
|
||||
self.clear_buffer()
|
||||
return outputs
|
||||
|
||||
def clear_buffer(self):
|
||||
"""clear buffer"""
|
||||
self.first_conv.clear_buffer()
|
||||
for f in self.conv_layers:
|
||||
f.clear_buffer()
|
||||
for f in self.last_conv_layers:
|
||||
try:
|
||||
f.clear_buffer()
|
||||
except AttributeError:
|
||||
pass
|
Loading…
Reference in New Issue