forked from mindspore-Ecosystem/mindspore
!13509 Fix explainer API documents bug to make it executable
From: @lixiaohui33 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
385edf507d
|
@ -83,21 +83,30 @@ class ImageClassificationRunner:
|
|||
>>> from mindspore.explainer.benchmark import Faithfulness
|
||||
>>> from mindspore.nn import Softmax
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
|
||||
>>> dataset = get_dataset('/path/to/Cifar10_dataset')
|
||||
>>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
||||
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
|
||||
>>> param_dict = load_checkpoint("checkpoint.ckpt")
|
||||
>>> net = resnet50(len(labels))
|
||||
>>> activation_fn = Softmax()
|
||||
>>>
|
||||
>>> # The detail of AlexNet is shown in model_zoo.official.cv.alexnet.src.alexnet.py
|
||||
>>> net = AlexNet(10)
|
||||
>>> # Load the checkpoint
|
||||
>>> param_dict = load_checkpoint("/path/to/checkpoint")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>>
|
||||
>>> # Prepare the dataset for explaining and evaluation.
|
||||
>>> # The detail of create_dataset_cifar10 method is shown in model_zoo.official.cv.alexnet.src.dataset.py
|
||||
>>>
|
||||
>>> dataset = create_dataset_cifar10("/path/to/cifar/dataset", 1)
|
||||
>>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
||||
>>>
|
||||
>>> activation_fn = Softmax()
|
||||
>>> gbp = GuidedBackprop(net)
|
||||
>>> gradient = Gradient(net)
|
||||
>>> explainers = [gbp, gradient]
|
||||
>>> faithfulness = Faithfulness(len(labels), activation_fn, "NaiveFaithfulness")
|
||||
>>> benchmarkers = [faithfulness]
|
||||
>>>
|
||||
>>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn)
|
||||
>>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers)
|
||||
>>> runner.register_uncertainty()
|
||||
>>> runner.register_hierarchical_occlusion()
|
||||
>>> runner.run()
|
||||
"""
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -250,9 +250,9 @@ def summation(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tens
|
|||
|
||||
|
||||
def stack(inputs: List[Tensor], axis: int) -> Tensor:
|
||||
"""Packs a list of tensors in specified axis."""
|
||||
pack_op = op.Pack(axis)
|
||||
outputs = pack_op(inputs)
|
||||
"""Stacks a list of tensors in specified axis."""
|
||||
stack_op = op.Stack(axis)
|
||||
outputs = stack_op(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -104,16 +104,14 @@ def retrieve_layer_by_name(model: _Module, layer_name: str):
|
|||
Retrieve the layer in the model by the given layer_name.
|
||||
|
||||
Args:
|
||||
model (_Module): model which contains the target layer
|
||||
layer_name (str): name of target layer
|
||||
model (Cell): Model which contains the target layer.
|
||||
layer_name (str): Name of target layer.
|
||||
|
||||
Return:
|
||||
- target_layer (_Module)
|
||||
|
||||
Raise:
|
||||
ValueError: if module with given layer_name is not found in the model,
|
||||
raise ValueError.
|
||||
Returns:
|
||||
Cell, the target layer.
|
||||
|
||||
Raises:
|
||||
ValueError: If module with given layer_name is not found in the model.
|
||||
"""
|
||||
if not isinstance(layer_name, str):
|
||||
raise TypeError('layer_name should be type of str, but receive {}.'
|
||||
|
@ -146,13 +144,14 @@ def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''):
|
|||
be raised.
|
||||
|
||||
Args:
|
||||
model (_Module): the model to retrieve the target layer
|
||||
target_layer (Union[str, _Module]): target layer to retrieve. Can be
|
||||
either string (layer name) or the Cell object. If '' is provided,
|
||||
the input model will be returned.
|
||||
model (Cell): Model which contains the target layer.
|
||||
target_layer (str, Cell): Name of target layer or the target layer instance.
|
||||
|
||||
Return:
|
||||
target layer (_Module)
|
||||
Returns:
|
||||
Cell, the target layer.
|
||||
|
||||
Raises:
|
||||
ValueError: If module with given layer_name is not found in the model.
|
||||
"""
|
||||
if isinstance(target_layer, str):
|
||||
target_layer = retrieve_layer_by_name(model, target_layer)
|
||||
|
@ -174,9 +173,7 @@ class ForwardProbe:
|
|||
Probe to capture output of specific layer in a given model.
|
||||
|
||||
Args:
|
||||
target_layer (_Module): name of target layer or just provide the
|
||||
target layer.
|
||||
|
||||
target_layer (str, Cell): Name of target layer or the target layer instance.
|
||||
"""
|
||||
|
||||
def __init__(self, target_layer: _Module):
|
||||
|
@ -204,7 +201,7 @@ class ForwardProbe:
|
|||
|
||||
|
||||
def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray:
|
||||
"""Unify `mindspore.Tensor` and `np.ndarray` to `np.ndarray`. """
|
||||
"""Unify Tensor and numpy.array to numpy.array."""
|
||||
if isinstance(x, ms.Tensor):
|
||||
x = x.asnumpy()
|
||||
|
||||
|
@ -231,7 +228,7 @@ def calc_correlation(x: Union[ms.Tensor, np.ndarray],
|
|||
|
||||
|
||||
def calc_auc(x: _Array) -> _Array:
|
||||
"""Calculate the Aera under Curve."""
|
||||
"""Calculate the Area under Curve."""
|
||||
# take mean for multiple patches if the model is fully convolutional model
|
||||
if len(x.shape) == 4:
|
||||
x = np.mean(np.mean(x, axis=2), axis=3)
|
||||
|
@ -242,18 +239,11 @@ def calc_auc(x: _Array) -> _Array:
|
|||
|
||||
def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
|
||||
"""
|
||||
Generate rank order fo every pixel in an 2D array.
|
||||
Generate rank order for every pixel in an 2D array.
|
||||
|
||||
The rank order start from 0 to (num_pixel-1). If descending is True, the
|
||||
rank order will generate in a descending order, otherwise in ascending
|
||||
order.
|
||||
|
||||
Example:
|
||||
x = np.array([[4., 3., 1.], [5., 9., 1.]])
|
||||
rank_pixels(x, descending=True)
|
||||
>> np.array([[2, 3, 4], [1, 0, 5]])
|
||||
rank_pixels(x, descending=False)
|
||||
>> np.array([[3, 2, 0], [4, 5, 1]])
|
||||
"""
|
||||
if len(inputs.shape) < 2 or len(inputs.shape) > 3:
|
||||
raise ValueError('Only support 2D or 3D inputs currently.')
|
||||
|
@ -275,16 +265,15 @@ def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor:
|
|||
Resize the intermediate layer _attribution to the same size as inputs.
|
||||
|
||||
Args:
|
||||
inputs (ms.Tensor): the input tensor to be resized
|
||||
size (tupleint]): the targeted size resize to
|
||||
mode (str): the resize mode. Options: 'nearest_neighbor', 'bilinear'
|
||||
inputs (Tensor): The input tensor to be resized.
|
||||
size (tuple[int]): The targeted size resize to.
|
||||
mode (str): The resize mode. Options: 'nearest_neighbor', 'bilinear'.
|
||||
|
||||
Returns:
|
||||
outputs (ms.Tensor): the resized tensor.
|
||||
Tensor, the resized tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: the resize mode is not in ['nearest_neighbor',
|
||||
'bilinear'].
|
||||
ValueError: the resize mode is not in ['nearest_neighbor', 'bilinear'].
|
||||
"""
|
||||
h, w = size
|
||||
if mode == 'nearest_neighbor':
|
||||
|
@ -305,6 +294,6 @@ def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor:
|
|||
resized_np = np.transpose(array_lst, [0, 3, 1, 2])
|
||||
outputs = ms.Tensor(resized_np, inputs.dtype)
|
||||
else:
|
||||
raise ValueError('Unsupported resize mode {}'.format(mode))
|
||||
raise ValueError('Unsupported resize mode {}.'.format(mode))
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -45,19 +45,19 @@ class ClassSensitivity(LabelAgnosticMetric):
|
|||
numpy.ndarray, 1D array of shape :math:`(N,)`, result of class sensitivity evaluated on `explainer`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.benchmark import ClassSensitivity
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # prepare your network and load the trained checkpoint file, e.g., resnet50.
|
||||
>>> network = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(network, param_dict)
|
||||
>>>
|
||||
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
||||
>>> net = LeNet5(10, num_channel=3)
|
||||
>>> # prepare your explainer to be evaluated, e.g., Gradient.
|
||||
>>> gradient = Gradient(network)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> gradient = Gradient(net)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
|
||||
>>> class_sensitivity = ClassSensitivity()
|
||||
>>> res = class_sensitivity.evaluate(gradient, input_x)
|
||||
>>> print(res)
|
||||
"""
|
||||
self._check_evaluate_param(explainer, inputs)
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -419,16 +419,20 @@ class Faithfulness(LabelSensitiveMetric):
|
|||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> # init an explainer with a trained network, e.g., resnet50
|
||||
>>> gradient = Gradient(network)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>>
|
||||
>>>
|
||||
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
||||
>>> net = LeNet5(10, num_channel=3)
|
||||
>>> gradient = Gradient(net)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
|
||||
>>> targets = 5
|
||||
>>> # usage 1: input the explainer and the data to be explained,
|
||||
>>> # calculate the faithfulness with the specified metric
|
||||
>>> # faithfulness is a Faithfulness instance
|
||||
>>> res = faithfulness.evaluate(gradient, inputs, targets)
|
||||
>>> # usage 2: input the generated saliency map
|
||||
>>> saliency = gradient(inputs, targets)
|
||||
>>> res = faithfulness.evaluate(gradient, inputs, targets, saliency)
|
||||
>>> print(res)
|
||||
"""
|
||||
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -61,7 +61,7 @@ class Localization(LabelSensitiveMetric):
|
|||
|
||||
Examples:
|
||||
>>> from mindspore.explainer.benchmark import Localization
|
||||
>>> num_labels = 100
|
||||
>>> num_labels = 10
|
||||
>>> localization = Localization(num_labels, "PointingGame")
|
||||
"""
|
||||
|
||||
|
@ -113,18 +113,22 @@ class Localization(LabelSensitiveMetric):
|
|||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> # init an explainer with a trained network, e.g., resnet50
|
||||
>>> gradient = Gradient(network)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> masks = np.zeros([1, 1, 224, 224])
|
||||
>>> masks[:, :, 65: 100, 65: 100] = 1
|
||||
>>>
|
||||
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
||||
>>> net = LeNet5(10, num_channel=3)
|
||||
>>> gradient = Gradient(net)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
|
||||
>>> masks = np.zeros([1, 1, 32, 32])
|
||||
>>> masks[:, :, 10: 20, 10: 20] = 1
|
||||
>>> targets = 5
|
||||
>>> # usage 1: input the explainer and the data to be explained,
|
||||
>>> # calculate the faithfulness with the specified metric
|
||||
>>> # localization is a Localization instance
|
||||
>>> res = localization.evaluate(gradient, inputs, targets, mask=masks)
|
||||
>>> print(res)
|
||||
>>> # usage 2: input the generated saliency map
|
||||
>>> saliency = gradient(inputs, targets)
|
||||
>>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks)
|
||||
>>> print(res)
|
||||
"""
|
||||
self._check_evaluate_param_with_mask(explainer, inputs, targets, saliency, mask)
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ class AttributionMetric:
|
|||
if self._explainer is None:
|
||||
self._explainer = explainer
|
||||
elif self._explainer is not explainer:
|
||||
logger.info('Provided explainer is not the same as previously evaluted one. Please reset the evaluated '
|
||||
logger.info('Provided explainer is not the same as previously evaluated one. Please reset the evaluated '
|
||||
'results. Previous explainer: %s, current explainer: %s', self._explainer, explainer)
|
||||
self._explainer = explainer
|
||||
|
||||
|
@ -107,7 +107,7 @@ class LabelAgnosticMetric(AttributionMetric):
|
|||
raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))
|
||||
|
||||
def get_results(self):
|
||||
"""Return the gloabl results."""
|
||||
"""Return the global results."""
|
||||
return self._global_results.copy()
|
||||
|
||||
def reset(self):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -80,18 +80,16 @@ class Robustness(LabelSensitiveMetric):
|
|||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> from mindspore.explainer.benchmark import Robustness
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # prepare your network and load the trained checkpoint file, e.g., resnet50.
|
||||
>>> network = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(network, param_dict)
|
||||
>>>
|
||||
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
||||
>>> net = LeNet5(10, num_channel=3)
|
||||
>>> # prepare your explainer to be evaluated, e.g., Gradient.
|
||||
>>> gradient = Gradient(network)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> gradient = Gradient(net)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
|
||||
>>> target_label = ms.Tensor([0], ms.int32)
|
||||
>>> # robustness is a Robustness instance
|
||||
>>> res = robustness.evaluate(gradient, input_x, target_label)
|
||||
>>> print(res)
|
||||
"""
|
||||
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
|
|
|
@ -24,15 +24,15 @@ def get_bp_weights(model, inputs, targets=None, weights=None):
|
|||
Compute the gradient of output w.r.t input.
|
||||
|
||||
Args:
|
||||
model (`ms.nn.Cell`): Differentiable black-box model.
|
||||
inputs (`ms.Tensor`): Input to calculate gradient and explanation.
|
||||
model (Cell): Differentiable black-box model.
|
||||
inputs (Tensor): Input to calculate gradient and explanation.
|
||||
targets (int, optional): Target label id specifying which category to compute gradient. Default: None.
|
||||
weights (`ms.Tensor`, optional): Custom weights for computing gradients. The shape of weights should match the
|
||||
model outputs. If None is provided, an one-hot weights with one in targets positions will be used instead.
|
||||
weights (Tensor, optional): Custom weights for computing gradients. The shape of weights should match the model
|
||||
outputs. If None is provided, an one-hot weights with one in targets positions will be used instead.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
saliency map (ms.Tensor): Gradient back-propagated to the input.
|
||||
Tensor, signal to be back-propagated to the input.
|
||||
"""
|
||||
inputs = unify_inputs(inputs)
|
||||
if targets is None and weights is None:
|
||||
|
|
|
@ -61,7 +61,7 @@ class GradCAM(IntermediateLayerAttribution):
|
|||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
layer (str, optional): The layer name to generate the explanation, usually chosen as the last convolutional
|
||||
layer for better practice. If it is '', the explantion will be generated at the input layer.
|
||||
layer for better practice. If it is '', the explanation will be generated at the input layer.
|
||||
Default: ''.
|
||||
|
||||
Inputs:
|
||||
|
@ -76,18 +76,17 @@ class GradCAM(IntermediateLayerAttribution):
|
|||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import GradCAM
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # load a trained network
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>>
|
||||
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
||||
>>> net = LeNet5(10, num_channel=3)
|
||||
>>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer.
|
||||
>>> layer_name = 'layer4'
|
||||
>>> layer_name = 'conv2'
|
||||
>>> # init GradCAM with a trained network and specify the layer to obtain attribution
|
||||
>>> gradcam = GradCAM(net, layer=layer_name)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gradcam(inputs, label)
|
||||
>>> print(saliency.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, network, layer=""):
|
||||
|
|
|
@ -15,32 +15,14 @@
|
|||
"""Gradient explainer."""
|
||||
from copy import deepcopy
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.explainer._operators import reshape, sqrt, Tensor
|
||||
from mindspore.explainer._operators import Tensor
|
||||
from mindspore.explainer._utils import abs_max, unify_inputs, unify_targets
|
||||
|
||||
from .. import Attribution
|
||||
from .backprop_utils import get_bp_weights, GradNet
|
||||
|
||||
|
||||
def _get_hook(bntype, cache):
|
||||
"""Provide backward hook function for BatchNorm layer in eval mode."""
|
||||
var, gamma, eps = cache
|
||||
if bntype == "2d":
|
||||
var = reshape(var, (1, -1, 1, 1))
|
||||
gamma = reshape(gamma, (1, -1, 1, 1))
|
||||
elif bntype == "1d":
|
||||
var = reshape(var, (1, -1, 1))
|
||||
gamma = reshape(gamma, (1, -1, 1))
|
||||
|
||||
def reset_gradient(_, grad_input, grad_output):
|
||||
grad_output = grad_input[0] * gamma / sqrt(var + eps)
|
||||
return grad_output
|
||||
|
||||
return reset_gradient
|
||||
|
||||
|
||||
class Gradient(Attribution):
|
||||
r"""
|
||||
Provides Gradient explanation method.
|
||||
|
@ -72,15 +54,14 @@ class Gradient(Attribution):
|
|||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # init Gradient with a trained network
|
||||
>>> net = resnet50(10) # please refer to model_zoo
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>>
|
||||
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
||||
>>> net = LeNet5(10, num_channel=3)
|
||||
>>> gradient = Gradient(net)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gradient(inputs, label)
|
||||
>>> print(saliency.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
|
@ -88,7 +69,6 @@ class Gradient(Attribution):
|
|||
self._backward_model = deepcopy(network)
|
||||
self._backward_model.set_train(False)
|
||||
self._backward_model.set_grad(False)
|
||||
self._hook_bn()
|
||||
self._grad_net = GradNet(self._backward_model)
|
||||
self._aggregation_fn = abs_max
|
||||
|
||||
|
@ -103,22 +83,19 @@ class Gradient(Attribution):
|
|||
saliency = self._aggregation_fn(gradient)
|
||||
return saliency
|
||||
|
||||
def _hook_bn(self):
|
||||
"""Hook BatchNorm layer for `self._backward_model.`"""
|
||||
for _, cell in self._backward_model.cells_and_names():
|
||||
if isinstance(cell, nn.BatchNorm2d):
|
||||
cache = (cell.moving_variance, cell.gamma, cell.eps)
|
||||
cell.register_backward_hook(_get_hook("2d", cache=cache))
|
||||
elif isinstance(cell, nn.BatchNorm1d):
|
||||
cache = (cell.moving_variance, cell.gamma, cell.eps)
|
||||
cell.register_backward_hook(_get_hook("1d", cache=cache))
|
||||
|
||||
@staticmethod
|
||||
def _verify_data(inputs, targets):
|
||||
"""Verify the validity of the parsed inputs."""
|
||||
"""
|
||||
Verify the validity of the parsed inputs.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The inputs to be explained.
|
||||
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If it is a 1D tensor, its length should be the same as `inputs`.
|
||||
"""
|
||||
check_value_type('inputs', inputs, Tensor)
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument inputs must be 4D Tensor')
|
||||
raise ValueError(f'Argument inputs must be 4D Tensor. But got {len(inputs.shape)}D Tensor.')
|
||||
check_value_type('targets', targets, (Tensor, int))
|
||||
if isinstance(targets, Tensor):
|
||||
if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)):
|
||||
|
|
|
@ -109,16 +109,14 @@ class Deconvolution(ModifiedReLU):
|
|||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Deconvolution
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # init Deconvolution with a trained network.
|
||||
>>> net = resnet50(10) # please refer to model_zoo
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
||||
>>> net = LeNet5(10, num_channel=3)
|
||||
>>> deconvolution = Deconvolution(net)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = deconvolution(inputs, label)
|
||||
>>> print(saliency.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
|
@ -154,17 +152,15 @@ class GuidedBackprop(ModifiedReLU):
|
|||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> from mindspore.explainer.explanation import GuidedBackprop
|
||||
>>> # init GuidedBackprop with a trained network.
|
||||
>>> net = resnet50(10) # please refer to model_zoo
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
||||
>>> net = LeNet5(10, num_channel=3)
|
||||
>>> gbp = GuidedBackprop(net)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> # feed data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gbp(inputs, label)
|
||||
>>> print(saliency.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
|
|
|
@ -130,7 +130,7 @@ class AblationWithSaliency(Ablation):
|
|||
Generate mask for perturbations based on given saliency ranks.
|
||||
|
||||
Args:
|
||||
saliency (np.ndarray): Perturbing masks will be generated based on the given saliency map. The shape of
|
||||
saliency (numpy.array): Perturbing masks will be generated based on the given saliency map. The shape of
|
||||
saliency is expected to be: [batch_size, optional(num_channels), *spatial_size]. If multi-channel
|
||||
saliency is provided, an averaged saliency will be taken to calculate pixel order in spatial dimension.
|
||||
num_channels (optional[int]): Number of channels of the input data. In order to match the shape of inputs,
|
||||
|
@ -139,7 +139,7 @@ class AblationWithSaliency(Ablation):
|
|||
no channel dimension. Default: None.
|
||||
|
||||
Return:
|
||||
mask (np.ndarray): boolen mask for generate perturbations.
|
||||
numpy.array, boolean masks for perturbation generation.
|
||||
"""
|
||||
|
||||
batch_size = saliency.shape[0]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -71,17 +71,15 @@ class Occlusion(PerturbationAttribution):
|
|||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Occlusion
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # prepare your network and load the trained checkpoint file, e.g., resnet50.
|
||||
>>> network = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(network, param_dict)
|
||||
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
||||
>>> net = LeNet5(10, num_channel=3)
|
||||
>>> # initialize Occlusion explainer with the pretrained model and activation function
|
||||
>>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities
|
||||
>>> occlusion = Occlusion(network, activation_fn=activation_fn)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> occlusion = Occlusion(net, activation_fn=activation_fn)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
|
||||
>>> label = ms.Tensor([1], ms.int32)
|
||||
>>> saliency = occlusion(input_x, label)
|
||||
>>> print(saliency.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, network, activation_fn, perturbation_per_eval=32):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -62,23 +62,22 @@ class RISE(PerturbationAttribution):
|
|||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import RISE
|
||||
>>> from mindspore.nn import Sigmoid
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # prepare your network and load the trained checkpoint file, e.g., resnet50.
|
||||
>>> network = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(network, param_dict)
|
||||
>>>
|
||||
>>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
|
||||
>>> net = LeNet5(10, num_channel=3)
|
||||
>>> # initialize RISE explainer with the pretrained model and activation function
|
||||
>>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities
|
||||
>>> rise = RISE(network, activation_fn=activation_fn)
|
||||
>>> rise = RISE(net, activation_fn=activation_fn)
|
||||
>>> # given an instance of RISE, saliency map can be generate
|
||||
>>> inputs = ms.Tensor(np.random.rand(2, 3, 224, 224), ms.float32)
|
||||
>>> inputs = ms.Tensor(np.random.rand(2, 3, 32, 32), ms.float32)
|
||||
>>> # when `targets` is an integer
|
||||
>>> targets = 5
|
||||
>>> saliency = rise(inputs, targets)
|
||||
>>> print(saliency.shape)
|
||||
>>> # `targets` can also be a 2D tensor
|
||||
>>> targets = ms.Tensor([[5], [1]], ms.int32)
|
||||
>>> saliency = rise(inputs, targets)
|
||||
>>> print(saliency.shape)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -88,7 +87,7 @@ class RISE(PerturbationAttribution):
|
|||
super(RISE, self).__init__(network, activation_fn, perturbation_per_eval)
|
||||
|
||||
self._num_masks = 6000 # number of masks to be sampled
|
||||
self._mask_probability = 0.2 # ratio of inputs to be masked
|
||||
self._mask_probability = 0.5 # ratio of inputs to be masked
|
||||
self._down_sample_size = 10 # the original size of binary masks
|
||||
self._resize_mode = 'bilinear' # mode choice to resize the down-sized binary masks to size of the inputs
|
||||
self._perturbation_mode = 'constant' # setting the perturbed pixels to a constant value
|
||||
|
@ -127,7 +126,9 @@ class RISE(PerturbationAttribution):
|
|||
self._num_classes = num_classes
|
||||
|
||||
# Due to the unsupported Op of slice assignment, we use numpy array here
|
||||
attr_np = np.zeros(shape=(batch_size, self._num_classes, height, width))
|
||||
targets = self._unify_targets(inputs, targets)
|
||||
|
||||
attr_np = np.zeros(shape=(batch_size, targets.shape[1], height, width))
|
||||
|
||||
cal_times = math.ceil(self._num_masks / self._perturbation_per_eval)
|
||||
|
||||
|
@ -143,24 +144,21 @@ class RISE(PerturbationAttribution):
|
|||
weights = self._activation_fn(self.network(masked_input))
|
||||
while len(weights.shape) > 2:
|
||||
weights = op.mean(weights, axis=2)
|
||||
weights = op.reshape(weights,
|
||||
(bs, self._num_classes, 1, 1))
|
||||
|
||||
attr_np[idx] += op.summation(weights * masks, axis=0).asnumpy()
|
||||
weights = np.expand_dims(np.expand_dims(weights.asnumpy()[:, targets[idx]], 2), 3)
|
||||
|
||||
attr_np[idx] += np.sum(weights * masks.asnumpy(), axis=0)
|
||||
|
||||
attr_np = attr_np / self._num_masks
|
||||
targets = self._unify_targets(inputs, targets)
|
||||
|
||||
attr_classes = [att_i[target] for att_i, target in zip(attr_np, targets)]
|
||||
|
||||
return op.Tensor(attr_classes, dtype=inputs.dtype)
|
||||
return op.Tensor(attr_np, dtype=inputs.dtype)
|
||||
|
||||
@staticmethod
|
||||
def _verify_data(inputs, targets):
|
||||
"""Verify the validity of the parsed inputs."""
|
||||
check_value_type('inputs', inputs, Tensor)
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument inputs must be 4D Tensor')
|
||||
raise ValueError(f'Argument inputs must be 4D Tensor, but got {len(inputs.shape)}D Tensor.')
|
||||
check_value_type('targets', targets, (Tensor, int, tuple, list))
|
||||
if isinstance(targets, Tensor):
|
||||
if len(targets.shape) > 2:
|
||||
|
@ -168,7 +166,7 @@ class RISE(PerturbationAttribution):
|
|||
'But got {}D.'.format(len(targets.shape)))
|
||||
if targets.shape and len(targets) != len(inputs):
|
||||
raise ValueError(
|
||||
'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}'.format(
|
||||
'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}.'.format(
|
||||
len(inputs), len(targets)))
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue