forked from mindspore-Ecosystem/mindspore
refactor explain core code
This commit is contained in:
parent
6ff92a5b2e
commit
2e4b686408
|
@ -64,7 +64,10 @@ class ImageClassificationRunner:
|
|||
should provides [images], [images, labels] or [images, labels, bboxes] as columns. The label list must
|
||||
share the exact same length and order of the network outputs.
|
||||
network (Cell): The network(with logit outputs) to be explained.
|
||||
activation_fn (Cell): The activation function for converting network's output to probabilities.
|
||||
activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For
|
||||
single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks,
|
||||
`nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as
|
||||
when combining this function with network, the final output is the probability of the input.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer import ImageClassificationRunner
|
||||
|
@ -302,6 +305,8 @@ class ImageClassificationRunner:
|
|||
ds.config.set_seed(self._DATASET_SEED)
|
||||
for idx, next_element in enumerate(self._dataset):
|
||||
now = time()
|
||||
self._spaced_print("Start running {}-th explanation data for {}......".format(
|
||||
idx, exp.__class__.__name__), end='')
|
||||
self._run_exp_step(next_element, exp, sample_id_labels, summary)
|
||||
self._spaced_print("Finish writing {}-th explanation data for {}. Time elapsed: "
|
||||
"{:.3f} s".format(idx, exp.__class__.__name__, time() - now), end='')
|
||||
|
@ -320,12 +325,17 @@ class ImageClassificationRunner:
|
|||
ds.config.set_seed(self._DATASET_SEED)
|
||||
for idx, next_element in enumerate(self._dataset):
|
||||
now = time()
|
||||
self._spaced_print("Start running {}-th explanation data for {}......".format(
|
||||
idx, exp.__class__.__name__), end='')
|
||||
saliency_dict_lst = self._run_exp_step(next_element, exp, sample_id_labels, summary)
|
||||
self._spaced_print(
|
||||
"Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
|
||||
idx, exp.__class__.__name__, time() - now), end='')
|
||||
for bench in self._benchmarkers:
|
||||
now = time()
|
||||
self._spaced_print(
|
||||
"Start running {}-th batch {} data for {}......".format(
|
||||
idx, bench.__class__.__name__, exp.__class__.__name__), end='')
|
||||
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
|
||||
self._spaced_print(
|
||||
"Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
|
||||
|
@ -496,7 +506,7 @@ class ImageClassificationRunner:
|
|||
if explainer.__class__ in explainer_classes:
|
||||
raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! "
|
||||
"Please make sure all explainers' class is distinct.")
|
||||
if explainer.model != self._network:
|
||||
if explainer.network is not self._network:
|
||||
raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different "
|
||||
"instance from network of runner. Please make sure they are the same "
|
||||
"instance.")
|
||||
|
@ -717,4 +727,5 @@ class ImageClassificationRunner:
|
|||
@classmethod
|
||||
def _spaced_print(cls, message, *args, **kwargs):
|
||||
"""Spaced message printing."""
|
||||
print(cls._SPACER.format(message), *args, **kwargs)
|
||||
# workaround to print logs starting new line in case line width mismatch.
|
||||
print(cls._SPACER.format(message))
|
||||
|
|
|
@ -226,7 +226,7 @@ def calc_correlation(x: Union[ms.Tensor, np.ndarray],
|
|||
|
||||
if np.all(x == 0) or np.all(y == 0):
|
||||
return np.float(0)
|
||||
faithfulness = -np.corrcoef(x, y)[0, 1]
|
||||
faithfulness = np.corrcoef(x, y)[0, 1]
|
||||
return faithfulness
|
||||
|
||||
|
||||
|
|
|
@ -55,12 +55,12 @@ class ClassSensitivity(LabelAgnosticMetric):
|
|||
>>> # prepare your explainer to be evaluated, e.g., Gradient.
|
||||
>>> gradient = Gradient(network)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> class_sensitivity = ClassSensitivity()
|
||||
>>> # class_sensitivity is a ClassSensitivity instance
|
||||
>>> res = class_sensitivity.evaluate(gradient, input_x)
|
||||
"""
|
||||
self._check_evaluate_param(explainer, inputs)
|
||||
|
||||
outputs = explainer.model(inputs)
|
||||
outputs = explainer.network(inputs)
|
||||
|
||||
max_confidence_label = ops.argmax(outputs)
|
||||
min_confidence_label = ops.argmin(outputs)
|
||||
|
|
|
@ -18,9 +18,9 @@ from typing import Callable, Optional, Union
|
|||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import log, nn
|
||||
from mindspore.train._utils import check_value_type
|
||||
from .metric import LabelSensitiveMetric
|
||||
from ..._utils import calc_auc, format_tensor_to_ndarray
|
||||
from ...explanation._attribution import Attribution as _Attribution
|
||||
|
@ -358,22 +358,26 @@ class Faithfulness(LabelSensitiveMetric):
|
|||
|
||||
Args:
|
||||
num_labels (int): Number of labels.
|
||||
activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For
|
||||
single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks,
|
||||
`nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as
|
||||
when combining this function with network, the final output is the probability of the input.
|
||||
metric (str, optional): The specifi metric to quantify faithfulness.
|
||||
Options: "DeletionAUC", "InsertionAUC", "NaiveFaithfulness".
|
||||
Default: 'NaiveFaithfulness'.
|
||||
activation_fn (Cell, optional): The activation function that transforms the network output to a probability.
|
||||
Default: nn.Softmax().
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import nn
|
||||
>>> from mindspore.explainer.benchmark import Faithfulness
|
||||
>>> # init a `Faithfulness` object
|
||||
>>> num_labels = 10
|
||||
>>> metric = "InsertionAUC"
|
||||
>>> faithfulness = Faithfulness(num_labels, metric)
|
||||
>>> activation_fn = nn.Softmax()
|
||||
>>> faithfulness = Faithfulness(num_labels, activation_fn, metric)
|
||||
"""
|
||||
_methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC]
|
||||
|
||||
def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness", activation_fn=nn.Softmax()):
|
||||
def __init__(self, num_labels, activation_fn, metric="NaiveFaithfulness"):
|
||||
super(Faithfulness, self).__init__(num_labels)
|
||||
|
||||
perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument
|
||||
|
@ -382,7 +386,9 @@ class Faithfulness(LabelSensitiveMetric):
|
|||
num_perturb_steps = 100 # separate the perturbation progress in to 100 steps.
|
||||
base_value = 0.0 # the pixel value set for the perturbed pixels
|
||||
|
||||
check_value_type("activation_fn", activation_fn, nn.Cell)
|
||||
self._activation_fn = activation_fn
|
||||
|
||||
self._verify_metrics(metric)
|
||||
for method in self._methods:
|
||||
if metric == method.__name__:
|
||||
|
@ -437,8 +443,8 @@ class Faithfulness(LabelSensitiveMetric):
|
|||
inputs = format_tensor_to_ndarray(inputs)
|
||||
saliency = format_tensor_to_ndarray(saliency)
|
||||
|
||||
model = nn.SequentialCell([explainer.model, self._activation_fn])
|
||||
faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=model,
|
||||
full_network = nn.SequentialCell([explainer.network, self._activation_fn])
|
||||
faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=full_network,
|
||||
targets=targets, saliency=saliency)
|
||||
return (1 + faithfulness) / 2
|
||||
|
||||
|
|
|
@ -204,9 +204,9 @@ class LabelSensitiveMetric(AttributionMetric):
|
|||
check_value_type('explainer', explainer, Attribution)
|
||||
self._record_explainer(explainer)
|
||||
verify_argument(inputs, 'inputs')
|
||||
output = explainer.model(inputs)
|
||||
output = explainer.network(inputs)
|
||||
check_value_type("output of explainer model", output, Tensor)
|
||||
output_dim = explainer.model(inputs).shape[1]
|
||||
output_dim = explainer.network(inputs).shape[1]
|
||||
if output_dim != self._num_labels:
|
||||
raise ValueError("The output dimension of of black-box model in explainer does not match the dimension "
|
||||
"of num_labels set in the __init__, please check explainer and num_labels again.")
|
||||
|
|
|
@ -18,6 +18,7 @@ import numpy as np
|
|||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore import log
|
||||
from .metric import LabelSensitiveMetric
|
||||
from ...explanation._attribution._perturbation.replacement import RandomPerturb
|
||||
|
@ -30,17 +31,24 @@ class Robustness(LabelSensitiveMetric):
|
|||
|
||||
Args:
|
||||
num_labels (int): Number of classes in the dataset.
|
||||
activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For
|
||||
single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks,
|
||||
`nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as
|
||||
when combining this function with network, the final output is the probability of the input.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> # Initialize a Robustness benchmarker passing num_labels of the dataset.
|
||||
>>> from mindspore import nn
|
||||
>>> from mindspore.explainer.benchmark import Robustness
|
||||
>>> num_labels = 100
|
||||
>>> robustness = Robustness(num_labels)
|
||||
>>> # Initialize a Robustness benchmarker passing num_labels of the dataset.
|
||||
>>> num_labels = 10
|
||||
>>> activation_fn = nn.Softmax()
|
||||
>>> robustness = Robustness(num_labels, activation_fn)
|
||||
"""
|
||||
|
||||
def __init__(self, num_labels, activation_fn=nn.Softmax()):
|
||||
def __init__(self, num_labels, activation_fn):
|
||||
super().__init__(num_labels)
|
||||
|
||||
check_value_type("activation_fn", activation_fn, nn.Cell)
|
||||
self._perturb = RandomPerturb()
|
||||
self._num_perturbations = 10 # number of perturbations used in evaluation
|
||||
self._threshold = 0.1 # threshold to generate perturbation
|
||||
|
@ -69,6 +77,8 @@ class Robustness(LabelSensitiveMetric):
|
|||
ValueError: If batch_size is larger than 1.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> from mindspore.explainer.benchmark import Robustness
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
@ -80,7 +90,7 @@ class Robustness(LabelSensitiveMetric):
|
|||
>>> gradient = Gradient(network)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> target_label = ms.Tensor([0], ms.int32)
|
||||
>>> robustness = Robustness(num_labels=10)
|
||||
>>> # robustness is a Robustness instance
|
||||
>>> res = robustness.evaluate(gradient, input_x, target_label)
|
||||
"""
|
||||
|
||||
|
@ -100,13 +110,13 @@ class Robustness(LabelSensitiveMetric):
|
|||
log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.')
|
||||
norm[norm == 0] = np.nan
|
||||
|
||||
model = nn.SequentialCell([explainer.model, self._activation_fn])
|
||||
original_outputs = model(inputs).asnumpy()
|
||||
full_network = nn.SequentialCell([explainer.network, self._activation_fn])
|
||||
original_outputs = full_network(inputs).asnumpy()
|
||||
sensitivities = []
|
||||
for _ in range(self._num_perturbations):
|
||||
perturbations = []
|
||||
for j, sample in enumerate(inputs_np):
|
||||
perturbation_on_single_sample = self._perturb_with_threshold(model,
|
||||
perturbation_on_single_sample = self._perturb_with_threshold(full_network,
|
||||
np.expand_dims(sample, axis=0),
|
||||
original_outputs[j])
|
||||
perturbations.append(perturbation_on_single_sample)
|
||||
|
@ -120,7 +130,7 @@ class Robustness(LabelSensitiveMetric):
|
|||
robustness_res = 1 / np.exp(max_sensitivity)
|
||||
return robustness_res
|
||||
|
||||
def _perturb_with_threshold(self, model: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray:
|
||||
def _perturb_with_threshold(self, network: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Generate the perturbation until the L2-distance between original_output and perturbation_output is lower than
|
||||
the given self._threshold or until the attempt reaches the max_attempt_time.
|
||||
|
@ -130,7 +140,7 @@ class Robustness(LabelSensitiveMetric):
|
|||
perturbation = None
|
||||
for _ in range(max_attempt_time):
|
||||
perturbation = self._perturb(sample)
|
||||
perturbation_output = self._activation_fn(model(ms.Tensor(sample, ms.float32))).asnumpy()
|
||||
perturbation_output = self._activation_fn(network(ms.Tensor(sample, ms.float32))).asnumpy()
|
||||
perturb_error = np.linalg.norm(original_output - perturbation_output)
|
||||
if perturb_error <= self._threshold:
|
||||
return perturbation
|
||||
|
|
|
@ -39,7 +39,7 @@ def compute_gradients(model, inputs, targets=None, weights=None):
|
|||
raise ValueError('Must provide one of targets or weights')
|
||||
if weights is None:
|
||||
targets = unify_targets(targets)
|
||||
output = model(*inputs).asnumpy()
|
||||
output = model(*inputs)
|
||||
num_categories = output.shape[-1]
|
||||
weights = generate_one_hot(targets, num_categories)
|
||||
|
||||
|
|
|
@ -64,16 +64,30 @@ class GradCAM(IntermediateLayerAttribution):
|
|||
layer for better practice. If it is '', the explantion will be generated at the input layer.
|
||||
Default: ''.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
- **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If it is a 1D tensor, its length should be the same as `inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import GradCAM
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> network = resnet50(10) # please refer to model_zoo
|
||||
>>> # load a trained network
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer.
|
||||
>>> layer_name = 'layer4'
|
||||
>>> # init GradCAM with a trained network and specify the layer to obtain attribution
|
||||
>>> gradcam = GradCAM(net, layer=layer_name)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gradcam(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network, layer=""):
|
||||
|
@ -100,25 +114,7 @@ class GradCAM(IntermediateLayerAttribution):
|
|||
self._intermediate_grad = grad_input
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
"""
|
||||
Call function for `GradCAM`.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If it is a 1D tensor, its length should be the same as `inputs`.
|
||||
|
||||
Returns:
|
||||
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import numpy as np
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> label = 5
|
||||
>>> # gradcam is a GradCAM object, parse data and the target label to be explained and get the attribution
|
||||
>>> saliency = gradcam(inputs, label)
|
||||
"""
|
||||
"""Call function for `GradCAM`."""
|
||||
self._verify_data(inputs, targets)
|
||||
self._hook_cell()
|
||||
|
||||
|
|
|
@ -59,7 +59,17 @@ class Gradient(Attribution):
|
|||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
- **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If it is a 1D tensor, its length should be the same as `inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # init Gradient with a trained network
|
||||
|
@ -67,6 +77,9 @@ class Gradient(Attribution):
|
|||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> gradient = Gradient(net)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gradient(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
|
@ -79,25 +92,7 @@ class Gradient(Attribution):
|
|||
self._aggregation_fn = abs_max
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
"""
|
||||
Call function for `Gradient`.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If it is a 1D tensor, its length should be the same as `inputs`.
|
||||
|
||||
Returns:
|
||||
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import numpy as np
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> label = 5
|
||||
>>> # gradient is a Gradient object, parse data and the target label to be explained and get the attribution
|
||||
>>> saliency = gradient(inputs, label)
|
||||
"""
|
||||
"""Call function for `Gradient`."""
|
||||
self._verify_data(inputs, targets)
|
||||
inputs = unify_inputs(inputs)
|
||||
targets = unify_targets(targets)
|
||||
|
|
|
@ -96,15 +96,23 @@ class Deconvolution(ModifiedReLU):
|
|||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
- **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If it is a 1D tensor, its length should be the same as `inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Deconvolution
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # init Deconvolution with a trained network.
|
||||
>>> net = resnet50(10) # please refer to model_zoo
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # init Deconvolution with a trained network.
|
||||
>>> deconvolution = Deconvolution(net)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
|
@ -134,15 +142,23 @@ class GuidedBackprop(ModifiedReLU):
|
|||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
- **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If it is a 1D tensor, its length should be the same as `inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> from mindspore.explainer.explanation import GuidedBackprop
|
||||
>>> # init GuidedBackprop with a trained network.
|
||||
>>> net = resnet50(10) # please refer to model_zoo
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # init GuidedBackprop with a trained network.
|
||||
>>> gbp = GuidedBackprop(net)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
|
|
|
@ -47,7 +47,7 @@ def _generate_patches(array, window_size, stride):
|
|||
|
||||
|
||||
class Occlusion(PerturbationAttribution):
|
||||
r"""
|
||||
"""
|
||||
Occlusion uses a sliding window to replace the pixels with a reference value (e.g. constant value), and computes
|
||||
the output difference w.r.t the original output. The output difference caused by perturbed pixels are assigned as
|
||||
feature importance to those pixels. For pixels involved in multiple sliding windows, the feature importance is the
|
||||
|
@ -56,7 +56,14 @@ class Occlusion(PerturbationAttribution):
|
|||
For more details, please refer to the original paper via: `<https://arxiv.org/abs/1311.2901>`_.
|
||||
|
||||
Args:
|
||||
network (Cell): Specify the black-box model to be explained.
|
||||
network (Cell): The black-box model to be explained.
|
||||
activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For
|
||||
single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks,
|
||||
`nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as
|
||||
when combining this function with network, the final output is the probability of the input.
|
||||
perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the
|
||||
perturbed samples. Within the memory capacity, usually the larger this number is, the faster the
|
||||
explanation is obtained. Default: 32.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
|
@ -67,27 +74,29 @@ class Occlusion(PerturbationAttribution):
|
|||
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
|
||||
Example:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Occlusion
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # prepare your network and load the trained checkpoint file, e.g., resnet50.
|
||||
>>> network = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(network, param_dict)
|
||||
>>> # initialize Occlusion explainer and pass the pretrained model
|
||||
>>> occlusion = Occlusion(network)
|
||||
>>> # initialize Occlusion explainer with the pretrained model and activation function
|
||||
>>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities
|
||||
>>> occlusion = Occlusion(network, activation_fn=activation_fn)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> label = ms.Tensor([1], ms.int32)
|
||||
>>> saliency = occlusion(input_x, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network, activation_fn=nn.Softmax()):
|
||||
super().__init__(network, activation_fn)
|
||||
def __init__(self, network, activation_fn, perturbation_per_eval=32):
|
||||
super().__init__(network, activation_fn, perturbation_per_eval)
|
||||
|
||||
self._ablation = Ablation(perturb_mode='Deletion')
|
||||
self._aggregation_fn = abs_max
|
||||
self._get_replacement = Constant(base_value=0.0)
|
||||
self._num_sample_per_dim = 32 # specify the number of perturbations each dimension.
|
||||
self._num_per_eval = 2 # number of perturbations generate for each sample per evaluation step.
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
"""Call function for 'Occlusion'."""
|
||||
|
@ -99,9 +108,9 @@ class Occlusion(PerturbationAttribution):
|
|||
batch_size = inputs_np.shape[0]
|
||||
window_size, strides = self._get_window_size_and_strides(inputs_np)
|
||||
|
||||
model = nn.SequentialCell([self._model, self._activation_fn])
|
||||
full_network = nn.SequentialCell([self._network, self._activation_fn])
|
||||
|
||||
original_outputs = model(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets_np]
|
||||
original_outputs = full_network(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets_np]
|
||||
|
||||
total_attribution = np.zeros_like(inputs_np)
|
||||
weights = np.ones_like(inputs_np)
|
||||
|
@ -111,13 +120,13 @@ class Occlusion(PerturbationAttribution):
|
|||
|
||||
count = 0
|
||||
while count < num_perturbations:
|
||||
ith_masks = masks[:, count:min(count+self._num_per_eval, num_perturbations)]
|
||||
ith_masks = masks[:, count:min(count+self._perturbation_per_eval, num_perturbations)]
|
||||
actual_num_eval = ith_masks.shape[1]
|
||||
num_samples = batch_size * actual_num_eval
|
||||
occluded_inputs = self._ablation(inputs_np, reference, ith_masks)
|
||||
occluded_inputs = occluded_inputs.reshape((-1, *inputs_np.shape[1:]))
|
||||
targets_repeat = np.repeat(targets_np, repeats=actual_num_eval, axis=0)
|
||||
occluded_outputs = model(
|
||||
occluded_outputs = full_network(
|
||||
ms.Tensor(occluded_inputs, ms.float32)).asnumpy()[np.arange(num_samples), targets_repeat]
|
||||
original_outputs_repeat = np.repeat(original_outputs, repeats=actual_num_eval, axis=0)
|
||||
outputs_diff = original_outputs_repeat - occluded_outputs
|
||||
|
|
|
@ -19,7 +19,6 @@ from mindspore.train._utils import check_value_type
|
|||
from mindspore.nn import Cell
|
||||
|
||||
from ..attribution import Attribution
|
||||
from ...._operators import softmax
|
||||
|
||||
|
||||
class PerturbationAttribution(Attribution):
|
||||
|
@ -31,8 +30,13 @@ class PerturbationAttribution(Attribution):
|
|||
|
||||
def __init__(self,
|
||||
network,
|
||||
activation_fn=softmax(),
|
||||
activation_fn,
|
||||
perturbation_per_eval,
|
||||
):
|
||||
super(PerturbationAttribution, self).__init__(network)
|
||||
check_value_type("activation_fn", activation_fn, Cell)
|
||||
self._activation_fn = activation_fn
|
||||
check_value_type('perturbation_per_eval', perturbation_per_eval, int)
|
||||
if perturbation_per_eval <= 0:
|
||||
raise ValueError('Argument perturbation_per_eval should be a positive integer.')
|
||||
self._perturbation_per_eval = perturbation_per_eval
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
# ============================================================================
|
||||
"""RISE."""
|
||||
import math
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.ops.operations import Concat
|
||||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.train._utils import check_value_type
|
||||
|
||||
from .perturbation import PerturbationAttribution
|
||||
|
@ -36,41 +37,57 @@ class RISE(PerturbationAttribution):
|
|||
node of interest:
|
||||
|
||||
.. math::
|
||||
E_{RISE}(I, f)_c = \sum_{i}f_c(I\odot M_i) M_i
|
||||
attribution = \sum_{i}f_c(I\odot M_i) M_i
|
||||
|
||||
For more details, please refer to the original paper via: `RISE <https://arxiv.org/abs/1806.07421>`_.
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
activation_fn (Cell, optional): The activation layer that transforms logits to prediction probabilities. For
|
||||
activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For
|
||||
single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks,
|
||||
`nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as
|
||||
when combining this function with network, the final output is the probability of the input.
|
||||
Default: `nn.Softmax`.
|
||||
perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the
|
||||
perturbed samples. Default: 32.
|
||||
perturbed samples. Within the memory capacity, usually the larger this number is, the faster the
|
||||
explanation is obtained. Default: 32.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
- **targets** (Tensor, int) - The labels of interest to be explained. When `targets` is an integer,
|
||||
all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it
|
||||
should be of shape :math:`(N, l)` (l being the number of labels for each sample) or :math:`(N,)` :math:`()`.
|
||||
|
||||
Outputs:
|
||||
Tensor, a 4D tensor of shape :math:`(N, ?, H, W)`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import RISE
|
||||
>>> from mindspore.nn import Sigmoid
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # init RISE with a trained network
|
||||
>>> net = resnet50(10) # please refer to model_zoo
|
||||
>>> # prepare your network and load the trained checkpoint file, e.g., resnet50.
|
||||
>>> network = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # init RISE with specified activation function
|
||||
>>> rise = RISE(net, activation_fn=Sigmoid())
|
||||
"""
|
||||
>>> load_param_into_net(network, param_dict)
|
||||
>>> # initialize RISE explainer with the pretrained model and activation function
|
||||
>>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities
|
||||
>>> rise = RISE(network, activation_fn=activation_fn)
|
||||
>>> # given an instance of RISE, saliency map can be generate
|
||||
>>> inputs = ms.Tensor(np.random.rand(2, 3, 224, 224), ms.float32)
|
||||
>>> # when `targets` is an integer
|
||||
>>> targets = 5
|
||||
>>> saliency = rise(inputs, targets)
|
||||
>>> # `targets` can also be a 2D tensor
|
||||
>>> targets = ms.Tensor([[5], [1]])
|
||||
>>> saliency = rise(inputs, targets)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
network,
|
||||
activation_fn=nn.Softmax(),
|
||||
activation_fn,
|
||||
perturbation_per_eval=32):
|
||||
super(RISE, self).__init__(network, activation_fn)
|
||||
check_value_type('perturbation_per-eval', perturbation_per_eval, int)
|
||||
if perturbation_per_eval <= 0:
|
||||
raise ValueError('perturbation_per_eval should be postive integer.')
|
||||
self._perturbation_per_eval = perturbation_per_eval
|
||||
super(RISE, self).__init__(network, activation_fn, perturbation_per_eval)
|
||||
|
||||
self._num_masks = 6000 # number of masks to be sampled
|
||||
self._mask_probability = 0.2 # ratio of inputs to be masked
|
||||
|
@ -93,47 +110,26 @@ class RISE(PerturbationAttribution):
|
|||
self._resize_mode)
|
||||
|
||||
# Pack operator not available for GPU, thus transfer to numpy first
|
||||
upsample_np = upsample.asnumpy()
|
||||
masks_lst = []
|
||||
for sample in upsample_np:
|
||||
shift_x = np.random.randint(0, mask_size[0] + 1)
|
||||
shift_y = np.random.randint(0, mask_size[1] + 1)
|
||||
for sample in upsample:
|
||||
shift_x = random.randint(0, mask_size[0])
|
||||
shift_y = random.randint(0, mask_size[1])
|
||||
masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width])
|
||||
masks = op.Tensor(np.array(masks_lst), data.dtype)
|
||||
|
||||
concat = Concat()
|
||||
masks = concat(tuple(masks_lst))
|
||||
masks = op.reshape(masks, (batch_size, -1, height, width))
|
||||
return masks
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
"""
|
||||
Generates attribution maps for inputs.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
targets (int, Tensor): The labels of interest to be explained. When `targets` is an integer,
|
||||
all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it
|
||||
should be of shape :math:`(N, ?)` or :math:`(N,)` :math:`()`.
|
||||
|
||||
Returns:
|
||||
Tensor, a 4D tensor of shape :math:`(N, ?, H, W)` or :math:`(N, 1, H, W)`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import numpy as np
|
||||
>>> # given an instance of RISE, saliency map can be generate
|
||||
>>> inputs = ms.Tensor(np.random.rand(2, 3, 224, 224), ms.float32)
|
||||
>>> # when `targets` is an integer
|
||||
>>> targets = 5
|
||||
>>> saliency = rise(inputs, targets)
|
||||
>>> # `targets` can also be a tensor
|
||||
>>> targets = ms.Tensor([[5], [1]])
|
||||
>>> saliency = rise(inputs, targets)
|
||||
"""
|
||||
"""Generates attribution maps for inputs."""
|
||||
self._verify_data(inputs, targets)
|
||||
height, width = inputs.shape[2], inputs.shape[3]
|
||||
|
||||
batch_size = inputs.shape[0]
|
||||
|
||||
if self._num_classes is None:
|
||||
logits = self.model(inputs)
|
||||
logits = self.network(inputs)
|
||||
num_classes = logits.shape[1]
|
||||
self._num_classes = num_classes
|
||||
|
||||
|
@ -151,7 +147,7 @@ class RISE(PerturbationAttribution):
|
|||
masks = self._generate_masks(data, bs)
|
||||
|
||||
masked_input = masks * data + (1 - masks) * bg_data
|
||||
weights = self._activation_fn(self.model(masked_input))
|
||||
weights = self._activation_fn(self.network(masked_input))
|
||||
while len(weights.shape) > 2:
|
||||
weights = op.mean(weights, axis=2)
|
||||
weights = op.reshape(weights,
|
||||
|
|
|
@ -28,19 +28,19 @@ class Attribution:
|
|||
The explainers which explanation through attributing the relevance scores should inherit this class.
|
||||
|
||||
Args:
|
||||
network (nn.Cell): The black-box model to explanation.
|
||||
network (nn.Cell): The black-box model to be explained.
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
check_value_type("network", network, nn.Cell)
|
||||
self._model = network
|
||||
self._model.set_train(False)
|
||||
self._model.set_grad(False)
|
||||
self._network = network
|
||||
self._network.set_train(False)
|
||||
self._network.set_grad(False)
|
||||
|
||||
@staticmethod
|
||||
def _verify_model(model):
|
||||
def _verify_network(network):
|
||||
"""Verify the input `network` for __init__ function."""
|
||||
if not isinstance(model, nn.Cell):
|
||||
if not isinstance(network, nn.Cell):
|
||||
raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.")
|
||||
|
||||
__call__: Callable
|
||||
|
@ -57,9 +57,9 @@ class Attribution:
|
|||
"""
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
def network(self):
|
||||
"""Return the model."""
|
||||
return self._model
|
||||
return self._network
|
||||
|
||||
@staticmethod
|
||||
def _verify_data(inputs, targets):
|
||||
|
|
Loading…
Reference in New Issue