forked from mindspore-Ecosystem/mindspore
added Wavenet in CPU mode
This commit is contained in:
parent
b82df95b43
commit
79e93845c9
|
@ -77,6 +77,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
|
|||
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
|
||||
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
|
||||
- [DeepSpeech2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2/README.md)
|
||||
- [Wavenet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/wavenet/README.md)
|
||||
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)
|
||||
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md)
|
||||
- [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md)
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
# [WaveNet Description](#contents)
|
||||
|
||||
WaveNet is a deep neural network for generating raw audio waveforms. The model is fully probabilistic and autoregressive, with the predictive distribution for each audio sample conditioned on all previous ones. We support training and evaluation on GPU.
|
||||
WaveNet is a deep neural network for generating raw audio waveforms. The model is fully probabilistic and autoregressive, with the predictive distribution for each audio sample conditioned on all previous ones. We support training and evaluation on both GPU and CPU.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1609.03499.pdf): ord A, Dieleman S, Zen H, et al. Wavenet: A generative model for raw audio
|
||||
|
||||
|
@ -47,8 +47,8 @@ Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
|
|||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Hardware(GPU/CPU)
|
||||
- Prepare hardware environment with GPU/CPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://cmc-szv.clouddragon.huawei.com/cmcversion/index/search?searchKey=Do-MindSpore%20V100R001C00B622)
|
||||
- For more information, please check the resources below:
|
||||
|
@ -65,37 +65,38 @@ Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
|
|||
.
|
||||
├── audio
|
||||
└──wavenet
|
||||
├──datasets // Note the datasets folder should be download from the above link
|
||||
├──egs // Note the egs folder should be download from the above link
|
||||
├──utils // Note the utils folder should be download from the above link
|
||||
├── audio.py // audio utils. Note this script should be download from a third party
|
||||
├── compute-meanvar-stats.py // Compute mean-variance normalization stats. Note this script should be download from the above link
|
||||
├── evaluate.py // evaluation
|
||||
├── export.py // convert mindspore model to air model
|
||||
├── hparams.py // hyper-parameter configuration. Note this script should be download from the above link
|
||||
├── mksubset.py // Make subset of dataset. Note this script should be download from the above link
|
||||
├── preprocess.py // Preprocess dataset. Note this script should be download from the above link
|
||||
├── preprocess_normalize.py // Perform meanvar normalization to preprocessed features. Note this script should be download from the above link
|
||||
├── README.md // descriptions about WaveNet
|
||||
├── train.py // training scripts
|
||||
├── train_pytorch.py // Note this script should be download from the above link. The initial name of this script is train.py in the project from the link
|
||||
├──datasets // Note the datasets folder should be downloaded from the above link
|
||||
├──egs // Note the egs folder should be downloaded from the above link
|
||||
├──utils // Note the utils folder should be downloaded from the above link
|
||||
├── audio.py // Audio utils. Note this script should be downloaded from a third party
|
||||
├── compute-meanvar-stats.py // Compute mean-variance normalization stats. Note this script should be downloaded from the above link
|
||||
├── evaluate.py // Evaluation
|
||||
├── export.py // Convert mindspore model to air model
|
||||
├── hparams.py // Hyper-parameter configuration. Note this script should be downloaded from the above link
|
||||
├── lrschedule.py // Learning rate scheduler. Note this script should be downloaded from the above link
|
||||
├── mksubset.py // Make subset of dataset. Note this script should be downloaded from the above link
|
||||
├── preprocess.py // Preprocess dataset. Note this script should be downloaded from the above link
|
||||
├── preprocess_normalize.py // Perform meanvar normalization to preprocessed features. Note this script should be downloaded from the above link
|
||||
├── README.md // Descriptions about WaveNet
|
||||
├── train.py // Training scripts
|
||||
├── train_pytorch.py // Note this script should be downloaded from the above link. The initial name of this script is train.py in the project from the link
|
||||
├── src
|
||||
│ ├──__init__.py
|
||||
│ ├──dataset.py // generate dataloader and data processing entry
|
||||
│ ├──callback.py // callbacks to monitor the training
|
||||
│ ├──lr_generator.py // learning rate generator
|
||||
│ └──loss.py // loss function definition
|
||||
│ ├──dataset.py // Generate dataloader and data processing entry
|
||||
│ ├──callback.py // Callbacks to monitor the training
|
||||
│ ├──lr_generator.py // Learning rate generator
|
||||
│ └──loss.py // Loss function definition
|
||||
└── wavenet_vocoder
|
||||
├──__init__.py
|
||||
├──conv.py // extended 1D convolution
|
||||
├──mixture.py // loss function for training and sample function for testing
|
||||
├──modules.py // modules for Wavenet construction
|
||||
├──upsample.py // upsample layer definition
|
||||
├──util.py // utils. Note this script should be download from the above link
|
||||
├──conv.py // Extended 1D convolution
|
||||
├──mixture.py // Loss function for training and sample function for testing
|
||||
├──modules.py // Modules for Wavenet construction
|
||||
├──upsample.py // Upsample layer definition
|
||||
├──util.py // Utils. Note this script should be downloaded from the above link
|
||||
├──wavenet.py // WaveNet networks
|
||||
└──tfcompat // Note this script should be download from the above link
|
||||
└──tfcompat // Note this script should be downloaded from the above link
|
||||
├──__init__.py
|
||||
└──hparam.py // param management tools
|
||||
└──hparam.py // Param management tools
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
@ -105,13 +106,15 @@ Dataset used: [The LJ Speech Dataset](<https://keithito.com/LJ-Speech-Dataset>)
|
|||
```text
|
||||
usage: train.py [--data_path DATA_PATH] [--preset PRESET]
|
||||
[--checkpoint_dir CHECKPOINT_DIR] [--checkpoint CHECKPOINT]
|
||||
[--speaker_id SPEAKER_ID] [--is_distributed IS_DISTRIBUTED]
|
||||
[--speaker_id SPEAKER_ID] [--platform PLATFORM]
|
||||
[--is_distributed IS_DISTRIBUTED]
|
||||
options:
|
||||
--data_path dataset path
|
||||
--preset path of preset parameters (json)
|
||||
--checkpoint_dir directory of saving model checkpoints
|
||||
--checkpoint pre-trained ckpt path, default is "./checkpoints"
|
||||
--speaker_id specific speaker of data in case for multi-speaker datasets, not used currently
|
||||
--platform specify platform to be used, defeault is "GPU"
|
||||
--is_distributed whether distributed training or not
|
||||
|
||||
```
|
||||
|
@ -120,8 +123,9 @@ options:
|
|||
|
||||
```text
|
||||
usage: evaluate.py [--data_path DATA_PATH] [--preset PRESET]
|
||||
[--pretrain_ckpt PRETRAIN_CKPT] [--output_path OUTPUT_PATH]
|
||||
[--speaker_id SPEAKER_ID]
|
||||
[--pretrain_ckpt PRETRAIN_CKPT] [--is_numpy]
|
||||
[--output_path OUTPUT_PATH] [--speaker_id SPEAKER_ID]
|
||||
[--platform PLATFORM]
|
||||
options:
|
||||
--data_path dataset path
|
||||
--preset path of preset parameters (json)
|
||||
|
@ -129,6 +133,7 @@ options:
|
|||
--is_numpy whether using numpy for inference or not
|
||||
--output_path path to save synthesized audio
|
||||
--speaker_id specific speaker of data in case for multi-speaker datasets, not used currently
|
||||
--platform specify platform to be used, defeault is "GPU"
|
||||
```
|
||||
|
||||
More parameters for training and evaluation can be set in file `hparams.py`.
|
||||
|
@ -194,18 +199,19 @@ After the processing, the directory of gaussian will be as follows:
|
|||
└──eval
|
||||
```
|
||||
|
||||
The train_no_dev folder contains the final training data. For mulaw256 and mol, the process is the same. When the training data is prepared,
|
||||
The train_no_dev folder contains the final training data. For mol and gaussian, the process is the same. When the training data is prepared,
|
||||
you can run the following command to train the network:
|
||||
|
||||
```bash
|
||||
# standalone training
|
||||
python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt
|
||||
Standalone training
|
||||
GPU:
|
||||
python train.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt
|
||||
|
||||
distributed training
|
||||
CPU:
|
||||
python train.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt --platform=CPU
|
||||
|
||||
Distributed training (on GPU only)
|
||||
CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' mpirun --allow-run-as-root -n 8 python train.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_save_ckpt --is_distributed=True
|
||||
|
||||
eval
|
||||
python evaluate.py ----data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/ --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --output_path=path_to_save_audio
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
@ -214,21 +220,29 @@ WaveNet has a process of auto-regression and this process currently cannot be ru
|
|||
this [link](https://bbs.huaweicloud.com/forum/thread-94852-1-1.html)
|
||||
|
||||
```bash
|
||||
eval
|
||||
python evaluate.py --data_path= /path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio
|
||||
Evaluation
|
||||
GPU:
|
||||
python evaluate.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio
|
||||
|
||||
CPU:
|
||||
python evaluate.py --data_path=/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt --is_numpy --output_path=path_to_save_audio --platform=CPU
|
||||
```
|
||||
|
||||
## [Convert Process](#contents)
|
||||
|
||||
```bash
|
||||
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --pretrain_ckpt=path_to_load_ckpt
|
||||
GPU:
|
||||
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_dump_hparams --pretrain_ckpt=path_to_load_ckpt
|
||||
|
||||
CPU:
|
||||
python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_dump_hparams --pretrain_ckpt=path_to_load_ckpt --platform=CPU
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
### Training Performance on GPU
|
||||
|
||||
| Parameters | WaveNet |
|
||||
| -------------------------- | ---------------------------------------------------------------|
|
||||
|
|
|
@ -36,10 +36,12 @@ parser.add_argument('--data_path', type=str, required=True, default='',
|
|||
help='Directory contains preprocessed features.')
|
||||
parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).')
|
||||
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
|
||||
parser.add_argument('--is_numpy', action="store_false", default=True, help='Using numpy for inference or not')
|
||||
parser.add_argument('--is_numpy', action="store_true", default=False, help='Using numpy for inference or not')
|
||||
parser.add_argument('--output_path', type=str, default='./out_wave/', help='Path to save generated audios')
|
||||
parser.add_argument('--speaker_id', type=str, default='',
|
||||
help=' Use specific speaker of data in case for multi-speaker datasets.')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
|
||||
help='run platform, support GPU and CPU. Default: GPU')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
|
@ -183,7 +185,7 @@ def save_ref_audio(hparam, ref, length, target_wav_path_):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
|
||||
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
|
||||
if args.preset is not None:
|
||||
with open(args.preset) as f:
|
||||
|
|
|
@ -27,14 +27,18 @@ from src.loss import PredictNet
|
|||
|
||||
parser = argparse.ArgumentParser(description='TTS training')
|
||||
parser.add_argument('--preset', type=str, default='', help='Path of preset parameters (json).')
|
||||
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_test',
|
||||
help='Directory where to save model checkpoints [default: checkpoints].')
|
||||
parser.add_argument('--speaker_id', type=str, default='',
|
||||
help=' Use specific speaker of data in case for multi-speaker datasets.')
|
||||
parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
|
||||
help='run platform, support GPU and CPU. Default: GPU')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
|
||||
|
||||
speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
|
||||
if args.preset is not None:
|
||||
|
@ -82,13 +86,14 @@ if __name__ == '__main__':
|
|||
|
||||
Net = PredictNet(model)
|
||||
Net.set_train(False)
|
||||
receptive_field = model.receptive_field
|
||||
print("Receptive field (samples / ms): {} / {}".format(receptive_field, receptive_field / fs * 1000))
|
||||
param_dict = load_checkpoint(args.pretrain_ckpt)
|
||||
load_param_into_net(model, param_dict)
|
||||
print('Successfully loading the pre-trained model')
|
||||
|
||||
x = np.array(np.random.random((2, 256, 10240)), dtype=np.float32)
|
||||
if is_mulaw_quantize(hparams.input_type):
|
||||
x = np.array(np.random.random((2, 256, 10240)), dtype=np.float32)
|
||||
else:
|
||||
x = np.array(np.random.random((2, 1, 10240)), dtype=np.float32)
|
||||
c = np.array(np.random.random((2, 80, 44)), dtype=np.float32)
|
||||
g = np.array([0, 0], dtype=np.int64)
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ import matplotlib.pyplot as plt
|
|||
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from nnmnkwii import preprocessing as P1
|
||||
|
||||
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw
|
||||
|
@ -204,6 +206,7 @@ class NetWithLossClass(nn.Cell):
|
|||
Returns:
|
||||
Tensor, loss tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, network, hparams):
|
||||
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
@ -213,6 +216,7 @@ class NetWithLossClass(nn.Cell):
|
|||
self.transpose_op = P.Transpose()
|
||||
self.reshape_op = P.Reshape()
|
||||
self.is_mulaw_quant = is_mulaw_quantize(hparams.input_type)
|
||||
self.cast = P.Cast()
|
||||
|
||||
if self.is_mulaw_quant:
|
||||
self.criterion = MaskedCrossEntropyLoss()
|
||||
|
@ -225,13 +229,33 @@ class NetWithLossClass(nn.Cell):
|
|||
self.criterion = None
|
||||
raise RuntimeError(
|
||||
"Not supported output distribution type: {}".format(hparams.output_distribution))
|
||||
self.device_target = context.get_context("device_target")
|
||||
|
||||
def construct(self, x, y, c, g, input_lengths, mask):
|
||||
"""
|
||||
|
||||
Args:
|
||||
x (Tensor): input
|
||||
y (Tensor): predition
|
||||
c (Tensor): local_conditioning
|
||||
g (Tensor): global_conditioning
|
||||
input_lengths (Tensor): input_lengths
|
||||
mask (Tensor): Mask
|
||||
|
||||
Returns:
|
||||
Tensor: Loss tensor
|
||||
|
||||
"""
|
||||
y_hat = self.network(x, c, g, False)
|
||||
if self.is_mulaw_quant:
|
||||
y_hat = self.transpose_op(y_hat[:, :, :-1], (0, 2, 1))
|
||||
y_hat = self.reshape_op(y_hat, (-1, y_hat.shape[-1]))
|
||||
y = self.reshape_op(y[:, 1:, 0], (-1,))
|
||||
if self.device_target == "CPU":
|
||||
y = self.cast(y, mstype.float32)
|
||||
y = self.reshape_op(y[:, 1:, 0], (-1,))
|
||||
y = self.cast(y, mstype.int32)
|
||||
else:
|
||||
y = self.reshape_op(y[:, 1:, 0], (-1,))
|
||||
loss = self.criterion(y_hat, y)
|
||||
else:
|
||||
loss = self.criterion(y_hat[:, :, :-1], y[:, 1:, :], mask[:, 1:, :])
|
||||
|
|
|
@ -44,6 +44,8 @@ parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_test',
|
|||
parser.add_argument('--checkpoint', type=str, default='', help='Restore model from checkpoint path if given.')
|
||||
parser.add_argument('--speaker_id', type=str, default='',
|
||||
help=' Use specific speaker of data in case for multi-speaker datasets.')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('GPU', 'CPU'),
|
||||
help='run platform, support GPU and CPU. Default: GPU')
|
||||
parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -57,7 +59,7 @@ if __name__ == '__main__':
|
|||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=False)
|
||||
rank_id = 0
|
||||
group_size = 1
|
||||
|
||||
|
@ -132,4 +134,4 @@ if __name__ == '__main__':
|
|||
config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=10)
|
||||
ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck)
|
||||
callback_list.append(ckpt_cb)
|
||||
model.train(hparams.nepochs, data_loaders, callbacks=callback_list)
|
||||
model.train(hparams.nepochs, data_loaders, callbacks=callback_list, dataset_sink_mode=False)
|
||||
|
|
|
@ -18,6 +18,7 @@ import math
|
|||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
import numpy as np
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
|
@ -84,7 +85,12 @@ class Conv1d(nn.Conv1d):
|
|||
self.input_buffer = self.concat_op((self.input_buffer[:, 1:, :], inputs[:, 0:1, :]))
|
||||
inputs = self.input_buffer
|
||||
if dilation > 1:
|
||||
inputs = inputs[:, 0::dilation, :]
|
||||
if context.get_context("device_target") == "CPU":
|
||||
inputs = self.transpose_op(inputs, (1, 0, 2))
|
||||
inputs = inputs[0::dilation, :, :]
|
||||
inputs = self.transpose_op(inputs, (1, 0, 2))
|
||||
else:
|
||||
inputs = inputs[:, 0::dilation, :]
|
||||
|
||||
output = self.matmul(self.reshape_op(inputs, (bsz, -1)), self.get_weight)
|
||||
if self.bias is not None:
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore as ms
|
|||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as P
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class log_sum_exp(nn.Cell):
|
||||
|
@ -41,6 +42,55 @@ class log_sum_exp(nn.Cell):
|
|||
return m + self.log(self.sums(self.exp(x - m2), axis))
|
||||
|
||||
|
||||
class log_softmax(nn.Cell):
|
||||
"""
|
||||
replacement of P.LogSoftmax(-1) in CPU mode
|
||||
only support x.shape == 2 or 3
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(log_softmax, self).__init__()
|
||||
self.maxi = P.ReduceMax()
|
||||
self.log = P.Log()
|
||||
self.sums = P.ReduceSum()
|
||||
self.exp = P.Exp()
|
||||
self.axis = -1
|
||||
self.concat = P.Concat(-1)
|
||||
self.expanddims = P.ExpandDims()
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
|
||||
Args:
|
||||
x (Tensor): input
|
||||
|
||||
Returns:
|
||||
Tensor: log_softmax of input
|
||||
|
||||
"""
|
||||
c = self.maxi(x, self.axis)
|
||||
logs, lsm = None, None
|
||||
if len(x.shape) == 2:
|
||||
for j in range(x.shape[-1]):
|
||||
temp = self.expanddims(self.exp(x[:, j] - c), -1)
|
||||
logs = temp if j == 0 else self.concat((logs, temp))
|
||||
sums = self.sums(logs, -1)
|
||||
for i in range(x.shape[-1]):
|
||||
temp = self.expanddims(x[:, i] - c - self.log(sums), -1)
|
||||
lsm = temp if i == 0 else self.concat((lsm, temp))
|
||||
return lsm
|
||||
if len(x.shape) == 3:
|
||||
for j in range(x.shape[-1]):
|
||||
temp = self.expanddims(self.exp(x[:, :, j] - c), -1)
|
||||
logs = temp if j == 0 else self.concat((logs, temp))
|
||||
sums = self.sums(logs, -1)
|
||||
for i in range(x.shape[-1]):
|
||||
temp = self.expanddims(x[:, :, i] - c - self.log(sums), -1)
|
||||
lsm = temp if i == 0 else self.concat((lsm, temp))
|
||||
return lsm
|
||||
return None
|
||||
|
||||
|
||||
class Stable_softplus(nn.Cell):
|
||||
"""Numerically stable softplus
|
||||
"""
|
||||
|
@ -77,7 +127,6 @@ class discretized_mix_logistic_loss(nn.Cell):
|
|||
self.softplus = Stable_softplus()
|
||||
self.log = P.Log()
|
||||
self.cast = P.Cast()
|
||||
self.logsoftmax = P.LogSoftmax(-1)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.tile = P.Tile()
|
||||
self.maximum = P.Maximum()
|
||||
|
@ -85,6 +134,12 @@ class discretized_mix_logistic_loss(nn.Cell):
|
|||
self.lse = log_sum_exp()
|
||||
self.reshape = P.Reshape()
|
||||
self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32))
|
||||
self.tensor_one = Tensor(1., ms.float32)
|
||||
|
||||
if context.get_context("device_target") == "CPU":
|
||||
self.logsoftmax = log_softmax()
|
||||
else:
|
||||
self.logsoftmax = P.LogSoftmax(-1)
|
||||
|
||||
def construct(self, y_hat, y):
|
||||
"""
|
||||
|
@ -105,7 +160,8 @@ class discretized_mix_logistic_loss(nn.Cell):
|
|||
# (B, T, num_mixtures) x 3
|
||||
logit_probs = y_hat[:, :, :nr_mix]
|
||||
means = y_hat[:, :, nr_mix:2 * nr_mix]
|
||||
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], self.log_scale_min)
|
||||
min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix))
|
||||
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut)
|
||||
|
||||
# B x T x 1 -> B x T x num_mixtures
|
||||
y = self.tile(y, (1, 1, nr_mix))
|
||||
|
@ -127,8 +183,9 @@ class discretized_mix_logistic_loss(nn.Cell):
|
|||
log_pdf_mid = mid_in - log_scales - 2. * self.softplus(mid_in)
|
||||
|
||||
inner_inner_cond = self.cast(cdf_delta > 1e-5, ms.float32)
|
||||
min_cut2 = 1e-12 * self.tile(self.tensor_one, cdf_delta.shape)
|
||||
inner_inner_out = inner_inner_cond * \
|
||||
self.log(self.maximum(cdf_delta, 1e-12)) + \
|
||||
self.log(self.maximum(cdf_delta, min_cut2)) + \
|
||||
(1. - inner_inner_cond) * (log_pdf_mid - self.factor)
|
||||
inner_cond = self.cast(y > 0.999, ms.float32)
|
||||
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
|
||||
|
@ -192,15 +249,19 @@ class mix_gaussian_loss(nn.Cell):
|
|||
self.maximum = P.Maximum()
|
||||
self.tile = P.Tile()
|
||||
self.exp = P.Exp()
|
||||
self.logsoftmax = P.LogSoftmax(-1)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.sums = P.ReduceSum()
|
||||
self.lse = log_sum_exp()
|
||||
|
||||
self.sq = P.Square()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.const = P.ScalarToArray()
|
||||
self.log = P.Log()
|
||||
self.tensor_one = Tensor(1., ms.float32)
|
||||
|
||||
if context.get_context("device_target") == "CPU":
|
||||
self.logsoftmax = log_softmax()
|
||||
else:
|
||||
self.logsoftmax = P.LogSoftmax(-1)
|
||||
|
||||
def construct(self, y_hat, y):
|
||||
"""
|
||||
|
@ -225,12 +286,14 @@ class mix_gaussian_loss(nn.Cell):
|
|||
if C == 2:
|
||||
logit_probs = None
|
||||
means = y_hat[:, :, 0:1]
|
||||
log_scales = self.maximum(y_hat[:, :, 1:2], self.log_scale_min)
|
||||
min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], 1))
|
||||
log_scales = self.maximum(y_hat[:, :, 1:2], min_cut)
|
||||
else:
|
||||
# (B, T, num_mixtures) x 3
|
||||
logit_probs = y_hat[:, :, :nr_mix]
|
||||
means = y_hat[:, :, nr_mix:2 * nr_mix]
|
||||
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], self.log_scale_min)
|
||||
min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix))
|
||||
log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut)
|
||||
|
||||
# B x T x 1 -> B x T x num_mixtures
|
||||
y = self.tile(y, (1, 1, nr_mix))
|
||||
|
|
Loading…
Reference in New Issue