!4669 add toolbox and dpn

Merge pull request !4669 from zhangxinfeng3/master
This commit is contained in:
mindspore-ci-bot 2020-08-19 09:13:29 +08:00 committed by Gitee
commit 2aab14242c
8 changed files with 95 additions and 62 deletions

View File

@ -52,8 +52,12 @@ class ConditionalVAE(Cell):
super(ConditionalVAE, self).__init__()
self.encoder = encoder
self.decoder = decoder
if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
raise TypeError('The encoder and decoder should be Cell type.')
self.hidden_size = check_int_positive(hidden_size)
self.latent_size = check_int_positive(latent_size)
if hidden_size < latent_size:
raise ValueError('The latent_size should be less than or equal to the hidden_size.')
self.num_classes = check_int_positive(num_classes)
self.normal = C.normal
self.exp = P.Exp()
@ -78,6 +82,9 @@ class ConditionalVAE(Cell):
return recon_x
def construct(self, x, y):
"""
The input are x and y, so the WithLossCell method needs to be rewritten when using cvae interface.
"""
mu, log_var = self._encode(x, y)
std = self.exp(0.5 * log_var)
z = self.normal(self.shape(mu), mu, std, seed=0)

View File

@ -49,8 +49,12 @@ class VAE(Cell):
super(VAE, self).__init__()
self.encoder = encoder
self.decoder = decoder
if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
raise TypeError('The encoder and decoder should be Cell type.')
self.hidden_size = check_int_positive(hidden_size)
self.latent_size = check_int_positive(latent_size)
if hidden_size < latent_size:
raise ValueError('The latent_size should be less than or equal to the hidden_size.')
self.normal = C.normal
self.exp = P.Exp()
self.reshape = P.Reshape()

View File

