added Wavenet in CPU mode

This commit is contained in:
huangbo77 2021-01-30 09:40:18 +08:00
parent b82df95b43
commit 79e93845c9
8 changed files with 176 additions and 59 deletions

View File

@ -77,6 +77,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
- [DeepSpeech2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2/README.md)
- [Wavenet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/wavenet/README.md)
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md)
- [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md)

View File

@ -18,7 +18,7 @@
# [WaveNet Description](#contents)
WaveNet is a deep neural network for generating raw audio waveforms. The model is fully probabilistic and autoregressive, with the predictive distribution for each audio sample conditioned on all previous ones. We support training and evaluation on GPU.
WaveNet is a deep neural network for generating raw audio waveforms. The model is fully probabilistic and autoregressive, with the predictive distribution for each audio sample conditioned on all previous ones. We support training and evaluation on both GPU and CPU.
[Paper](https://arxiv.org/pdf/1609.03499.pdf): ord A, Dieleman S, Zen H, et al. Wavenet: A generative model for raw audio
@ -47,8 +47,8 @@ Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
# [Environment Requirements](#contents)
- HardwareGPU
- Prepare hardware environment with GPU processor.
- HardwareGPU/CPU
- Prepare hardware environment with GPU/CPU processor.
- Framework
- [MindSpore](https://cmc-szv.clouddragon.huawei.com/cmcversion/index/search?searchKey=Do-MindSpore%20V100R001C00B622)
- For more information, please check the resources below
@ -65,37 +65,38 @@ Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
.
├── audio
└──wavenet
├──datasets // Note the datasets folder should be download from the above link
├──egs // Note the egs folder should be download from the above link
├──utils // Note the utils folder should be download from the above link
├── audio.py // audio utils. Note this script should be download from a third party
├── compute-meanvar-stats.py // Compute mean-variance normalization stats. Note this script should be download from the above link
├── evaluate.py // evaluation
├── export.py // convert mindspore model to air model
├── hparams.py // hyper-parameter configuration. Note this script should be download from the above link
├── mksubset.py // Make subset of dataset. Note this script should be download from the above link
├── preprocess.py // Preprocess dataset. Note this script should be download from the above link
├── preprocess_normalize.py // Perform meanvar normalization to preprocessed features. Note this script should be download from the above link
├── README.md // descriptions about WaveNet
├── train.py // training scripts
├── train_pytorch.py // Note this script should be download from the above link. The initial name of this script is train.py in the project from the link
├──datasets // Note the datasets folder should be downloaded from the above link
├──egs // Note the egs folder should be downloaded from the above link
├──utils // Note the utils folder should be downloaded from the above link
├── audio.py // Audio utils. Note this script should be downloaded from a third party
├── compute-meanvar-stats.py // Compute mean-variance normalization stats. Note this script should be downloaded from the above link
├── evaluate.py // Evaluation
├── export.py // Convert mindspore model to air model
├── hparams.py // Hyper-parameter configuration. Note this script should be downloaded from the above link
├── lrschedule.py // Learning rate scheduler. Note this script should be downloaded from the above link
├── mksubset.py // Make subset of dataset. Note this script should be downloaded from the above link
├── preprocess.py // Preprocess dataset. Note this script should be downloaded from the above link
├── preprocess_normalize.py // Perform meanvar normalization to preprocessed features. Note this script should be downloaded from the above link
├── README.md // Descriptions about WaveNet
├── train.py // Training scripts
├── train_pytorch.py // Note this script should be downloaded from the above link. The initial name of this script is train.py in the project from the link
├── src
│ ├──__init__.py
│ ├──dataset.py // generate dataloader and data processing entry
│ ├──callback.py // callbacks to monitor the training
│ ├──lr_generator.py // learning rate generator
│ └──loss.py // loss function definition
│ ├──dataset.py // Generate dataloader and data processing entry
│ ├──callback.py // Callbacks to monitor the training
│ ├──lr_generator.py // Learning rate generator
│ └──loss.py // Loss function definition
└── wavenet_vocoder
├──__init__.py
├──conv.py // extended 1D convolution
├──mixture.py // loss function for training and sample function for testing
├──modules.py // modules for Wavenet construction
├──upsample.py // upsample layer definition
├──util.py // utils. Note this script should be download from the above link
├──conv.py // Extended 1D convolution
├──mixture.py // Loss function for training and sample function for testing
├──modules.py // Modules for Wavenet construction
├──upsample.py // Upsample layer definition
├──util.py // Utils. Note this script should be downloaded from the above link
├──wavenet.py // WaveNet networks
└──tfcompat // Note this script should be download from the above link
└──tfcompat // Note this script should be downloaded from the above link
├──__init__.py
└──hparam.py // param management tools
└──hparam.py // Param management tools
```
## [Script Parameters](#contents)
@ -105,13 +106,15 @@ Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
```text
usage: train.py [--data_path DATA_PATH] [--preset PRESET]
[--checkpoint_dir CHECKPOINT_DIR] [--checkpoint CHECKPOINT]
[--speaker_id SPEAKER_ID] [--is_distributed IS_DISTRIBUTED]
[--speaker_id SPEAKER_ID] [--platform PLATFORM]
[--is_distributed IS_DISTRIBUTED]
options:
--data_path dataset path
--preset path of preset parameters (json)
--checkpoint_dir directory of saving model checkpoints
--checkpoint pre-trained ckpt path, default is "./checkpoints"
--speaker_id specific speaker of data in case for multi-speaker datasets, not used currently
--platform specify platform to be used, defeault is "GPU"
--is_distributed whether distributed training or not
```
@ -120,8 +123,9 @@ options:
```text
usage: evaluate.py [--data_path DATA_PATH] [--preset PRESET]
[--pretrain_ckpt PRETRAIN_CKPT] [--output_path OUTPUT_PATH]
[--speaker_id SPEAKER_ID]
[--pretrain_ckpt PRETRAIN_CKPT] [--is_numpy]
[--output_path OUTPUT_PATH] [--speaker_id SPEAKER_ID]
[--platform PLATFORM]
options:
--data_path dataset path
--preset path of preset parameters (json)
@ -129,6 +133,7 @@ options:
--is_numpy whether using numpy for inference or not
--output_path path to save synthesized audio
--speaker_id specific speaker of data in case for multi-speaker datasets, not used currently
--platform specify platform to be used, defeault is "GPU"
```
More parameters for training and evaluation can be set in file `hparams.py`.
@ -194,18 +199,19 @@ After the processing, the directory of gaussian will be as follows:
└──eval
```
The train_no_dev folder contains the final training data. For mulaw256 and mol, the process is the same. When the training data is prepared,
The train_no_dev folder contains the final training data. For mol and gaussian, the process is the same. When the training data is prepared,
you can run the following command to train the network:
```bash
# standalone training
python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt
Standalone training
GPU:
python train.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt
distributed training
CPU:
python train.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt --platform=CPU
Distributed training (on GPU only)
CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' mpirun --allow-run-as-root -n 8 python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt --is_distributed=True
eval
python evaluate.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --output_path=path_to_save_audio
```
## [Evaluation Process](#contents)
@ -214,21 +220,29 @@ WaveNet has a process of auto-regression and this process currently cannot be ru
this [link](https://bbs.huaweicloud.com/forum/thread-94852-1-1.html)
```bash
eval
python evaluate.py --data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio
Evaluation
GPU:
python evaluate.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio
CPU:
python evaluate.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio --platform=CPU
```
## [Convert Process](#contents)
```bash
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt
GPU:
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_dump_hparams --pretrain_ckpt=path_to_load_ckpt
CPU:
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_dump_hparams --pretrain_ckpt=path_to_load_ckpt --platform=CPU
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
### Training Performance on GPU
| Parameters | WaveNet |
| -------------------------- | ---------------------------------------------------------------|

View File

@ -36,10 +36,12 @@ parser.add_argument('--data_path', type=str, required=True, default='',
help='Directory contains preprocessed features.')
parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).')
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
parser.add_argument('--is_numpy', action="store_false", default=True, help='Using numpy for inference or not')
parser.add_argument('--is_numpy', action="store_true", default=False, help='Using numpy for inference or not')
parser.add_argument('--output_path', type=str, default='./out_wave/', help='Path to save generated audios')
parser.add_argument('--speaker_id', type=str, default='',
help=' Use specific speaker of data in case for multi-speaker datasets.')
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
help='run platform, support GPU and CPU. Default: GPU')
args = parser.parse_args()
@ -183,7 +185,7 @@ def save_ref_audio(hparam, ref, length, target_wav_path_):
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
if args.preset is not None:
with open(args.preset) as f:

View File

@ -27,14 +27,18 @@ from src.loss import PredictNet
parser = argparse.ArgumentParser(description='TTS training')
parser.add_argument('--preset', type=str, default='', help='Path of preset parameters (json).')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_test',
help='Directory where to save model checkpoints [default: checkpoints].')
parser.add_argument('--speaker_id', type=str, default='',
help=' Use specific speaker of data in case for multi-speaker datasets.')
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
help='run platform, support GPU and CPU. Default: GPU')
args = parser.parse_args()
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
if args.preset is not None:
@ -82,13 +86,14 @@ if __name__ == '__main__':
Net = PredictNet(model)
Net.set_train(False)
receptive_field = model.receptive_field
print("Receptive field (samples / ms): {} / {}".format(receptive_field, receptive_field / fs * 1000))
param_dict = load_checkpoint(args.pretrain_ckpt)
load_param_into_net(model, param_dict)
print('Successfully loading the pre-trained model')
x = np.array(np.random.random((2, 256, 10240)), dtype=np.float32)
if is_mulaw_quantize(hparams.input_type):
x = np.array(np.random.random((2, 256, 10240)), dtype=np.float32)
else:
x = np.array(np.random.random((2, 1, 10240)), dtype=np.float32)
c = np.array(np.random.random((2, 80, 44)), dtype=np.float32)
g = np.array([0, 0], dtype=np.int64)

View File

@ -20,6 +20,8 @@ import matplotlib.pyplot as plt
from mindspore import nn, Tensor
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
from mindspore import context
from nnmnkwii import preprocessing as P1
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw
@ -204,6 +206,7 @@ class NetWithLossClass(nn.Cell):
Returns:
Tensor, loss tensor.
"""
def __init__(self, network, hparams):
super(NetWithLossClass, self).__init__(auto_prefix=False)
self.network = network
@ -213,6 +216,7 @@ class NetWithLossClass(nn.Cell):
self.transpose_op = P.Transpose()
self.reshape_op = P.Reshape()
self.is_mulaw_quant = is_mulaw_quantize(hparams.input_type)
self.cast = P.Cast()
if self.is_mulaw_quant:
self.criterion = MaskedCrossEntropyLoss()
@ -225,13 +229,33 @@ class NetWithLossClass(nn.Cell):
self.criterion = None
raise RuntimeError(
"Not supported output distribution type: {}".format(hparams.output_distribution))
self.device_target = context.get_context("device_target")
def construct(self, x, y, c, g, input_lengths, mask):
"""
Args:
x (Tensor): input
y (Tensor): predition
c (Tensor): local_conditioning
g (Tensor): global_conditioning
input_lengths (Tensor): input_lengths
mask (Tensor): Mask
Returns:
Tensor: Loss tensor
"""
y_hat = self.network(x, c, g, False)
if self.is_mulaw_quant:
y_hat = self.transpose_op(y_hat[:, :, :-1], (0, 2, 1))
y_hat = self.reshape_op(y_hat, (-1, y_hat.shape[-1]))
y = self.reshape_op(y[:, 1:, 0], (-1,))
if self.device_target == "CPU":
y = self.cast(y, mstype.float32)
y = self.reshape_op(y[:, 1:, 0], (-1,))
y = self.cast(y, mstype.int32)
else:
y = self.reshape_op(y[:, 1:, 0], (-1,))
loss = self.criterion(y_hat, y)
else:
loss = self.criterion(y_hat[:, :, :-1], y[:, 1:, :], mask[:, 1:, :])

View File

@ -44,6 +44,8 @@ parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_test',
parser.add_argument('--checkpoint', type=str, default='', help='Restore model from checkpoint path if given.')
parser.add_argument('--speaker_id', type=str, default='',
help=' Use specific speaker of data in case for multi-speaker datasets.')
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
help='run platform, support GPU and CPU. Default: GPU')
parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training')
args = parser.parse_args()
@ -57,7 +59,7 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
rank_id = 0
group_size = 1
@ -132,4 +134,4 @@ if __name__ == '__main__':
config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=10)
ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck)
callback_list.append(ckpt_cb)
model.train(hparams.nepochs, data_loaders, callbacks=callback_list)
model.train(hparams.nepochs, data_loaders, callbacks=callback_list, dataset_sink_mode=False)

View File

@ -18,6 +18,7 @@ import math
from mindspore import nn, Tensor
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
from mindspore import context
import numpy as np
class Conv1d(nn.Conv1d):
@ -84,7 +85,12 @@ class Conv1d(nn.Conv1d):
self.input_buffer = self.concat_op((self.input_buffer[:, 1:, :], inputs[:, 0:1, :]))
inputs = self.input_buffer
if dilation > 1:
inputs = inputs[:, 0::dilation, :]
if context.get_context("device_target") == "CPU":
inputs = self.transpose_op(inputs, (1, 0, 2))
inputs = inputs[0::dilation, :, :]
inputs = self.transpose_op(inputs, (1, 0, 2))
else:
inputs = inputs[:, 0::dilation, :]
output = self.matmul(self.reshape_op(inputs, (bsz, -1)), self.get_weight)
if self.bias is not None:

View File

@ -20,6 +20,7 @@ import mindspore as ms
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.ops as P
from mindspore import context
class log_sum_exp(nn.Cell):
@ -41,6 +42,55 @@ class log_sum_exp(nn.Cell):
return m + self.log(self.sums(self.exp(x - m2), axis))
class log_softmax(nn.Cell):
"""
replacement of P.LogSoftmax(-1) in CPU mode
only support x.shape == 2 or 3
"""
def __init__(self):
super(log_softmax, self).__init__()
self.maxi = P.ReduceMax()
self.log = P.Log()
self.sums = P.ReduceSum()
self.exp = P.Exp()
self.axis = -1
self.concat = P.Concat(-1)
self.expanddims = P.ExpandDims()
def construct(self, x):
"""
Args:
x (Tensor): input
Returns:
Tensor: log_softmax of input
"""
c = self.maxi(x, self.axis)
logs, lsm = None, None
if len(x.shape) == 2:
for j in range(x.shape[-1]):
temp = self.expanddims(self.exp(x[:, j] - c), -1)
logs = temp if j == 0 else self.concat((logs, temp))
sums = self.sums(logs, -1)
for i in range(x.shape[-1]):
temp = self.expanddims(x[:, i] - c - self.log(sums), -1)
lsm = temp if i == 0 else self.concat((lsm, temp))
return lsm
if len(x.shape) == 3:
for j in range(x.shape[-1]):
temp = self.expanddims(self.exp(x[:, :, j] - c), -1)
logs = temp if j == 0 else self.concat((logs, temp))
sums = self.sums(logs, -1)
for i in range(x.shape[-1]):
temp = self.expanddims(x[:, :, i] - c - self.log(sums), -1)
lsm = temp if i == 0 else self.concat((lsm, temp))
return lsm
return None
class Stable_softplus(nn.Cell):
"""Numerically stable softplus
"""
@ -77,7 +127,6 @@ class discretized_mix_logistic_loss(nn.Cell):
self.softplus = Stable_softplus()
self.log = P.Log()
self.cast = P.Cast()
self.logsoftmax = P.LogSoftmax(-1)
self.expand_dims = P.ExpandDims()
self.tile = P.Tile()
self.maximum = P.Maximum()
@ -85,6 +134,12 @@ class discretized_mix_logistic_loss(nn.Cell):
self.lse = log_sum_exp()
self.reshape = P.Reshape()
self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32))
self.tensor_one = Tensor(1., ms.float32)
if context.get_context("device_target") == "CPU":
self.logsoftmax = log_softmax()
else:
self.logsoftmax = P.LogSoftmax(-1)
def construct(self, y_hat, y):
"""
@ -105,7 +160,8 @@ class discretized_mix_logistic_loss(nn.Cell):
# (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix]
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], self.log_scale_min)
min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix))
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut)
# B x T x 1 -> B x T x num_mixtures
y = self.tile(y, (1, 1, nr_mix))
@ -127,8 +183,9 @@ class discretized_mix_logistic_loss(nn.Cell):
log_pdf_mid = mid_in - log_scales - 2. * self.softplus(mid_in)
inner_inner_cond = self.cast(cdf_delta > 1e-5, ms.float32)
min_cut2 = 1e-12 * self.tile(self.tensor_one, cdf_delta.shape)
inner_inner_out = inner_inner_cond * \
self.log(self.maximum(cdf_delta, 1e-12)) + \
self.log(self.maximum(cdf_delta, min_cut2)) + \
(1. - inner_inner_cond) * (log_pdf_mid - self.factor)
inner_cond = self.cast(y > 0.999, ms.float32)
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
@ -192,15 +249,19 @@ class mix_gaussian_loss(nn.Cell):
self.maximum = P.Maximum()
self.tile = P.Tile()
self.exp = P.Exp()
self.logsoftmax = P.LogSoftmax(-1)
self.expand_dims = P.ExpandDims()
self.sums = P.ReduceSum()
self.lse = log_sum_exp()
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.const = P.ScalarToArray()
self.log = P.Log()
self.tensor_one = Tensor(1., ms.float32)
if context.get_context("device_target") == "CPU":
self.logsoftmax = log_softmax()
else:
self.logsoftmax = P.LogSoftmax(-1)
def construct(self, y_hat, y):
"""
@ -225,12 +286,14 @@ class mix_gaussian_loss(nn.Cell):
if C == 2:
logit_probs = None
means = y_hat[:, :, 0:1]
log_scales = self.maximum(y_hat[:, :, 1:2], self.log_scale_min)
min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], 1))
log_scales = self.maximum(y_hat[:, :, 1:2], min_cut)
else:
# (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix]
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], self.log_scale_min)
min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix))
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut)
# B x T x 1 -> B x T x num_mixtures
y = self.tile(y, (1, 1, nr_mix))