forked from OSSInnovation/mindspore
!4669 add toolbox and dpn
Merge pull request !4669 from zhangxinfeng3/master
This commit is contained in:
commit
2aab14242c
|
@ -52,8 +52,12 @@ class ConditionalVAE(Cell):
|
|||
super(ConditionalVAE, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
|
||||
raise TypeError('The encoder and decoder should be Cell type.')
|
||||
self.hidden_size = check_int_positive(hidden_size)
|
||||
self.latent_size = check_int_positive(latent_size)
|
||||
if hidden_size < latent_size:
|
||||
raise ValueError('The latent_size should be less than or equal to the hidden_size.')
|
||||
self.num_classes = check_int_positive(num_classes)
|
||||
self.normal = C.normal
|
||||
self.exp = P.Exp()
|
||||
|
@ -78,6 +82,9 @@ class ConditionalVAE(Cell):
|
|||
return recon_x
|
||||
|
||||
def construct(self, x, y):
|
||||
"""
|
||||
The input are x and y, so the WithLossCell method needs to be rewritten when using cvae interface.
|
||||
"""
|
||||
mu, log_var = self._encode(x, y)
|
||||
std = self.exp(0.5 * log_var)
|
||||
z = self.normal(self.shape(mu), mu, std, seed=0)
|
||||
|
|
|
@ -49,8 +49,12 @@ class VAE(Cell):
|
|||
super(VAE, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
|
||||
raise TypeError('The encoder and decoder should be Cell type.')
|
||||
self.hidden_size = check_int_positive(hidden_size)
|
||||
self.latent_size = check_int_positive(latent_size)
|
||||
if hidden_size < latent_size:
|
||||
raise ValueError('The latent_size should be less than or equal to the hidden_size.')
|
||||
self.normal = C.normal
|
||||
self.exp = P.Exp()
|
||||
self.reshape = P.Reshape()
|
||||
|
|
|
@ -15,7 +15,10 @@
|
|||
"""Stochastic Variational Inference(SVI)."""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import check_int_positive
|
||||
from ....cell import Cell
|
||||
from ....wrap.cell_wrapper import TrainOneStepCell
|
||||
from .elbo import ELBO
|
||||
|
||||
|
||||
class SVI:
|
||||
|
@ -35,7 +38,12 @@ class SVI:
|
|||
|
||||
def __init__(self, net_with_loss, optimizer):
|
||||
self.net_with_loss = net_with_loss
|
||||
self.loss_fn = getattr(net_with_loss, '_loss_fn')
|
||||
if not isinstance(self.loss_fn, ELBO):
|
||||
raise TypeError('The loss function for variational inference should be ELBO.')
|
||||
self.optimizer = optimizer
|
||||
if not isinstance(optimizer, Cell):
|
||||
raise TypeError('The optimizer should be Cell type.')
|
||||
self._loss = 0.0
|
||||
|
||||
def run(self, train_dataset, epochs=10):
|
||||
|
@ -49,6 +57,7 @@ class SVI:
|
|||
Outputs:
|
||||
Cell, the trained probability network.
|
||||
"""
|
||||
epochs = check_int_positive(epochs)
|
||||
train_net = TrainOneStepCell(self.net_with_loss, self.optimizer)
|
||||
train_net.set_train()
|
||||
for _ in range(1, epochs+1):
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""Toolbox for Uncertainty Evaluation."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore._checkparam import check_int_positive
|
||||
from mindspore._checkparam import check_int_positive, check_bool
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train import Model
|
||||
|
@ -36,7 +36,8 @@ class UncertaintyEvaluation:
|
|||
|
||||
Args:
|
||||
model (Cell): The model for uncertainty evaluation.
|
||||
train_dataset (Dataset): A dataset iterator.
|
||||
epi_train_dataset (Dataset): A dataset iterator to train model for obtain epistemic uncertainty.
|
||||
ale_train_dataset (Dataset): A dataset iterator to train model for obtain aleatoric uncertainty.
|
||||
task_type (str): Option for the task types of model
|
||||
- regression: A regression model.
|
||||
- classification: A classification model.
|
||||
|
@ -55,9 +56,11 @@ class UncertaintyEvaluation:
|
|||
>>> network = LeNet()
|
||||
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
|
||||
>>> load_param_into_net(network, param_dict)
|
||||
>>> ds_train = create_dataset('workspace/mnist/train')
|
||||
>>> epi_ds_train = create_dataset('workspace/mnist/train')
|
||||
>>> ale_ds_train = create_dataset('workspace/mnist/train')
|
||||
>>> evaluation = UncertaintyEvaluation(model=network,
|
||||
>>> train_dataset=ds_train,
|
||||
>>> epi_train_dataset=epi_ds_train,
|
||||
>>> ale_train_dataset=ale_ds_train,
|
||||
>>> task_type='classification',
|
||||
>>> num_classes=10,
|
||||
>>> epochs=1,
|
||||
|
@ -68,28 +71,30 @@ class UncertaintyEvaluation:
|
|||
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
|
||||
"""
|
||||
|
||||
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
|
||||
def __init__(self, model, epi_train_dataset, ale_train_dataset, task_type, num_classes=None, epochs=1,
|
||||
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
|
||||
self.model = model
|
||||
self.train_dataset = train_dataset
|
||||
self.epi_model = model
|
||||
self.ale_model = model
|
||||
self.epi_train_dataset = epi_train_dataset
|
||||
self.ale_train_dataset = ale_train_dataset
|
||||
self.task_type = task_type
|
||||
self.num_classes = check_int_positive(num_classes)
|
||||
self.epochs = epochs
|
||||
self.epochs = check_int_positive(epochs)
|
||||
self.epi_uncer_model_path = epi_uncer_model_path
|
||||
self.ale_uncer_model_path = ale_uncer_model_path
|
||||
self.save_model = save_model
|
||||
self.save_model = check_bool(save_model)
|
||||
self.epi_uncer_model = None
|
||||
self.ale_uncer_model = None
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.sum = P.ReduceSum()
|
||||
self.pow = P.Pow()
|
||||
if self.task_type not in ('regression', 'classification'):
|
||||
if not isinstance(model, Cell):
|
||||
raise TypeError('The model should be Cell type.')
|
||||
if task_type not in ('regression', 'classification'):
|
||||
raise ValueError('The task should be regression or classification.')
|
||||
if self.task_type == 'classification':
|
||||
if self.num_classes is None:
|
||||
raise ValueError("Classification task needs to input labels.")
|
||||
if self.save_model:
|
||||
if self.epi_uncer_model_path is None or self.ale_uncer_model_path is None:
|
||||
if task_type == 'classification':
|
||||
self.num_classes = check_int_positive(num_classes)
|
||||
if save_model:
|
||||
if epi_uncer_model_path is None or ale_uncer_model_path is None:
|
||||
raise ValueError("If save_model is True, the epi_uncer_model_path and "
|
||||
"ale_uncer_model_path should not be None.")
|
||||
|
||||
|
@ -102,7 +107,7 @@ class UncertaintyEvaluation:
|
|||
Get the model which can obtain the epistemic uncertainty.
|
||||
"""
|
||||
if self.epi_uncer_model is None:
|
||||
self.epi_uncer_model = EpistemicUncertaintyModel(self.model)
|
||||
self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model)
|
||||
if self.epi_uncer_model.drop_count == 0:
|
||||
if self.task_type == 'classification':
|
||||
net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
|
@ -117,9 +122,9 @@ class UncertaintyEvaluation:
|
|||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
|
||||
directory=self.epi_uncer_model_path,
|
||||
config=config_ck)
|
||||
model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
||||
model.train(self.epochs, self.epi_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
||||
elif self.epi_uncer_model_path is None:
|
||||
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
|
||||
model.train(self.epochs, self.epi_train_dataset, callbacks=[LossMonitor()])
|
||||
else:
|
||||
uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
|
||||
load_param_into_net(self.epi_uncer_model, uncer_param_dict)
|
||||
|
@ -148,7 +153,7 @@ class UncertaintyEvaluation:
|
|||
Get the model which can obtain the aleatoric uncertainty.
|
||||
"""
|
||||
if self.ale_uncer_model is None:
|
||||
self.ale_uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type)
|
||||
self.ale_uncer_model = AleatoricUncertaintyModel(self.ale_model, self.num_classes, self.task_type)
|
||||
net_loss = AleatoricLoss(self.task_type)
|
||||
net_opt = Adam(self.ale_uncer_model.trainable_params())
|
||||
if self.task_type == 'classification':
|
||||
|
@ -160,9 +165,9 @@ class UncertaintyEvaluation:
|
|||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
|
||||
directory=self.ale_uncer_model_path,
|
||||
config=config_ck)
|
||||
model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
||||
model.train(self.epochs, self.ale_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
||||
elif self.ale_uncer_model_path is None:
|
||||
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
|
||||
model.train(self.epochs, self.ale_train_dataset, callbacks=[LossMonitor()])
|
||||
else:
|
||||
uncer_param_dict = load_checkpoint(self.ale_uncer_model_path)
|
||||
load_param_into_net(self.ale_uncer_model, uncer_param_dict)
|
||||
|
@ -216,31 +221,31 @@ class EpistemicUncertaintyModel(Cell):
|
|||
<https://arxiv.org/abs/1506.02142>`.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
def __init__(self, epi_model):
|
||||
super(EpistemicUncertaintyModel, self).__init__()
|
||||
self.drop_count = 0
|
||||
self.model = self._make_epistemic(model)
|
||||
self.epi_model = self._make_epistemic(epi_model)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.model(x)
|
||||
x = self.epi_model(x)
|
||||
return x
|
||||
|
||||
def _make_epistemic(self, model, dropout_rate=0.5):
|
||||
def _make_epistemic(self, epi_model, dropout_rate=0.5):
|
||||
"""
|
||||
The dropout rate is set to 0.5 by default.
|
||||
"""
|
||||
for (name, layer) in model.name_cells().items():
|
||||
for (name, layer) in epi_model.name_cells().items():
|
||||
if isinstance(layer, Dropout):
|
||||
self.drop_count += 1
|
||||
return model
|
||||
for (name, layer) in model.name_cells().items():
|
||||
return epi_model
|
||||
for (name, layer) in epi_model.name_cells().items():
|
||||
if isinstance(layer, (Conv2d, Dense)):
|
||||
uncertainty_layer = layer
|
||||
uncertainty_name = name
|
||||
drop = Dropout(keep_prob=dropout_rate)
|
||||
bnn_drop = SequentialCell([uncertainty_layer, drop])
|
||||
setattr(model, uncertainty_name, bnn_drop)
|
||||
return model
|
||||
setattr(epi_model, uncertainty_name, bnn_drop)
|
||||
return epi_model
|
||||
raise ValueError("The model has not Dense Layer or Convolution Layer, "
|
||||
"it can not evaluate epistemic uncertainty so far.")
|
||||
|
||||
|
@ -254,40 +259,40 @@ class AleatoricUncertaintyModel(Cell):
|
|||
<https://arxiv.org/abs/1703.04977>`.
|
||||
"""
|
||||
|
||||
def __init__(self, model, labels, task):
|
||||
def __init__(self, ale_model, num_classes, task):
|
||||
super(AleatoricUncertaintyModel, self).__init__()
|
||||
self.task = task
|
||||
if task == 'classification':
|
||||
self.model = model
|
||||
self.var_layer = Dense(labels, labels)
|
||||
self.ale_model = ale_model
|
||||
self.var_layer = Dense(num_classes, num_classes)
|
||||
else:
|
||||
self.model, self.var_layer, self.pred_layer = self._make_aleatoric(model)
|
||||
self.ale_model, self.var_layer, self.pred_layer = self._make_aleatoric(ale_model)
|
||||
|
||||
def construct(self, x):
|
||||
if self.task == 'classification':
|
||||
pred = self.model(x)
|
||||
pred = self.ale_model(x)
|
||||
var = self.var_layer(pred)
|
||||
else:
|
||||
x = self.model(x)
|
||||
x = self.ale_model(x)
|
||||
pred = self.pred_layer(x)
|
||||
var = self.var_layer(x)
|
||||
return pred, var
|
||||
|
||||
def _make_aleatoric(self, model):
|
||||
def _make_aleatoric(self, ale_model):
|
||||
"""
|
||||
In order to add variance into original loss, add var Layer after the original network.
|
||||
"""
|
||||
dense_layer = dense_name = None
|
||||
for (name, layer) in model.name_cells().items():
|
||||
for (name, layer) in ale_model.name_cells().items():
|
||||
if isinstance(layer, Dense):
|
||||
dense_layer = layer
|
||||
dense_name = name
|
||||
if dense_layer is None:
|
||||
raise ValueError("The model has not Dense Layer, "
|
||||
"it can not evaluate aleatoric uncertainty so far.")
|
||||
setattr(model, dense_name, Flatten())
|
||||
setattr(ale_model, dense_name, Flatten())
|
||||
var_layer = Dense(dense_layer.in_channels, dense_layer.out_channels)
|
||||
return model, var_layer, dense_layer
|
||||
return ale_model, var_layer, dense_layer
|
||||
|
||||
|
||||
class AleatoricLoss(Cell):
|
||||
|
|
|
@ -60,12 +60,10 @@ class Decoder(nn.Cell):
|
|||
return z
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
class CVAEWithLossCell(nn.WithLossCell):
|
||||
"""
|
||||
Rewrite WithLossCell for CVAE
|
||||
"""
|
||||
def construct(self, data, label):
|
||||
out = self._backbone(data, label)
|
||||
return self._loss_fn(out, label)
|
||||
|
@ -100,7 +98,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
|
|||
return mnist_ds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_svi_cave():
|
||||
# define the encoder and decoder
|
||||
encoder = Encoder(num_classes=10)
|
||||
decoder = Decoder()
|
||||
|
@ -113,11 +111,11 @@ if __name__ == "__main__":
|
|||
# define the training dataset
|
||||
ds_train = create_dataset(image_path, 128, 1)
|
||||
# define the WithLossCell modified
|
||||
net_with_loss = WithLossCell(cvae, net_loss)
|
||||
net_with_loss = CVAEWithLossCell(cvae, net_loss)
|
||||
# define the variational inference
|
||||
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
|
||||
# run the vi to return the trained network.
|
||||
cvae = vi.run(train_dataset=ds_train, epochs=10)
|
||||
cvae = vi.run(train_dataset=ds_train, epochs=5)
|
||||
# get the trained loss
|
||||
trained_loss = vi.get_train_loss()
|
||||
# test function: generate_sample
|
||||
|
@ -128,3 +126,6 @@ if __name__ == "__main__":
|
|||
sample_x = Tensor(sample['image'], dtype=mstype.float32)
|
||||
sample_y = Tensor(sample['label'], dtype=mstype.int32)
|
||||
reconstructed_sample = cvae.reconstruct_sample(sample_x, sample_y)
|
||||
print('The loss of the trained network is ', trained_loss)
|
||||
print('The shape of the generated sample is ', generated_sample.shape)
|
||||
print('The shape of the reconstructed sample is ', reconstructed_sample.shape)
|
||||
|
|
|
@ -88,7 +88,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
|
|||
return mnist_ds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_svi_vae():
|
||||
# define the encoder and decoder
|
||||
encoder = Encoder()
|
||||
decoder = Decoder()
|
||||
|
@ -104,7 +104,7 @@ if __name__ == "__main__":
|
|||
# define the variational inference
|
||||
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
|
||||
# run the vi to return the trained network.
|
||||
vae = vi.run(train_dataset=ds_train, epochs=10)
|
||||
vae = vi.run(train_dataset=ds_train, epochs=5)
|
||||
# get the trained loss
|
||||
trained_loss = vi.get_train_loss()
|
||||
# test function: generate_sample
|
||||
|
@ -113,3 +113,6 @@ if __name__ == "__main__":
|
|||
for sample in ds_train.create_dict_iterator():
|
||||
sample_x = Tensor(sample['image'], dtype=mstype.float32)
|
||||
reconstructed_sample = vae.reconstruct_sample(sample_x)
|
||||
print('The loss of the trained network is ', trained_loss)
|
||||
print('The hape of the generated sample is ', generated_sample.shape)
|
||||
print('The shape of the reconstructed sample is ', reconstructed_sample.shape)
|
||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.dataset.transforms.vision.c_transforms as CV
|
|||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.nn.probability.dpn import VAE
|
||||
from mindspore.nn.probability.infer import ELBO, SVI
|
||||
|
||||
|
@ -93,17 +94,18 @@ class VaeGan(nn.Cell):
|
|||
self.dense = nn.Dense(20, 400)
|
||||
self.vae = VAE(self.E, self.G, 400, 20)
|
||||
self.shape = P.Shape()
|
||||
self.normal = C.normal
|
||||
self.to_tensor = P.ScalarToArray()
|
||||
|
||||
def construct(self, x):
|
||||
recon_x, x, mu, std, z, prior = self.vae(x)
|
||||
z_p = prior('sample', self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0))
|
||||
recon_x, x, mu, std = self.vae(x)
|
||||
z_p = self.normal(self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
|
||||
z_p = self.dense(z_p)
|
||||
x_p = self.G(z_p)
|
||||
ld_real = self.D(x)
|
||||
ld_fake = self.D(recon_x)
|
||||
ld_p = self.D(x_p)
|
||||
return ld_real, ld_fake, ld_p, recon_x, x, mu, std, z, prior
|
||||
return ld_real, ld_fake, ld_p, recon_x, x, mu, std
|
||||
|
||||
|
||||
class VaeGanLoss(nn.Cell):
|
||||
|
@ -111,13 +113,13 @@ class VaeGanLoss(nn.Cell):
|
|||
super(VaeGanLoss, self).__init__()
|
||||
self.zeros = P.ZerosLike()
|
||||
self.mse = nn.MSELoss(reduction='sum')
|
||||
self.elbo = ELBO(latent_prior='Normal', output_dis='Normal')
|
||||
self.elbo = ELBO(latent_prior='Normal', output_prior='Normal')
|
||||
|
||||
def construct(self, data, label):
|
||||
ld_real, ld_fake, ld_p, recon_x, x, mean, std, z, prior = data
|
||||
ld_real, ld_fake, ld_p, recon_x, x, mean, std = data
|
||||
y_real = self.zeros(ld_real) + 1
|
||||
y_fake = self.zeros(ld_fake)
|
||||
elbo_data = (recon_x, x, mean, std, z, prior)
|
||||
elbo_data = (recon_x, x, mean, std)
|
||||
loss_D = self.mse(ld_real, y_real)
|
||||
loss_GD = self.mse(ld_p, y_fake)
|
||||
loss_G = self.mse(ld_fake, y_real)
|
||||
|
@ -154,11 +156,11 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
|
|||
return mnist_ds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_vae_gan():
|
||||
vae_gan = VaeGan()
|
||||
net_loss = VaeGanLoss()
|
||||
optimizer = nn.Adam(params=vae_gan.trainable_params(), learning_rate=0.001)
|
||||
ds_train = create_dataset(image_path, 128, 1)
|
||||
net_with_loss = nn.WithLossCell(vae_gan, net_loss)
|
||||
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
|
||||
vae_gan = vi.run(train_dataset=ds_train, epochs=10)
|
||||
vae_gan = vi.run(train_dataset=ds_train, epochs=5)
|
||||
|
|
|
@ -119,10 +119,12 @@ if __name__ == '__main__':
|
|||
param_dict = load_checkpoint('checkpoint_lenet.ckpt')
|
||||
load_param_into_net(network, param_dict)
|
||||
# get train and eval dataset
|
||||
ds_train = create_dataset('workspace/mnist/train')
|
||||
epi_ds_train = create_dataset('workspace/mnist/train')
|
||||
ale_ds_train = create_dataset('workspace/mnist/train')
|
||||
ds_eval = create_dataset('workspace/mnist/test')
|
||||
evaluation = UncertaintyEvaluation(model=network,
|
||||
train_dataset=ds_train,
|
||||
epi_train_dataset=epi_ds_train,
|
||||
ale_train_dataset=ale_ds_train,
|
||||
task_type='classification',
|
||||
num_classes=10,
|
||||
epochs=1,
|
||||
|
|
Loading…
Reference in New Issue