@ -15,7 +15,10 @@
"""Stochastic Variational Inference(SVI)."""
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive
from ....cell import Cell
from ....wrap.cell_wrapper import TrainOneStepCell
from .elbo import ELBO
class SVI:
@ -35,7 +38,12 @@ class SVI:
def __init__(self, net_with_loss, optimizer):
self.net_with_loss = net_with_loss
self.loss_fn = getattr(net_with_loss, '_loss_fn')
if not isinstance(self.loss_fn, ELBO):
raise TypeError('The loss function for variational inference should be ELBO.')
self.optimizer = optimizer
if not isinstance(optimizer, Cell):
raise TypeError('The optimizer should be Cell type.')
self._loss = 0.0
def run(self, train_dataset, epochs=10):
@ -49,6 +57,7 @@ class SVI:
Outputs:
Cell, the trained probability network.
"""
epochs = check_int_positive(epochs)
train_net = TrainOneStepCell(self.net_with_loss, self.optimizer)
train_net.set_train()
for _ in range(1, epochs+1):

View File

@ -15,7 +15,7 @@
"""Toolbox for Uncertainty Evaluation."""
import numpy as np
from mindspore._checkparam import check_int_positive
from mindspore._checkparam import check_int_positive, check_bool
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.train import Model
@ -36,7 +36,8 @@ class UncertaintyEvaluation:
Args:
model (Cell): The model for uncertainty evaluation.
train_dataset (Dataset): A dataset iterator.
epi_train_dataset (Dataset): A dataset iterator to train model for obtain epistemic uncertainty.
ale_train_dataset (Dataset): A dataset iterator to train model for obtain aleatoric uncertainty.
task_type (str): Option for the task types of model
- regression: A regression model.
- classification: A classification model.
@ -55,9 +56,11 @@ class UncertaintyEvaluation:
>>> network = LeNet()
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
>>> load_param_into_net(network, param_dict)
>>> ds_train = create_dataset('workspace/mnist/train')
>>> epi_ds_train = create_dataset('workspace/mnist/train')
>>> ale_ds_train = create_dataset('workspace/mnist/train')
>>> evaluation = UncertaintyEvaluation(model=network,
>>> train_dataset=ds_train,
>>> epi_train_dataset=epi_ds_train,
>>> ale_train_dataset=ale_ds_train,
>>> task_type='classification',
>>> num_classes=10,
>>> epochs=1,
@ -68,28 +71,30 @@ class UncertaintyEvaluation:
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
"""
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
def __init__(self, model, epi_train_dataset, ale_train_dataset, task_type, num_classes=None, epochs=1,
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
self.model = model
self.train_dataset = train_dataset
self.epi_model = model
self.ale_model = model
self.epi_train_dataset = epi_train_dataset
self.ale_train_dataset = ale_train_dataset
self.task_type = task_type
self.num_classes = check_int_positive(num_classes)
self.epochs = epochs
self.epochs = check_int_positive(epochs)
self.epi_uncer_model_path = epi_uncer_model_path
self.ale_uncer_model_path = ale_uncer_model_path
self.save_model = save_model
self.save_model = check_bool(save_model)
self.epi_uncer_model = None
self.ale_uncer_model = None
self.concat = P.Concat(axis=0)
self.sum = P.ReduceSum()
self.pow = P.Pow()
if self.task_type not in ('regression', 'classification'):
if not isinstance(model, Cell):
raise TypeError('The model should be Cell type.')
if task_type not in ('regression', 'classification'):
raise ValueError('The task should be regression or classification.')
if self.task_type == 'classification':
if self.num_classes is None:
raise ValueError("Classification task needs to input labels.")
if self.save_model:
if self.epi_uncer_model_path is None or self.ale_uncer_model_path is None:
if task_type == 'classification':
self.num_classes = check_int_positive(num_classes)
if save_model:
if epi_uncer_model_path is None or ale_uncer_model_path is None:
raise ValueError("If save_model is True, the epi_uncer_model_path and "
"ale_uncer_model_path should not be None.")
@ -102,7 +107,7 @@ class UncertaintyEvaluation:
Get the model which can obtain the epistemic uncertainty.
"""
if self.epi_uncer_model is None:
self.epi_uncer_model = EpistemicUncertaintyModel(self.model)
self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model)
if self.epi_uncer_model.drop_count == 0:
if self.task_type == 'classification':
net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
@ -117,9 +122,9 @@ class UncertaintyEvaluation:
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
directory=self.epi_uncer_model_path,
config=config_ck)
model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
model.train(self.epochs, self.epi_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
elif self.epi_uncer_model_path is None:
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
model.train(self.epochs, self.epi_train_dataset, callbacks=[LossMonitor()])
else:
uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
load_param_into_net(self.epi_uncer_model, uncer_param_dict)
@ -148,7 +153,7 @@ class UncertaintyEvaluation:
Get the model which can obtain the aleatoric uncertainty.
"""
if self.ale_uncer_model is None:
self.ale_uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type)
self.ale_uncer_model = AleatoricUncertaintyModel(self.ale_model, self.num_classes, self.task_type)
net_loss = AleatoricLoss(self.task_type)
net_opt = Adam(self.ale_uncer_model.trainable_params())
if self.task_type == 'classification':
@ -160,9 +165,9 @@ class UncertaintyEvaluation:
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
directory=self.ale_uncer_model_path,
config=config_ck)
model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
model.train(self.epochs, self.ale_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
elif self.ale_uncer_model_path is None:
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
model.train(self.epochs, self.ale_train_dataset, callbacks=[LossMonitor()])
else:
uncer_param_dict = load_checkpoint(self.ale_uncer_model_path)
load_param_into_net(self.ale_uncer_model, uncer_param_dict)
@ -216,31 +221,31 @@ class EpistemicUncertaintyModel(Cell):
<https://arxiv.org/abs/1506.02142>`.
"""
def __init__(self, model):
def __init__(self, epi_model):
super(EpistemicUncertaintyModel, self).__init__()
self.drop_count = 0
self.model = self._make_epistemic(model)
self.epi_model = self._make_epistemic(epi_model)
def construct(self, x):
x = self.model(x)
x = self.epi_model(x)
return x
def _make_epistemic(self, model, dropout_rate=0.5):
def _make_epistemic(self, epi_model, dropout_rate=0.5):
"""
The dropout rate is set to 0.5 by default.
"""
for (name, layer) in model.name_cells().items():
for (name, layer) in epi_model.name_cells().items():
if isinstance(layer, Dropout):
self.drop_count += 1
return model
for (name, layer) in model.name_cells().items():
return epi_model
for (name, layer) in epi_model.name_cells().items():
if isinstance(layer, (Conv2d, Dense)):
uncertainty_layer = layer
uncertainty_name = name
drop = Dropout(keep_prob=dropout_rate)
bnn_drop = SequentialCell([uncertainty_layer, drop])
setattr(model, uncertainty_name, bnn_drop)
return model
setattr(epi_model, uncertainty_name, bnn_drop)
return epi_model
raise ValueError("The model has not Dense Layer or Convolution Layer, "
"it can not evaluate epistemic uncertainty so far.")
@ -254,40 +259,40 @@ class AleatoricUncertaintyModel(Cell):
<https://arxiv.org/abs/1703.04977>`.
"""
def __init__(self, model, labels, task):
def __init__(self, ale_model, num_classes, task):
super(AleatoricUncertaintyModel, self).__init__()
self.task = task
if task == 'classification':
self.model = model
self.var_layer = Dense(labels, labels)
self.ale_model = ale_model
self.var_layer = Dense(num_classes, num_classes)
else:
self.model, self.var_layer, self.pred_layer = self._make_aleatoric(model)
self.ale_model, self.var_layer, self.pred_layer = self._make_aleatoric(ale_model)
def construct(self, x):
if self.task == 'classification':
pred = self.model(x)
pred = self.ale_model(x)
var = self.var_layer(pred)
else:
x = self.model(x)
x = self.ale_model(x)
pred = self.pred_layer(x)
var = self.var_layer(x)
return pred, var
def _make_aleatoric(self, model):
def _make_aleatoric(self, ale_model):
"""
In order to add variance into original loss, add var Layer after the original network.
"""
dense_layer = dense_name = None
for (name, layer) in model.name_cells().items():
for (name, layer) in ale_model.name_cells().items():
if isinstance(layer, Dense):
dense_layer = layer
dense_name = name
if dense_layer is None:
raise ValueError("The model has not Dense Layer, "
"it can not evaluate aleatoric uncertainty so far.")
setattr(model, dense_name, Flatten())
setattr(ale_model, dense_name, Flatten())
var_layer = Dense(dense_layer.in_channels, dense_layer.out_channels)
return model, var_layer, dense_layer
return ale_model, var_layer, dense_layer
class AleatoricLoss(Cell):

View File

@ -60,12 +60,10 @@ class Decoder(nn.Cell):
return z
class WithLossCell(nn.Cell):
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
class CVAEWithLossCell(nn.WithLossCell):
"""
Rewrite WithLossCell for CVAE
"""
def construct(self, data, label):
out = self._backbone(data, label)
return self._loss_fn(out, label)
@ -100,7 +98,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
return mnist_ds
if __name__ == "__main__":
def test_svi_cave():
# define the encoder and decoder
encoder = Encoder(num_classes=10)
decoder = Decoder()
@ -113,11 +111,11 @@ if __name__ == "__main__":
# define the training dataset
ds_train = create_dataset(image_path, 128, 1)
# define the WithLossCell modified
net_with_loss = WithLossCell(cvae, net_loss)
net_with_loss = CVAEWithLossCell(cvae, net_loss)
# define the variational inference
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
# run the vi to return the trained network.
cvae = vi.run(train_dataset=ds_train, epochs=10)
cvae = vi.run(train_dataset=ds_train, epochs=5)
# get the trained loss
trained_loss = vi.get_train_loss()
# test function: generate_sample
@ -128,3 +126,6 @@ if __name__ == "__main__":
sample_x = Tensor(sample['image'], dtype=mstype.float32)
sample_y = Tensor(sample['label'], dtype=mstype.int32)
reconstructed_sample = cvae.reconstruct_sample(sample_x, sample_y)
print('The loss of the trained network is ', trained_loss)
print('The shape of the generated sample is ', generated_sample.shape)
print('The shape of the reconstructed sample is ', reconstructed_sample.shape)

View File

@ -88,7 +88,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
return mnist_ds
if __name__ == "__main__":
def test_svi_vae():
# define the encoder and decoder
encoder = Encoder()
decoder = Decoder()
@ -104,7 +104,7 @@ if __name__ == "__main__":
# define the variational inference
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
# run the vi to return the trained network.
vae = vi.run(train_dataset=ds_train, epochs=10)
vae = vi.run(train_dataset=ds_train, epochs=5)
# get the trained loss
trained_loss = vi.get_train_loss()
# test function: generate_sample
@ -113,3 +113,6 @@ if __name__ == "__main__":
for sample in ds_train.create_dict_iterator():
sample_x = Tensor(sample['image'], dtype=mstype.float32)
reconstructed_sample = vae.reconstruct_sample(sample_x)
print('The loss of the trained network is ', trained_loss)
print('The hape of the generated sample is ', generated_sample.shape)
print('The shape of the reconstructed sample is ', reconstructed_sample.shape)

View File

@ -22,6 +22,7 @@ import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.nn as nn
from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.nn.probability.dpn import VAE
from mindspore.nn.probability.infer import ELBO, SVI
@ -93,17 +94,18 @@ class VaeGan(nn.Cell):
self.dense = nn.Dense(20, 400)
self.vae = VAE(self.E, self.G, 400, 20)
self.shape = P.Shape()
self.normal = C.normal
self.to_tensor = P.ScalarToArray()
def construct(self, x):
recon_x, x, mu, std, z, prior = self.vae(x)
z_p = prior('sample', self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0))
recon_x, x, mu, std = self.vae(x)
z_p = self.normal(self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
z_p = self.dense(z_p)
x_p = self.G(z_p)
ld_real = self.D(x)
ld_fake = self.D(recon_x)
ld_p = self.D(x_p)
return ld_real, ld_fake, ld_p, recon_x, x, mu, std, z, prior
return ld_real, ld_fake, ld_p, recon_x, x, mu, std
class VaeGanLoss(nn.Cell):
@ -111,13 +113,13 @@ class VaeGanLoss(nn.Cell):
super(VaeGanLoss, self).__init__()
self.zeros = P.ZerosLike()
self.mse = nn.MSELoss(reduction='sum')
self.elbo = ELBO(latent_prior='Normal', output_dis='Normal')
self.elbo = ELBO(latent_prior='Normal', output_prior='Normal')
def construct(self, data, label):
ld_real, ld_fake, ld_p, recon_x, x, mean, std, z, prior = data
ld_real, ld_fake, ld_p, recon_x, x, mean, std = data
y_real = self.zeros(ld_real) + 1
y_fake = self.zeros(ld_fake)
elbo_data = (recon_x, x, mean, std, z, prior)
elbo_data = (recon_x, x, mean, std)
loss_D = self.mse(ld_real, y_real)
loss_GD = self.mse(ld_p, y_fake)
loss_G = self.mse(ld_fake, y_real)
@ -154,11 +156,11 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
return mnist_ds
if __name__ == "__main__":
def test_vae_gan():
vae_gan = VaeGan()
net_loss = VaeGanLoss()
optimizer = nn.Adam(params=vae_gan.trainable_params(), learning_rate=0.001)
ds_train = create_dataset(image_path, 128, 1)
net_with_loss = nn.WithLossCell(vae_gan, net_loss)
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
vae_gan = vi.run(train_dataset=ds_train, epochs=10)
vae_gan = vi.run(train_dataset=ds_train, epochs=5)

View File

@ -119,10 +119,12 @@ if __name__ == '__main__':
param_dict = load_checkpoint('checkpoint_lenet.ckpt')
load_param_into_net(network, param_dict)
# get train and eval dataset
ds_train = create_dataset('workspace/mnist/train')
epi_ds_train = create_dataset('workspace/mnist/train')
ale_ds_train = create_dataset('workspace/mnist/train')
ds_eval = create_dataset('workspace/mnist/test')
evaluation = UncertaintyEvaluation(model=network,
train_dataset=ds_train,
epi_train_dataset=epi_ds_train,
ale_train_dataset=ale_ds_train,
task_type='classification',
num_classes=10,
epochs=1,