forked from mindspore-Ecosystem/mindspore
update toolbox
This commit is contained in:
parent
c84ef2c88a
commit
ade9c17a31
|
@ -17,5 +17,6 @@ Uncertainty toolbox.
|
|||
"""
|
||||
|
||||
from .uncertainty_evaluation import UncertaintyEvaluation
|
||||
from .anomaly_detection import VAEAnomalyDetection
|
||||
|
||||
__all__ = ['UncertaintyEvaluation']
|
||||
__all__ = ['UncertaintyEvaluation', 'VAEAnomalyDetection']
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Toolbox for anomaly detection by using VAE."""
|
||||
import numpy as np
|
||||
|
||||
from ..dpn import VAE
|
||||
from ..infer import ELBO, SVI
|
||||
from ...optim import Adam
|
||||
from ...wrap.cell_wrapper import WithLossCell
|
||||
|
||||
|
||||
class VAEAnomalyDetection:
|
||||
r"""
|
||||
Toolbox for anomaly detection by using VAE.
|
||||
|
||||
Variational Auto-Encoder(VAE) can be used for Unsupervised Anomaly Detection. The anomaly score is the error
|
||||
between the X and the reconstruction. If the score is high, the X is mostly outlier.
|
||||
|
||||
Args:
|
||||
encoder(Cell): The Deep Neural Network (DNN) model defined as encoder.
|
||||
decoder(Cell): The DNN model defined as decoder.
|
||||
hidden_size(int): The size of encoder's output tensor.
|
||||
latent_size(int): The size of the latent space.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, encoder, decoder, hidden_size=400, latent_size=20):
|
||||
self.vae = VAE(encoder, decoder, hidden_size, latent_size)
|
||||
|
||||
def train(self, train_dataset, epochs=5):
|
||||
"""
|
||||
Train the VAE model.
|
||||
|
||||
Args:
|
||||
train_dataset (Dataset): A dataset iterator to train model.
|
||||
epochs (int): Total number of iterations on the data. Default: 5.
|
||||
|
||||
Returns:
|
||||
Cell, the trained model.
|
||||
"""
|
||||
net_loss = ELBO()
|
||||
optimizer = Adam(params=self.vae.trainable_params(), learning_rate=0.001)
|
||||
net_with_loss = WithLossCell(self.vae, net_loss)
|
||||
vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
|
||||
self.vae = vi.run(train_dataset, epochs)
|
||||
return self.vae
|
||||
|
||||
def predict_outlier_score(self, sample_x):
|
||||
"""
|
||||
Predict the outlier score.
|
||||
|
||||
Args:
|
||||
sample_x (Tensor): The sample to be predicted, the shape is (N, C, H, W).
|
||||
|
||||
Returns:
|
||||
numpy.dtype, the predicted outlier score of the sample.
|
||||
"""
|
||||
reconstructed_sample = self.vae.reconstruct_sample(sample_x)
|
||||
return self._calculate_euclidean_distance(sample_x.asnumpy(), reconstructed_sample.asnumpy())
|
||||
|
||||
def predict_outlier(self, sample_x, threshold=100.0):
|
||||
"""
|
||||
Predict whether the sample is an outlier.
|
||||
|
||||
Args:
|
||||
sample_x (Tensor): The sample to be predicted, the shape is (N, C, H, W).
|
||||
threshold (float): the threshold of the outlier. Default: 100.0.
|
||||
|
||||
Returns:
|
||||
Bool, whether the sample is an outlier.
|
||||
"""
|
||||
score = self.predict_outlier_score(sample_x)
|
||||
return score >= threshold
|
||||
|
||||
def _calculate_euclidean_distance(self, sample_x, reconstructed_sample):
|
||||
"""
|
||||
Calculate the euclidean distance of the sample_x and reconstructed_sample.
|
||||
"""
|
||||
return np.sqrt(np.sum(np.square(sample_x - reconstructed_sample)))
|
|
@ -47,7 +47,6 @@ class UncertaintyEvaluation:
|
|||
Default: None.
|
||||
epochs (int): Total number of iterations on the data. Default: 1.
|
||||
epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None.
|
||||
If the epi_uncer_model_path is 'Untrain', the epistemic model need not to be trained.
|
||||
ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None.
|
||||
save_model (bool): Whether to save the uncertainty model or not, if true, the epi_uncer_model_path
|
||||
and ale_uncer_model_path must not be None. If false, the model to evaluate will be loaded from
|
||||
|
@ -82,7 +81,7 @@ class UncertaintyEvaluation:
|
|||
self.epi_model = model
|
||||
self.ale_model = deepcopy(model)
|
||||
self.epi_train_dataset = train_dataset
|
||||
self.ale_train_dataset = deepcopy(train_dataset)
|
||||
self.ale_train_dataset = train_dataset
|
||||
self.task_type = task_type
|
||||
self.epochs = Validator.check_positive_int(epochs)
|
||||
self.epi_uncer_model_path = epi_uncer_model_path
|
||||
|
@ -112,7 +111,7 @@ class UncertaintyEvaluation:
|
|||
"""
|
||||
if self.epi_uncer_model is None:
|
||||
self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model)
|
||||
if self.epi_uncer_model.drop_count == 0 and self.epi_uncer_model_path != 'Untrain':
|
||||
if self.epi_uncer_model.drop_count == 0 and self.epi_train_dataset is not None:
|
||||
if self.task_type == 'classification':
|
||||
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_opt = Adam(self.epi_uncer_model.trainable_params())
|
||||
|
@ -156,6 +155,8 @@ class UncertaintyEvaluation:
|
|||
"""
|
||||
Get the model which can obtain the aleatoric uncertainty.
|
||||
"""
|
||||
if self.ale_train_dataset is None:
|
||||
raise ValueError('The train dataset should not be None when evaluating aleatoric uncertainty.')
|
||||
if self.ale_uncer_model is None:
|
||||
self.ale_uncer_model = AleatoricUncertaintyModel(self.ale_model, self.num_classes, self.task_type)
|
||||
net_loss = AleatoricLoss(self.task_type)
|
||||
|
@ -239,17 +240,17 @@ class EpistemicUncertaintyModel(Cell):
|
|||
The dropout rate is set to 0.5 by default.
|
||||
"""
|
||||
for (name, layer) in epi_model.name_cells().items():
|
||||
if isinstance(layer, Dropout):
|
||||
self.drop_count += 1
|
||||
return epi_model
|
||||
for (name, layer) in epi_model.name_cells().items():
|
||||
if isinstance(layer, (Conv2d, Dense)):
|
||||
if isinstance(layer, (Conv2d, Dense, Dropout)):
|
||||
if isinstance(layer, Dropout):
|
||||
self.drop_count += 1
|
||||
return epi_model
|
||||
uncertainty_layer = layer
|
||||
uncertainty_name = name
|
||||
drop = Dropout(keep_prob=dropout_rate)
|
||||
bnn_drop = SequentialCell([uncertainty_layer, drop])
|
||||
setattr(epi_model, uncertainty_name, bnn_drop)
|
||||
return epi_model
|
||||
self._make_epistemic(layer)
|
||||
raise ValueError("The model has not Dense Layer or Convolution Layer, "
|
||||
"it can not evaluate epistemic uncertainty so far.")
|
||||
|
||||
|
|
|
@ -36,21 +36,21 @@ class TransformToBNN:
|
|||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
||||
>>> self.bn = nn.BatchNorm2d(64)
|
||||
>>> self.relu = nn.ReLU()
|
||||
>>> self.flatten = nn.Flatten()
|
||||
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> x = self.conv(x)
|
||||
>>> x = self.bn(x)
|
||||
>>> x = self.relu(x)
|
||||
>>> x = self.flatten(x)
|
||||
>>> out = self.fc(x)
|
||||
>>> return out
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
||||
... self.bn = nn.BatchNorm2d(64)
|
||||
... self.relu = nn.ReLU()
|
||||
... self.flatten = nn.Flatten()
|
||||
... self.fc = nn.Dense(64*224*224, 12) # padding=0
|
||||
...
|
||||
... def construct(self, x):
|
||||
... x = self.conv(x)
|
||||
... x = self.bn(x)
|
||||
... x = self.relu(x)
|
||||
... x = self.flatten(x)
|
||||
... out = self.fc(x)
|
||||
... return out
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
|
|
Loading…
Reference in New Issue