!4845 Fix bugs about uncertainty toolbox and vae

Merge pull request !4845 from zhangxinfeng3/master
This commit is contained in:
mindspore-ci-bot 2020-08-21 09:06:41 +08:00 committed by Gitee
commit 29e21479a4
4 changed files with 26 additions and 21 deletions

View File

@ -93,18 +93,21 @@ class ConditionalVAE(Cell):
recon_x = self._decode(z_c)
return recon_x, x, mu, std
def generate_sample(self, sample_y, generate_nums=None, shape=None):
def generate_sample(self, sample_y, generate_nums, shape):
"""
Randomly sample from latent space to generate sample.
Args:
sample_y (Tensor): Define the label of sample, int tensor.
generate_nums (int): The number of samples to generate.
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`.
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)` or math:`(-1, C, H, W)`.
Returns:
Tensor, the generated sample.
"""
generate_nums = check_int_positive(generate_nums)
if not isinstance(shape, tuple) or len(shape) != 4 or shape[0] != generate_nums or shape[0] != -1:
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
sample_y = self.one_hot(sample_y)
sample_c = self.concat((sample_z, sample_y))

View File

@ -88,11 +88,14 @@ class VAE(Cell):
Args:
generate_nums (int): The number of samples to generate.
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`.
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)` or math:`(-1, C, H, W)`.
Returns:
Tensor, the generated sample.
"""
generate_nums = check_int_positive(generate_nums)
if not isinstance(shape, tuple) or len(shape) != 4 or shape[0] != generate_nums or shape[0] != -1:
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
sample = self._decode(sample_z)
sample = self.reshape(sample, shape)

View File

@ -13,18 +13,20 @@
# limitations under the License.
# ============================================================================
"""Toolbox for Uncertainty Evaluation."""
import numpy as np
from copy import deepcopy
import numpy as np
from mindspore._checkparam import check_int_positive, check_bool
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from ...cell import Cell
from ...layer.basic import Dense, Flatten, Dropout
from ...layer.conv import Conv2d
from ...layer.container import SequentialCell
from ...layer.conv import Conv2d
from ...loss import SoftmaxCrossEntropyWithLogits, MSELoss
from ...metrics import Accuracy, MSE
from ...optim import Adam
@ -36,8 +38,7 @@ class UncertaintyEvaluation:
Args:
model (Cell): The model for uncertainty evaluation.
epi_train_dataset (Dataset): A dataset iterator to train model for obtain epistemic uncertainty.
ale_train_dataset (Dataset): A dataset iterator to train model for obtain aleatoric uncertainty.
train_dataset (Dataset): A dataset iterator to train model.
task_type (str): Option for the task types of model
- regression: A regression model.
- classification: A classification model.
@ -45,22 +46,20 @@ class UncertaintyEvaluation:
If the task type is classification, it must be set; if not classification, it need not to be set.
Default: None.
epochs (int): Total number of iterations on the data. Default: 1.
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model.
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model.
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None.
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None.
save_model (bool): Save the uncertainty model or not, if True, the epi_uncer_model_path
and ale_uncer_model_path should not be None. If False, give the path of
the uncertainty model, it will load the model to evaluate, if not given
the path, it will not save or load the uncertainty model.
the path, it will not save or load the uncertainty model. Default: False.
Examples:
>>> network = LeNet()
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
>>> load_param_into_net(network, param_dict)
>>> epi_ds_train = create_dataset('workspace/mnist/train')
>>> ale_ds_train = create_dataset('workspace/mnist/train')
>>> ds_train = create_dataset('workspace/mnist/train')
>>> evaluation = UncertaintyEvaluation(model=network,
>>> epi_train_dataset=epi_ds_train,
>>> ale_train_dataset=ale_ds_train,
>>> train_dataset=ds_train,
>>> task_type='classification',
>>> num_classes=10,
>>> epochs=1,
@ -71,12 +70,12 @@ class UncertaintyEvaluation:
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
"""
def __init__(self, model, epi_train_dataset, ale_train_dataset, task_type, num_classes=None, epochs=1,
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
self.epi_model = model
self.ale_model = model
self.epi_train_dataset = epi_train_dataset
self.ale_train_dataset = ale_train_dataset
self.epi_train_dataset = train_dataset
self.ale_train_dataset = deepcopy(train_dataset)
self.task_type = task_type
self.epochs = check_int_positive(epochs)
self.epi_uncer_model_path = epi_uncer_model_path
@ -93,6 +92,8 @@ class UncertaintyEvaluation:
raise ValueError('The task should be regression or classification.')
if task_type == 'classification':
self.num_classes = check_int_positive(num_classes)
else:
self.num_classes = num_classes
if save_model:
if epi_uncer_model_path is None or ale_uncer_model_path is None:
raise ValueError("If save_model is True, the epi_uncer_model_path and "

View File

@ -119,12 +119,10 @@ if __name__ == '__main__':
param_dict = load_checkpoint('checkpoint_lenet.ckpt')
load_param_into_net(network, param_dict)
# get train and eval dataset
epi_ds_train = create_dataset('workspace/mnist/train')
ale_ds_train = create_dataset('workspace/mnist/train')
ds_train = create_dataset('workspace/mnist/train')
ds_eval = create_dataset('workspace/mnist/test')
evaluation = UncertaintyEvaluation(model=network,
epi_train_dataset=epi_ds_train,
ale_train_dataset=ale_ds_train,
train_dataset=ds_train,
task_type='classification',
num_classes=10,
epochs=1,