forked from mindspore-Ecosystem/mindspore
!4845 Fix bugs about uncertainty toolbox and vae
Merge pull request !4845 from zhangxinfeng3/master
This commit is contained in:
commit
29e21479a4
|
@ -93,18 +93,21 @@ class ConditionalVAE(Cell):
|
|||
recon_x = self._decode(z_c)
|
||||
return recon_x, x, mu, std
|
||||
|
||||
def generate_sample(self, sample_y, generate_nums=None, shape=None):
|
||||
def generate_sample(self, sample_y, generate_nums, shape):
|
||||
"""
|
||||
Randomly sample from latent space to generate sample.
|
||||
|
||||
Args:
|
||||
sample_y (Tensor): Define the label of sample, int tensor.
|
||||
generate_nums (int): The number of samples to generate.
|
||||
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`.
|
||||
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)` or math:`(-1, C, H, W)`.
|
||||
|
||||
Returns:
|
||||
Tensor, the generated sample.
|
||||
"""
|
||||
generate_nums = check_int_positive(generate_nums)
|
||||
if not isinstance(shape, tuple) or len(shape) != 4 or shape[0] != generate_nums or shape[0] != -1:
|
||||
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
|
||||
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
|
||||
sample_y = self.one_hot(sample_y)
|
||||
sample_c = self.concat((sample_z, sample_y))
|
||||
|
|
|
@ -88,11 +88,14 @@ class VAE(Cell):
|
|||
|
||||
Args:
|
||||
generate_nums (int): The number of samples to generate.
|
||||
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`.
|
||||
shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)` or math:`(-1, C, H, W)`.
|
||||
|
||||
Returns:
|
||||
Tensor, the generated sample.
|
||||
"""
|
||||
generate_nums = check_int_positive(generate_nums)
|
||||
if not isinstance(shape, tuple) or len(shape) != 4 or shape[0] != generate_nums or shape[0] != -1:
|
||||
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
|
||||
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
|
||||
sample = self._decode(sample_z)
|
||||
sample = self.reshape(sample, shape)
|
||||
|
|
|
@ -13,18 +13,20 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Toolbox for Uncertainty Evaluation."""
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from mindspore._checkparam import check_int_positive, check_bool
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from ...cell import Cell
|
||||
from ...layer.basic import Dense, Flatten, Dropout
|
||||
from ...layer.conv import Conv2d
|
||||
from ...layer.container import SequentialCell
|
||||
from ...layer.conv import Conv2d
|
||||
from ...loss import SoftmaxCrossEntropyWithLogits, MSELoss
|
||||
from ...metrics import Accuracy, MSE
|
||||
from ...optim import Adam
|
||||
|
@ -36,8 +38,7 @@ class UncertaintyEvaluation:
|
|||
|
||||
Args:
|
||||
model (Cell): The model for uncertainty evaluation.
|
||||
epi_train_dataset (Dataset): A dataset iterator to train model for obtain epistemic uncertainty.
|
||||
ale_train_dataset (Dataset): A dataset iterator to train model for obtain aleatoric uncertainty.
|
||||
train_dataset (Dataset): A dataset iterator to train model.
|
||||
task_type (str): Option for the task types of model
|
||||
- regression: A regression model.
|
||||
- classification: A classification model.
|
||||
|
@ -45,22 +46,20 @@ class UncertaintyEvaluation:
|
|||
If the task type is classification, it must be set; if not classification, it need not to be set.
|
||||
Default: None.
|
||||
epochs (int): Total number of iterations on the data. Default: 1.
|
||||
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model.
|
||||
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model.
|
||||
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None.
|
||||
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None.
|
||||
save_model (bool): Save the uncertainty model or not, if True, the epi_uncer_model_path
|
||||
and ale_uncer_model_path should not be None. If False, give the path of
|
||||
the uncertainty model, it will load the model to evaluate, if not given
|
||||
the path, it will not save or load the uncertainty model.
|
||||
the path, it will not save or load the uncertainty model. Default: False.
|
||||
|
||||
Examples:
|
||||
>>> network = LeNet()
|
||||
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
|
||||
>>> load_param_into_net(network, param_dict)
|
||||
>>> epi_ds_train = create_dataset('workspace/mnist/train')
|
||||
>>> ale_ds_train = create_dataset('workspace/mnist/train')
|
||||
>>> ds_train = create_dataset('workspace/mnist/train')
|
||||
>>> evaluation = UncertaintyEvaluation(model=network,
|
||||
>>> epi_train_dataset=epi_ds_train,
|
||||
>>> ale_train_dataset=ale_ds_train,
|
||||
>>> train_dataset=ds_train,
|
||||
>>> task_type='classification',
|
||||
>>> num_classes=10,
|
||||
>>> epochs=1,
|
||||
|
@ -71,12 +70,12 @@ class UncertaintyEvaluation:
|
|||
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
|
||||
"""
|
||||
|
||||
def __init__(self, model, epi_train_dataset, ale_train_dataset, task_type, num_classes=None, epochs=1,
|
||||
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
|
||||
epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
|
||||
self.epi_model = model
|
||||
self.ale_model = model
|
||||
self.epi_train_dataset = epi_train_dataset
|
||||
self.ale_train_dataset = ale_train_dataset
|
||||
self.epi_train_dataset = train_dataset
|
||||
self.ale_train_dataset = deepcopy(train_dataset)
|
||||
self.task_type = task_type
|
||||
self.epochs = check_int_positive(epochs)
|
||||
self.epi_uncer_model_path = epi_uncer_model_path
|
||||
|
@ -93,6 +92,8 @@ class UncertaintyEvaluation:
|
|||
raise ValueError('The task should be regression or classification.')
|
||||
if task_type == 'classification':
|
||||
self.num_classes = check_int_positive(num_classes)
|
||||
else:
|
||||
self.num_classes = num_classes
|
||||
if save_model:
|
||||
if epi_uncer_model_path is None or ale_uncer_model_path is None:
|
||||
raise ValueError("If save_model is True, the epi_uncer_model_path and "
|
||||
|
|
|
@ -119,12 +119,10 @@ if __name__ == '__main__':
|
|||
param_dict = load_checkpoint('checkpoint_lenet.ckpt')
|
||||
load_param_into_net(network, param_dict)
|
||||
# get train and eval dataset
|
||||
epi_ds_train = create_dataset('workspace/mnist/train')
|
||||
ale_ds_train = create_dataset('workspace/mnist/train')
|
||||
ds_train = create_dataset('workspace/mnist/train')
|
||||
ds_eval = create_dataset('workspace/mnist/test')
|
||||
evaluation = UncertaintyEvaluation(model=network,
|
||||
epi_train_dataset=epi_ds_train,
|
||||
ale_train_dataset=ale_ds_train,
|
||||
train_dataset=ds_train,
|
||||
task_type='classification',
|
||||
num_classes=10,
|
||||
epochs=1,
|
||||
|
|
Loading…
Reference in New Issue