!8876 add feature/explainer/ClassSensitivity, Occlusion, Robustness
From: @yuhanshi Reviewed-by: Signed-off-by:
This commit is contained in:
commit
452cb0dd4e
|
@ -13,19 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Packaged operations based on MindSpore."""
|
||||
from typing import List, Tuple, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
import mindspore.ops.operations as op
|
||||
from mindspore import nn
|
||||
|
||||
_Axis = Union[int, Tuple[int, ...], List[int]]
|
||||
_Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]]
|
||||
_Number = Union[int, float, np.int, np.float]
|
||||
_Shape = Union[int, Tuple[int, ...]]
|
||||
Tensor = mindspore.Tensor
|
||||
|
||||
__all__ = [
|
||||
'absolute',
|
||||
|
@ -41,6 +28,7 @@ __all__ = [
|
|||
'mean',
|
||||
'mul',
|
||||
'sort',
|
||||
'sqrt',
|
||||
'squeeze',
|
||||
'tile',
|
||||
'reshape',
|
||||
|
@ -51,6 +39,20 @@ __all__ = [
|
|||
'summation'
|
||||
]
|
||||
|
||||
from typing import List, Tuple, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import nn
|
||||
import mindspore.ops.operations as op
|
||||
|
||||
_Axis = Union[int, Tuple[int, ...], List[int]]
|
||||
_Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]]
|
||||
_Number = Union[int, float, np.int, np.float]
|
||||
_Shape = Union[int, Tuple[int, ...]]
|
||||
Tensor = mindspore.Tensor
|
||||
|
||||
|
||||
def absolute(inputs: Tensor) -> Tensor:
|
||||
"""Get the absolute value of a tensor value."""
|
||||
|
|
|
@ -33,11 +33,10 @@ from mindspore.train._utils import check_value_type
|
|||
from mindspore.train.summary._summary_adapter import _convert_image_format
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from mindspore.train.summary_pb2 import Explain
|
||||
|
||||
from .benchmark import Localization
|
||||
from .benchmark._attribution.metric import AttributionMetric
|
||||
from .explanation import RISE
|
||||
from .explanation._attribution._attribution import Attribution
|
||||
from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric
|
||||
from .explanation._attribution.attribution import Attribution
|
||||
|
||||
# datafile directory names
|
||||
_DATAFILE_DIRNAME_PREFIX = "_explain_"
|
||||
|
@ -293,7 +292,8 @@ class ExplainRunner:
|
|||
benchmark.benchmark_method = bench.__class__.__name__
|
||||
|
||||
benchmark.total_score = bench.performance
|
||||
benchmark.label_score.extend(bench.class_performances)
|
||||
if isinstance(bench, LabelSensitiveMetric):
|
||||
benchmark.label_score.extend(bench.class_performances)
|
||||
|
||||
print(spacer.format("Finish running and writing explanation and benchmark data for {}. "
|
||||
"Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start)))
|
||||
|
@ -603,7 +603,6 @@ class ExplainRunner:
|
|||
Args:
|
||||
next_element (Tuple): Data of one step
|
||||
explainer (`_Attribution`): An Attribution object to generate saliency maps.
|
||||
imageid_labels (dict): A dict that maps the image_id and its union labels.
|
||||
"""
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
for idx, inp in enumerate(inputs):
|
||||
|
@ -615,10 +614,22 @@ class ExplainRunner:
|
|||
if label in labels[idx]:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
|
||||
saliency=saliency)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
benchmarker.aggregate(res, label)
|
||||
else:
|
||||
elif isinstance(benchmarker, LabelSensitiveMetric):
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
benchmarker.aggregate(res, label)
|
||||
elif isinstance(benchmarker, LabelAgnosticMetric):
|
||||
res = benchmarker.evaluate(explainer, inp)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
benchmarker.aggregate(res)
|
||||
else:
|
||||
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
|
||||
'receive {}'.format(type(benchmarker)))
|
||||
|
||||
def _save_original_image(self, sample_id: int, image):
|
||||
"""Save an image to summary directory."""
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
__all__ = [
|
||||
'ForwardProbe',
|
||||
'abs_max',
|
||||
'calc_auc',
|
||||
'calc_correlation',
|
||||
'format_tensor_to_ndarray',
|
||||
|
@ -29,7 +30,6 @@ __all__ = [
|
|||
]
|
||||
|
||||
from typing import Tuple, Union
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -43,6 +43,21 @@ _Module = nn.Cell
|
|||
_Tensor = ms.Tensor
|
||||
|
||||
|
||||
def abs_max(gradients):
|
||||
"""
|
||||
Transform gradients to saliency through abs then take max along channels.
|
||||
|
||||
Args:
|
||||
gradients (_Tensor): Gradients which will be transformed to saliency map.
|
||||
|
||||
Returns:
|
||||
_Tensor, saliency map integrated from gradients.
|
||||
"""
|
||||
gradients = op.Abs()(gradients)
|
||||
saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1)
|
||||
return saliency
|
||||
|
||||
|
||||
def generate_one_hot(indices, depth):
|
||||
r"""
|
||||
Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0
|
||||
|
@ -96,7 +111,7 @@ def retrieve_layer_by_name(model: _Module, layer_name: str):
|
|||
- target_layer (_Module)
|
||||
|
||||
Raise:
|
||||
ValueError: is module with given layer_name is not found in the model,
|
||||
ValueError: if module with given layer_name is not found in the model,
|
||||
raise ValueError.
|
||||
|
||||
"""
|
||||
|
@ -201,23 +216,28 @@ def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray:
|
|||
|
||||
def calc_correlation(x: Union[ms.Tensor, np.ndarray],
|
||||
y: Union[ms.Tensor, np.ndarray]) -> float:
|
||||
"""Calculate Pearson correlation coefficient between two arrays. """
|
||||
"""Calculate Pearson correlation coefficient between two vectors."""
|
||||
x = format_tensor_to_ndarray(x)
|
||||
y = format_tensor_to_ndarray(y)
|
||||
faithfulness = -np.corrcoef(x, y)[0, 1]
|
||||
if math.isnan(faithfulness):
|
||||
|
||||
if len(x.shape) > 1 or len(y.shape) > 1:
|
||||
raise ValueError('"calc_correlation" only support 1-dim vectors currently, but get shape {} and {}.'
|
||||
.format(len(x.shape), len(y.shape)))
|
||||
|
||||
if np.all(x == 0) or np.all(y == 0):
|
||||
return np.float(0)
|
||||
faithfulness = -np.corrcoef(x, y)[0, 1]
|
||||
return faithfulness
|
||||
|
||||
|
||||
def calc_auc(x: _Array) -> float:
|
||||
def calc_auc(x: _Array) -> _Array:
|
||||
"""Calculate the Aera under Curve."""
|
||||
# take mean for multiple patches if the model is fully convolutional model
|
||||
if len(x.shape) == 4:
|
||||
x = np.mean(np.mean(x, axis=2), axis=3)
|
||||
|
||||
auc = (x.sum() - x[0] - x[-1]) / len(x)
|
||||
return float(auc)
|
||||
return auc
|
||||
|
||||
|
||||
def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
|
||||
|
@ -235,13 +255,17 @@ def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
|
|||
rank_pixels(x, descending=False)
|
||||
>> np.array([[3, 2, 0], [4, 5, 1]])
|
||||
"""
|
||||
if len(inputs.shape) != 2:
|
||||
raise ValueError('Only support 2D array currently')
|
||||
flatten_saliency = inputs.reshape(-1)
|
||||
if len(inputs.shape) < 2 or len(inputs.shape) > 3:
|
||||
raise ValueError('Only support 2D or 3D inputs currently.')
|
||||
|
||||
batch_size = inputs.shape[0]
|
||||
flatten_saliency = inputs.reshape(batch_size, -1)
|
||||
factor = -1 if descending else 1
|
||||
sorted_arg = np.argsort(factor * flatten_saliency, axis=0)
|
||||
sorted_arg = np.argsort(factor * flatten_saliency, axis=1)
|
||||
flatten_rank = np.zeros_like(sorted_arg)
|
||||
flatten_rank[sorted_arg] = np.arange(0, sorted_arg.shape[0])
|
||||
arange = np.arange(flatten_saliency.shape[1])
|
||||
for i in range(batch_size):
|
||||
flatten_rank[i][sorted_arg[i]] = arange
|
||||
rank_map = flatten_rank.reshape(inputs.shape)
|
||||
return rank_map
|
||||
|
||||
|
|
|
@ -14,10 +14,14 @@
|
|||
# ============================================================================
|
||||
"""Predefined XAI metrics."""
|
||||
|
||||
from ._attribution.class_sensitivity import ClassSensitivity
|
||||
from ._attribution.faithfulness import Faithfulness
|
||||
from ._attribution.localization import Localization
|
||||
from ._attribution.robustness import Robustness
|
||||
|
||||
__all__ = [
|
||||
"ClassSensitivity",
|
||||
"Faithfulness",
|
||||
"Localization"
|
||||
"Localization",
|
||||
"Robustness"
|
||||
]
|
||||
|
|
|
@ -13,11 +13,3 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Predefined XAI metrics"""
|
||||
|
||||
from .faithfulness import Faithfulness
|
||||
from .localization import Localization
|
||||
|
||||
__all__ = [
|
||||
"Faithfulness",
|
||||
"Localization"
|
||||
]
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Class Sensitivity."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from .metric import LabelAgnosticMetric
|
||||
from ... import _operators as ops
|
||||
from ...explanation._attribution.attribution import Attribution
|
||||
from ..._utils import calc_correlation
|
||||
|
||||
|
||||
class ClassSensitivity(LabelAgnosticMetric):
|
||||
r"""
|
||||
Class sensitivity metric used to evaluate attribution-based explanations.
|
||||
|
||||
Reasonable atrribution-based explainers are expected to generate distinct saliency maps for different labels,
|
||||
especially for labels of highest confidence and low confidence. Class sensitivity evaluates the explainer through
|
||||
computing the correlation between saliency maps of highest-confidence and lowest-confidence labels. Explainer with
|
||||
better class sensitivity will receive lower correlation score. To make the evaluation results intuitive, the
|
||||
returned score will take negative on correlation and normalize.
|
||||
|
||||
"""
|
||||
|
||||
def evaluate(self, explainer: Attribution, inputs: Tensor) -> np.ndarray:
|
||||
"""
|
||||
Evaluate class sensitivity on a single data sample.
|
||||
|
||||
Args:
|
||||
explainer (Attribution): The explainer to be evaluated, see `mindspore.explainer.explanation`.
|
||||
inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, 1D array of shape :math:`(N,)`, result of class sensitivity evaluated on `explainer`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> gradient = Gradient()
|
||||
>>> x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> class_sensitivity = ClassSensitivity()
|
||||
>>> res = class_sensitivity.evaluate(gradient, x)
|
||||
"""
|
||||
self._check_evaluate_param(explainer, inputs)
|
||||
|
||||
outputs = explainer.model(inputs)
|
||||
|
||||
max_confidence_label = ops.argmax(outputs)
|
||||
min_confidence_label = ops.argmin(outputs)
|
||||
|
||||
max_confidence_saliency = explainer(inputs, max_confidence_label).asnumpy()
|
||||
min_confidence_saliency = explainer(inputs, min_confidence_label).asnumpy()
|
||||
|
||||
correlations = []
|
||||
for i in range(inputs.shape[0]):
|
||||
correlation = calc_correlation(max_confidence_saliency[i].reshape(-1),
|
||||
min_confidence_saliency[i].reshape(-1))
|
||||
normalized_correlation = (-correlation + 1) / 2
|
||||
correlations.append(normalized_correlation)
|
||||
return np.array(correlations, np.float)
|
|
@ -12,21 +12,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Faithfulness"""
|
||||
import math
|
||||
from typing import Callable, Optional, Union, Tuple
|
||||
"""Faithfulness."""
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
|
||||
from mindspore import log
|
||||
import mindspore as ms
|
||||
from mindspore.train._utils import check_value_type
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as op
|
||||
from .metric import AttributionMetric
|
||||
from ..._utils import calc_correlation, calc_auc, format_tensor_to_ndarray, rank_pixels
|
||||
from ...explanation._attribution._attribution import Attribution as _Attribution
|
||||
from .metric import LabelSensitiveMetric
|
||||
from ..._utils import calc_auc, format_tensor_to_ndarray
|
||||
from ...explanation._attribution import Attribution as _Attribution
|
||||
from ...explanation._attribution._perturbation.replacement import Constant, GaussianBlur
|
||||
from ...explanation._attribution._perturbation.ablation import AblationWithSaliency
|
||||
|
||||
_Array = np.ndarray
|
||||
_Explainer = Union[_Attribution, Callable]
|
||||
|
@ -36,189 +34,19 @@ _Module = nn.Cell
|
|||
|
||||
def _calc_feature_importance(saliency: _Array, masks: _Array) -> _Array:
|
||||
"""Calculate feature important w.r.t given masks."""
|
||||
feature_importance = []
|
||||
num_perturbations = masks.shape[0]
|
||||
for i in range(num_perturbations):
|
||||
patch_feature_importance = saliency[masks[i]].sum() / masks[i].sum()
|
||||
feature_importance.append(patch_feature_importance)
|
||||
feature_importance = np.array(feature_importance, dtype=np.float32)
|
||||
if saliency.shape[1] < masks.shape[2]:
|
||||
saliency = np.repeat(saliency, repeats=masks.shape[2], axis=1)
|
||||
|
||||
batch_size = masks.shape[0]
|
||||
num_perturbations = masks.shape[1]
|
||||
saliency = np.repeat(saliency, repeats=num_perturbations, axis=0)
|
||||
saliency = saliency.reshape([batch_size, num_perturbations, -1])
|
||||
masks = masks.reshape([batch_size, num_perturbations, -1])
|
||||
feature_importance = saliency * masks
|
||||
feature_importance = feature_importance.sum(-1) / masks.sum(-1)
|
||||
return feature_importance
|
||||
|
||||
|
||||
class _BaseReplacement:
|
||||
"""
|
||||
Base class of generator for generating different replacement for perturbations.
|
||||
|
||||
Args:
|
||||
kwargs: Optional args for generating replacement. Derived class need to
|
||||
add necessary arg names and default value to '_necessary_args'.
|
||||
If the argument has no default value, the value should be set to
|
||||
'EMPTY' to mark the required args. Initializing an object will
|
||||
check the given kwargs w.r.t '_necessary_args'.
|
||||
|
||||
Raise:
|
||||
ValueError: Raise when provided kwargs not contain necessary arg names with 'EMPTY' mark.
|
||||
"""
|
||||
_necessary_args = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._replace_args = self._necessary_args.copy()
|
||||
for key, value in self._replace_args.items():
|
||||
if key in kwargs.keys():
|
||||
self._replace_args[key] = kwargs[key]
|
||||
elif key not in kwargs.keys() and value == 'EMPTY':
|
||||
raise ValueError(f"Missing keyword arg {key} for {self.__class__.__name__}.")
|
||||
|
||||
__call__: Callable
|
||||
"""
|
||||
Generate replacement for perturbations. Derived class should overwrite this
|
||||
function to generate different replacement for perturbing.
|
||||
|
||||
Args:
|
||||
inputs (_Array): Array to be perturb.
|
||||
|
||||
Returns:
|
||||
- replacement (_Array): Array to provide alternative pixels for every
|
||||
position in the given
|
||||
inputs. The returned array should have same shape as inputs.
|
||||
"""
|
||||
|
||||
|
||||
class Constant(_BaseReplacement):
|
||||
""" Generator to provide constant-value replacement for perturbations """
|
||||
_necessary_args = {'base_value': 'EMPTY'}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
replacement = np.ones_like(inputs, dtype=np.float32)
|
||||
replacement *= self._replace_args['base_value']
|
||||
return replacement
|
||||
|
||||
|
||||
class GaussianBlur(_BaseReplacement):
|
||||
""" Generator to provided gaussian blurred inputs for perturbation. """
|
||||
_necessary_args = {'sigma': 0.7}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
sigma = self._replace_args['sigma']
|
||||
replacement = gaussian_filter(inputs, sigma=sigma)
|
||||
return replacement
|
||||
|
||||
|
||||
class Perturb:
|
||||
"""
|
||||
Perturbation generator to generate perturbations for a given array.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_mode (str): specify perturbing mode, through deleting or
|
||||
inserting pixels. Current support: ['Deletion', 'Insertion'].
|
||||
is_accumulate (bool): whether to accumulate the former perturbations to
|
||||
the later perturbations.
|
||||
perturb_pixel_per_step (int, optional): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (int, optional): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perturb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_mode: str,
|
||||
is_accumulate: bool,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None):
|
||||
self._perturb_percent = perturb_percent
|
||||
self._perturb_mode = perturb_mode
|
||||
self._pixel_per_step = perturb_pixel_per_step
|
||||
self._num_perturbations = num_perturbations
|
||||
self._is_accumulate = is_accumulate
|
||||
|
||||
@staticmethod
|
||||
def _assign(x: _Array, y: _Array, masks: _Array):
|
||||
"""Assign values to perturb pixels on perturbations."""
|
||||
check_value_type("masks dtype", masks.dtype, type(np.dtype(bool)))
|
||||
for i in range(x.shape[0]):
|
||||
x[i][:, masks[i]] = y[:, masks[i]]
|
||||
|
||||
def _generate_mask(self, saliency_rank: _Array) -> _Array:
|
||||
"""Generate mask for perturbations based on given saliency ranks."""
|
||||
if len(saliency_rank.shape) != 2:
|
||||
raise ValueError(f'The param "saliency_rank" should be 2-dim, but receive {len(saliency_rank.shape)}.')
|
||||
|
||||
num_pixels = saliency_rank.shape[0] * saliency_rank.shape[1]
|
||||
if self._pixel_per_step:
|
||||
pixel_per_step = self._pixel_per_step
|
||||
num_perturbations = math.floor(
|
||||
num_pixels * self._perturb_percent / self._pixel_per_step)
|
||||
elif self._num_perturbations:
|
||||
pixel_per_step = math.floor(
|
||||
num_pixels * self._perturb_percent / self._num_perturbations)
|
||||
num_perturbations = self._num_perturbations
|
||||
else:
|
||||
raise ValueError("Must provide either pixel_per_step or num_perturbations.")
|
||||
|
||||
masks = np.zeros(
|
||||
(num_perturbations, saliency_rank.shape[0], saliency_rank.shape[1]),
|
||||
dtype=np.bool)
|
||||
low_bound = 0
|
||||
up_bound = low_bound + pixel_per_step
|
||||
factor = 0 if self._is_accumulate else 1
|
||||
|
||||
for i in range(num_perturbations):
|
||||
masks[i, ((saliency_rank >= low_bound)
|
||||
& (saliency_rank < up_bound))] = True
|
||||
low_bound = up_bound * factor
|
||||
up_bound += pixel_per_step
|
||||
|
||||
if len(masks.shape) == 3:
|
||||
return masks
|
||||
raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect 3-dim.')
|
||||
|
||||
def __call__(self,
|
||||
inputs: _Array,
|
||||
saliency: _Array,
|
||||
reference: _Array,
|
||||
return_mask: bool = False,
|
||||
) -> Union[_Array, Tuple[_Array, ...]]:
|
||||
"""
|
||||
Generate perturbations of given array.
|
||||
|
||||
Args:
|
||||
inputs (_Array): input array to perturb
|
||||
saliency (_Array): saliency map
|
||||
return_mask (bool): whether return the mask for generating
|
||||
the perturbation. The mask can be used to calculate
|
||||
average feature importance of pixels perturbed at each step.
|
||||
|
||||
Return:
|
||||
perturbations (_Array)
|
||||
masks (_Array): return when return_mask is set to True.
|
||||
"""
|
||||
if not np.array_equal(inputs.shape, reference.shape):
|
||||
raise ValueError('reference must have the same shape as inputs.')
|
||||
|
||||
saliency_rank = rank_pixels(saliency, descending=True)
|
||||
masks = self._generate_mask(saliency_rank)
|
||||
num_perturbations = masks.shape[0]
|
||||
|
||||
if self._perturb_mode == 'Insertion':
|
||||
inputs, reference = reference, inputs
|
||||
|
||||
perturbations = np.tile(
|
||||
inputs, (num_perturbations, *[1] * len(inputs.shape)))
|
||||
|
||||
Perturb._assign(perturbations, reference, masks)
|
||||
|
||||
if return_mask:
|
||||
return perturbations, masks
|
||||
return perturbations
|
||||
|
||||
|
||||
class _FaithfulnessHelper:
|
||||
"""Base class for faithfulness calculator."""
|
||||
_support = [Constant, GaussianBlur]
|
||||
|
@ -240,27 +68,15 @@ class _FaithfulnessHelper:
|
|||
raise ValueError(
|
||||
'The param "perturb_method" should be one of {}.'.format([x.__name__ for x in self._support]))
|
||||
|
||||
self._perturb = Perturb(perturb_percent=perturb_percent,
|
||||
perturb_mode=perturb_mode,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=is_accumulate)
|
||||
self._ablation = AblationWithSaliency(perturb_mode=perturb_mode,
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=is_accumulate)
|
||||
|
||||
calc_faithfulness: Callable
|
||||
"""
|
||||
Method used to calculate faithfulness for given inputs, target label,
|
||||
saliency. Derive class should implement this method.
|
||||
|
||||
Args:
|
||||
inputs (_Array): sample to calculate faithfulness score
|
||||
model (_Module): model to explanation
|
||||
targets (_Label): label to explanation on.
|
||||
saliency (_Array): Saliency map of given inputs and targets from the
|
||||
explainer.
|
||||
|
||||
Return:
|
||||
- faithfulness (float): faithfulness score
|
||||
"""
|
||||
def calc_faithfulness(self, inputs, model, targets, saliency):
|
||||
"""Calc faithfulness."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NaiveFaithfulness(_FaithfulnessHelper):
|
||||
|
@ -304,14 +120,13 @@ class NaiveFaithfulness(_FaithfulnessHelper):
|
|||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
super(NaiveFaithfulness, self).__init__(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_mode='Deletion',
|
||||
perturb_method=perturb_method,
|
||||
is_accumulate=is_accumulate,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
**kwargs)
|
||||
super().__init__(perturb_percent=perturb_percent,
|
||||
perturb_mode='Deletion',
|
||||
perturb_method=perturb_method,
|
||||
is_accumulate=is_accumulate,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
**kwargs)
|
||||
|
||||
def calc_faithfulness(self,
|
||||
inputs: _Array,
|
||||
|
@ -336,16 +151,21 @@ class NaiveFaithfulness(_FaithfulnessHelper):
|
|||
log.warning("The saliency map is zero everywhere. The correlation will be set to zero.")
|
||||
correlation = 0
|
||||
return np.array([correlation], np.float)
|
||||
|
||||
batch_size = inputs.shape[0]
|
||||
reference = self._get_reference(inputs)
|
||||
perturbations, masks = self._perturb(
|
||||
inputs, saliency, reference, return_mask=True)
|
||||
masks = self._ablation.generate_mask(saliency, inputs.shape[1])
|
||||
perturbations = self._ablation(inputs, reference, masks)
|
||||
feature_importance = _calc_feature_importance(saliency, masks)
|
||||
|
||||
perturbations = perturbations.reshape(-1, *perturbations.shape[2:])
|
||||
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
|
||||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
predictions = model(perturbations)[:, targets].asnumpy()
|
||||
predictions = predictions.reshape(*feature_importance.shape)
|
||||
|
||||
faithfulness = calc_correlation(feature_importance, predictions)
|
||||
return np.array([faithfulness], np.float)
|
||||
faithfulness = -np.corrcoef(feature_importance, predictions)
|
||||
faithfulness = np.diag(faithfulness[:batch_size, batch_size:])
|
||||
return faithfulness
|
||||
|
||||
|
||||
class DeletionAUC(_FaithfulnessHelper):
|
||||
|
@ -385,20 +205,19 @@ class DeletionAUC(_FaithfulnessHelper):
|
|||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
super(DeletionAUC, self).__init__(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_mode='Deletion',
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=True,
|
||||
**kwargs)
|
||||
super().__init__(perturb_percent=perturb_percent,
|
||||
perturb_mode='Deletion',
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=True,
|
||||
**kwargs)
|
||||
|
||||
def calc_faithfulness(self,
|
||||
inputs: _Array,
|
||||
model: _Module,
|
||||
targets: _Label,
|
||||
saliency: _Array) -> np.ndarray:
|
||||
saliency: _Array) -> _Array:
|
||||
"""
|
||||
Calculate faithfulness through deletion AUC.
|
||||
|
||||
|
@ -414,14 +233,17 @@ class DeletionAUC(_FaithfulnessHelper):
|
|||
|
||||
"""
|
||||
reference = self._get_reference(inputs)
|
||||
perturbations = self._perturb(inputs, saliency, reference)
|
||||
masks = self._ablation.generate_mask(saliency, inputs.shape[1])
|
||||
perturbations = self._ablation(inputs, reference, masks)
|
||||
perturbations = perturbations.reshape(-1, *perturbations.shape[2:])
|
||||
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
|
||||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
input_tensor = op.ExpandDims()(ms.Tensor(inputs, ms.float32), 0)
|
||||
predictions = predictions.reshape((inputs.shape[0], -1))
|
||||
input_tensor = ms.Tensor(inputs, ms.float32)
|
||||
original_output = model(input_tensor).asnumpy()[:, targets]
|
||||
|
||||
auc = calc_auc(original_output - predictions)
|
||||
return np.array([1 - auc])
|
||||
auc = calc_auc(original_output.squeeze() - predictions.squeeze())
|
||||
return np.array([1 - auc], np.float)
|
||||
|
||||
|
||||
class InsertionAUC(_FaithfulnessHelper):
|
||||
|
@ -462,20 +284,19 @@ class InsertionAUC(_FaithfulnessHelper):
|
|||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
super(InsertionAUC, self).__init__(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_mode='Insertion',
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=True,
|
||||
**kwargs)
|
||||
super().__init__(perturb_percent=perturb_percent,
|
||||
perturb_mode='Insertion',
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=True,
|
||||
**kwargs)
|
||||
|
||||
def calc_faithfulness(self,
|
||||
inputs: _Array,
|
||||
model: _Module,
|
||||
targets: _Label,
|
||||
saliency: _Array) -> np.ndarray:
|
||||
saliency: _Array) -> _Array:
|
||||
"""
|
||||
Calculate faithfulness through insertion AUC.
|
||||
|
||||
|
@ -491,17 +312,21 @@ class InsertionAUC(_FaithfulnessHelper):
|
|||
|
||||
"""
|
||||
reference = self._get_reference(inputs)
|
||||
perturbations = self._perturb(inputs, saliency, reference)
|
||||
masks = self._ablation.generate_mask(saliency, inputs.shape[1])
|
||||
perturbations = self._ablation(inputs, reference, masks)
|
||||
perturbations = perturbations.reshape(-1, *perturbations.shape[2:])
|
||||
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
|
||||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
base_tensor = op.ExpandDims()(ms.Tensor(reference, ms.float32), 0)
|
||||
predictions = predictions.reshape((inputs.shape[0], -1))
|
||||
|
||||
base_tensor = ms.Tensor(reference, ms.float32)
|
||||
base_outputs = model(base_tensor).asnumpy()[:, targets]
|
||||
|
||||
auc = calc_auc(predictions - base_outputs)
|
||||
return np.array([auc])
|
||||
auc = calc_auc(predictions.squeeze() - base_outputs.squeeze())
|
||||
return np.array([auc], np.float)
|
||||
|
||||
|
||||
class Faithfulness(AttributionMetric):
|
||||
class Faithfulness(LabelSensitiveMetric):
|
||||
"""
|
||||
Provides evaluation on faithfulness on XAI explanations.
|
||||
|
||||
|
@ -604,10 +429,6 @@ class Faithfulness(AttributionMetric):
|
|||
inputs = format_tensor_to_ndarray(inputs)
|
||||
saliency = format_tensor_to_ndarray(saliency)
|
||||
|
||||
inputs = inputs.squeeze(axis=0)
|
||||
saliency = saliency.squeeze()
|
||||
if len(saliency.shape) != 2:
|
||||
raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape)))
|
||||
model = nn.SequentialCell([explainer.model, self._activation_fn])
|
||||
faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=model,
|
||||
targets=targets, saliency=saliency)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import numpy as np
|
||||
|
||||
from mindspore.train._utils import check_value_type
|
||||
from .metric import AttributionMetric
|
||||
from .metric import LabelSensitiveMetric
|
||||
from ..._operators import maximum, reshape, Tensor
|
||||
from ..._utils import format_tensor_to_ndarray
|
||||
|
||||
|
@ -37,7 +37,7 @@ def _mask_out_saliency(saliency, threshold):
|
|||
return mask_out
|
||||
|
||||
|
||||
class Localization(AttributionMetric):
|
||||
class Localization(LabelSensitiveMetric):
|
||||
r"""
|
||||
Provides evaluation on the localization capability of XAI methods.
|
||||
|
||||
|
|
|
@ -13,12 +13,20 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Base class for XAI metrics."""
|
||||
|
||||
import copy
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import log as logger
|
||||
from mindspore.train._utils import check_value_type
|
||||
from ..._operators import Tensor
|
||||
from ..._utils import format_tensor_to_ndarray
|
||||
from ...explanation._attribution._attribution import Attribution
|
||||
from ...explanation._attribution.attribution import Attribution
|
||||
|
||||
_Explainer = Attribution
|
||||
|
||||
|
||||
def verify_argument(inputs, arg_name):
|
||||
|
@ -46,8 +54,77 @@ def verify_targets(targets, num_labels):
|
|||
class AttributionMetric:
|
||||
"""Super class of XAI metric class used in classification scenarios."""
|
||||
|
||||
def __init__(self, num_labels=None):
|
||||
self._verify_params(num_labels)
|
||||
def __init__(self):
|
||||
self._explainer = None
|
||||
|
||||
evaluate: Callable
|
||||
"""
|
||||
This method evaluates the explainer on the given attribution and returns the evaluation results.
|
||||
Derived class should implement this method according to specific algorithms of the metric.
|
||||
"""
|
||||
|
||||
def _record_explainer(self, explainer: _Explainer):
|
||||
"""Record the explainer in current evaluation."""
|
||||
if self._explainer is None:
|
||||
self._explainer = explainer
|
||||
elif self._explainer is not explainer:
|
||||
logger.info('Provided explainer is not the same as previously evaluted one. Please reset the evaluated '
|
||||
'results. Previous explainer: %s, current explainer: %s', self._explainer, explainer)
|
||||
self._explainer = explainer
|
||||
|
||||
|
||||
class LabelAgnosticMetric(AttributionMetric):
|
||||
"""Super class add functions for label-agnostic metric."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._global_results = []
|
||||
|
||||
@property
|
||||
def performance(self) -> float:
|
||||
"""
|
||||
Return the average evaluation result.
|
||||
|
||||
Return:
|
||||
float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned.
|
||||
"""
|
||||
if not self._global_results:
|
||||
return 0.0
|
||||
results_sum = sum(self._global_results)
|
||||
count = len(self._global_results)
|
||||
return results_sum / count
|
||||
|
||||
def aggregate(self, result):
|
||||
"""Aggregate single evaluation result to global results."""
|
||||
if isinstance(result, float):
|
||||
self._global_results.append(result)
|
||||
elif isinstance(result, (ms.Tensor, np.ndarray)):
|
||||
result = format_tensor_to_ndarray(result)
|
||||
self._global_results.append(float(result))
|
||||
else:
|
||||
raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))
|
||||
|
||||
def get_results(self):
|
||||
"""Return the gloabl results."""
|
||||
return self._global_results.copy()
|
||||
|
||||
def reset(self):
|
||||
"""Reset global results."""
|
||||
self._global_results.clear()
|
||||
|
||||
def _check_evaluate_param(self, explainer, inputs):
|
||||
"""Check the evaluate parameters."""
|
||||
check_value_type('explainer', explainer, Attribution)
|
||||
self._record_explainer(explainer)
|
||||
verify_argument(inputs, 'inputs')
|
||||
|
||||
|
||||
class LabelSensitiveMetric(AttributionMetric):
|
||||
"""Super class add functions for label-sensitive metrics."""
|
||||
|
||||
def __init__(self, num_labels: int):
|
||||
super().__init__()
|
||||
LabelSensitiveMetric._verify_params(num_labels)
|
||||
self._num_labels = num_labels
|
||||
self._global_results = {i: [] for i in range(num_labels)}
|
||||
|
||||
|
@ -57,10 +134,6 @@ class AttributionMetric:
|
|||
if num_labels < 1:
|
||||
raise ValueError("Argument num_labels must be parsed with a integer > 0.")
|
||||
|
||||
def evaluate(self, explainer, inputs, targets, saliency=None):
|
||||
"""This function evaluates on a single sample and return the result."""
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, result, targets):
|
||||
"""Aggregates single result to global_results."""
|
||||
if isinstance(result, float):
|
||||
|
@ -120,11 +193,12 @@ class AttributionMetric:
|
|||
|
||||
def get_results(self):
|
||||
"""Global result of the metric can be return"""
|
||||
return self._global_results
|
||||
return copy.deepcopy(self._global_results)
|
||||
|
||||
def _check_evaluate_param(self, explainer, inputs, targets, saliency):
|
||||
"""Check the evaluate parameters."""
|
||||
check_value_type('explainer', explainer, Attribution)
|
||||
self._record_explainer(explainer)
|
||||
verify_argument(inputs, 'inputs')
|
||||
output = explainer.model(inputs)
|
||||
check_value_type("output of explainer model", output, Tensor)
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Robustness."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import log
|
||||
from .metric import LabelSensitiveMetric
|
||||
from ...explanation._attribution import Attribution
|
||||
from ...explanation._attribution._perturbation.replacement import RandomPerturb
|
||||
|
||||
_Array = np.ndarray
|
||||
_Label = Union[ms.Tensor, int]
|
||||
|
||||
|
||||
class Robustness(LabelSensitiveMetric):
|
||||
"""
|
||||
Robustness perturbs the inputs by adding random noise and choose the maximum sensitivity as evaluation score from
|
||||
the perturbations.
|
||||
|
||||
Args:
|
||||
num_labels (int): Number of classes in the dataset.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer.benchmark import Robustness
|
||||
>>> num_labels = 100
|
||||
>>> robustness = Robustness(num_labels)
|
||||
"""
|
||||
|
||||
def __init__(self, num_labels: int, activation_fn=nn.Softmax()):
|
||||
super().__init__(num_labels)
|
||||
|
||||
self._perturb = RandomPerturb()
|
||||
self._num_perturbations = 100 # number of perturbations used in evaluation
|
||||
self._threshold = 0.1 # threshold to generate perturbation
|
||||
self._activation_fn = activation_fn
|
||||
|
||||
def evaluate(self,
|
||||
explainer: Attribution,
|
||||
inputs: Tensor,
|
||||
targets: _Label,
|
||||
saliency: Optional[Tensor] = None
|
||||
) -> _Array:
|
||||
"""
|
||||
Evaluate robustness on single sample.
|
||||
|
||||
Note:
|
||||
Currently only single sample (:math:`N=1`) at each call is supported.
|
||||
|
||||
Args:
|
||||
explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`.
|
||||
inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If `targets` is a 1D tensor, its length should be the same as `inputs`.
|
||||
saliency (Tensor, optional): The saliency map to be evaluated, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
If it is None, the parsed `explainer` will generate the saliency map with `inputs` and `targets` and
|
||||
continue the evaluation. Default: None.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, 1D array of shape :math:`(N,)`, result of localization evaluated on `explainer`.
|
||||
|
||||
Raises:
|
||||
ValueError: If batch_size is larger than 1.
|
||||
|
||||
Examples:
|
||||
>>> # init an explainer, the network should contain the output activation function.
|
||||
>>> from mindspore.explainer.explanation import Gradient
|
||||
>>> from mindspore.explainer.benchmark import Robustness
|
||||
>>> gradient = Gradient(network)
|
||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> target_label = 5
|
||||
>>> robustness = Robustness(num_labels=10)
|
||||
>>> res = robustness.evaluate(gradient, input_x, target_label)
|
||||
"""
|
||||
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
if inputs.shape[0] > 1:
|
||||
raise ValueError('Robustness only support a sample each time, but receive {}'.format(inputs.shape[0]))
|
||||
|
||||
inputs_np = inputs.asnumpy()
|
||||
if isinstance(targets, int):
|
||||
targets = ms.Tensor(targets, ms.int32)
|
||||
if saliency is None:
|
||||
saliency = explainer(inputs, targets)
|
||||
saliency_np = saliency.asnumpy()
|
||||
norm = np.sqrt(np.sum(np.square(saliency_np), axis=tuple(range(1, len(saliency_np.shape)))))
|
||||
if norm == 0:
|
||||
log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.')
|
||||
return np.array([np.nan])
|
||||
|
||||
perturbations = []
|
||||
for sample in inputs_np:
|
||||
sample = np.expand_dims(sample, axis=0)
|
||||
perturbations_per_input = []
|
||||
for _ in range(self._num_perturbations):
|
||||
perturbation = self._perturb(sample)
|
||||
perturbations_per_input.append(perturbation)
|
||||
perturbations_per_input = np.vstack(perturbations_per_input)
|
||||
perturbations.append(perturbations_per_input)
|
||||
perturbations = np.stack(perturbations, axis=0)
|
||||
|
||||
perturbations = np.reshape(perturbations, (-1,) + inputs_np.shape[1:])
|
||||
perturbations = ms.Tensor(perturbations, ms.float32)
|
||||
|
||||
repeated_targets = np.repeat(targets.asnumpy(), repeats=self._num_perturbations, axis=0)
|
||||
repeated_targets = ms.Tensor(repeated_targets, ms.int32)
|
||||
saliency_of_perturbations = explainer(perturbations, repeated_targets)
|
||||
perturbations_saliency = saliency_of_perturbations.asnumpy()
|
||||
|
||||
repeated_saliency = np.repeat(saliency_np, repeats=self._num_perturbations, axis=0)
|
||||
|
||||
sensitivities = np.sum((repeated_saliency - perturbations_saliency) ** 2,
|
||||
axis=tuple(range(1, len(repeated_saliency.shape))))
|
||||
|
||||
max_sensitivity = np.max(sensitivities.reshape((norm.shape[0], -1)), axis=1) / norm
|
||||
robustness_res = 1 / np.exp(max_sensitivity)
|
||||
return robustness_res
|
|
@ -14,9 +14,10 @@
|
|||
# ============================================================================
|
||||
"""Predefined Attribution explainers."""
|
||||
|
||||
from ._attribution._backprop.gradcam import GradCAM
|
||||
from ._attribution._backprop.gradient import Gradient
|
||||
from ._attribution._backprop.gradcam import GradCAM
|
||||
from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop
|
||||
from ._attribution._perturbation.occlusion import Occlusion
|
||||
from ._attribution._perturbation.rise import RISE
|
||||
|
||||
__all__ = [
|
||||
|
@ -24,5 +25,6 @@ __all__ = [
|
|||
'Deconvolution',
|
||||
'GuidedBackprop',
|
||||
'GradCAM',
|
||||
'RISE'
|
||||
'Occlusion',
|
||||
'RISE',
|
||||
]
|
||||
|
|
|
@ -13,15 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Predefined Attribution explainers."""
|
||||
from ._backprop.gradcam import GradCAM
|
||||
from ._backprop.gradient import Gradient
|
||||
from ._backprop.modified_relu import Deconvolution, GuidedBackprop
|
||||
from ._perturbation.rise import RISE
|
||||
|
||||
from .attribution import Attribution
|
||||
|
||||
__all__ = [
|
||||
'Gradient',
|
||||
'Deconvolution',
|
||||
'GuidedBackprop',
|
||||
'GradCAM',
|
||||
'RISE'
|
||||
'Attribution'
|
||||
]
|
||||
|
|
|
@ -13,12 +13,3 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Backprop-base _attribution explainer."""
|
||||
|
||||
from .gradient import Gradient
|
||||
from .gradcam import GradCAM
|
||||
from .modified_relu import Deconvolution, GuidedBackprop
|
||||
|
||||
__all__ = ['Gradient',
|
||||
'GradCAM',
|
||||
'Deconvolution',
|
||||
'GuidedBackprop']
|
||||
|
|
|
@ -22,7 +22,6 @@ from .intermediate_layer import IntermediateLayerAttribution
|
|||
from ...._utils import ForwardProbe, retrieve_layer, unify_inputs, unify_targets
|
||||
|
||||
|
||||
|
||||
def _gradcam_aggregation(attributions):
|
||||
"""
|
||||
Aggregate the gradient and activation to get the final _attribution.
|
||||
|
@ -76,10 +75,7 @@ class GradCAM(IntermediateLayerAttribution):
|
|||
>>> gradcam = GradCAM(net, layer=layer_name)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
network,
|
||||
layer=""):
|
||||
def __init__(self, network, layer=""):
|
||||
super(GradCAM, self).__init__(network, layer)
|
||||
|
||||
self._saliency_cell = retrieve_layer(self._backward_model, target_layer=layer)
|
||||
|
|
|
@ -16,12 +16,11 @@
|
|||
from copy import deepcopy
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as op
|
||||
from mindspore.train._utils import check_value_type
|
||||
from ...._operators import reshape, sqrt, Tensor
|
||||
from .._attribution import Attribution
|
||||
from ..attribution import Attribution
|
||||
from .backprop_utils import compute_gradients
|
||||
from ...._utils import unify_inputs, unify_targets
|
||||
from ...._utils import abs_max, unify_inputs, unify_targets
|
||||
|
||||
|
||||
def _get_hook(bntype, cache):
|
||||
|
@ -41,16 +40,6 @@ def _get_hook(bntype, cache):
|
|||
return reset_gradient
|
||||
|
||||
|
||||
def _abs_max(gradients):
|
||||
"""
|
||||
Transform gradients to saliency through abs then take max along
|
||||
channels.
|
||||
"""
|
||||
gradients = op.Abs()(gradients)
|
||||
saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1)
|
||||
return saliency
|
||||
|
||||
|
||||
class Gradient(Attribution):
|
||||
r"""
|
||||
Provides Gradient explanation method.
|
||||
|
@ -85,8 +74,7 @@ class Gradient(Attribution):
|
|||
self._backward_model.set_grad(False)
|
||||
self._hook_bn()
|
||||
self._grad_op = compute_gradients
|
||||
self._aggregation_fn = _abs_max
|
||||
|
||||
self._aggregation_fn = abs_max
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
"""
|
||||
|
|
|
@ -13,7 +13,3 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" Perturbation-based _attribution explainer. """
|
||||
|
||||
from .rise import RISE
|
||||
|
||||
__all__ = ['RISE']
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Modules to ablate images."""
|
||||
|
||||
__all__ = [
|
||||
'Ablation',
|
||||
'AblationWithSaliency',
|
||||
]
|
||||
|
||||
import math
|
||||
from functools import reduce
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .replacement import Constant
|
||||
from ...._utils import rank_pixels
|
||||
|
||||
|
||||
class Ablation:
|
||||
"""Base class to ablate image based on given replacement."""
|
||||
|
||||
def __init__(self, perturb_mode: str):
|
||||
self._perturb_mode = perturb_mode
|
||||
|
||||
def __call__(self,
|
||||
inputs: np.array,
|
||||
reference: Union[np.array, float],
|
||||
masks: np.array
|
||||
) -> np.array:
|
||||
|
||||
"""
|
||||
Generate perturbations of given array.
|
||||
|
||||
Args:
|
||||
inputs (np.ndarray): Input array to perturb. The first dim of inputs is assumed to be the batch size, i.e.,
|
||||
number of samples.
|
||||
reference (np.ndarray or float): Array of values to replace the elements in the original inputs. The shape
|
||||
of reference must math the inputs. If scalar is provided, the perturbed elements will be assigned the
|
||||
given value..
|
||||
masks (np.ndarray): Several boolean array to mark the perturbed positions. True marks the pixels to be
|
||||
perturbed, otherwise the pixels will be kept. The shape of masks is assumed to be
|
||||
[batch_size, num_perturbations, inputs_shape[1:]].
|
||||
|
||||
Return:
|
||||
perturbations (np.ndarray)
|
||||
"""
|
||||
if isinstance(reference, float):
|
||||
reference = Constant(base_value=reference)(inputs)
|
||||
|
||||
if not np.array_equal(inputs.shape, reference.shape):
|
||||
raise ValueError('reference must have the same shape as inputs.')
|
||||
|
||||
num_perturbations = masks.shape[1]
|
||||
|
||||
if self._perturb_mode == 'Insertion':
|
||||
inputs, reference = reference, inputs
|
||||
|
||||
perturbations = np.repeat(inputs[:, None, :], num_perturbations, 1)
|
||||
reference = np.repeat(reference[:, None, :], num_perturbations, 1)
|
||||
Ablation._assign(perturbations, reference, masks)
|
||||
|
||||
return perturbations
|
||||
|
||||
@staticmethod
|
||||
def _assign(original_array: np.ndarray, replacement: np.ndarray, masks: np.ndarray):
|
||||
"""Assign values to perturb pixels on perturbations."""
|
||||
if masks.dtype != bool:
|
||||
raise TypeError('The param "masks" should be an array of bool, but receive {}'.format(masks.dtype))
|
||||
|
||||
if not np.array_equal(original_array.shape, masks.shape):
|
||||
raise ValueError('masks must have the shape {} same as [batch_size, num_perturbations, inputs.shape[1:],'
|
||||
'but receive {}.'.format(original_array.shape, masks.shape))
|
||||
|
||||
original_array[masks] = replacement[masks]
|
||||
|
||||
|
||||
class AblationWithSaliency(Ablation):
|
||||
"""
|
||||
Perturbation generator to generate perturbations for a given array.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_mode (str): specify perturbing mode, through deleting or
|
||||
inserting pixels. Current support: ['Deletion', 'Insertion'].
|
||||
is_accumulate (bool): whether to accumulate the former perturbations to
|
||||
the later perturbations.
|
||||
perturb_pixel_per_step (int, optional): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (int, optional): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perturb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_mode: str,
|
||||
perturb_percent: float = 1.0,
|
||||
is_accumulate: bool = False,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None):
|
||||
super().__init__(perturb_mode)
|
||||
self._perturb_percent = perturb_percent
|
||||
self._perturb_mode = perturb_mode
|
||||
self._pixel_per_step = perturb_pixel_per_step
|
||||
self._num_perturbations = num_perturbations
|
||||
self._is_accumulate = is_accumulate
|
||||
|
||||
def generate_mask(self,
|
||||
saliency: np.ndarray,
|
||||
num_channels: Optional[int] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate mask for perturbations based on given saliency ranks.
|
||||
|
||||
Args:
|
||||
saliency (np.ndarray): Perturbing masks will be generated based on the given saliency map. The shape of
|
||||
saliency is expected to be: [batch_size, optional(num_channels), *spatial_size]. If multi-channel
|
||||
saliency is provided, an averaged saliency will be taken to calculate pixel order in spatial dimension.
|
||||
num_channels (optional[int]): Number of channels of the input data. In order to match the shape of inputs,
|
||||
num_channels should be provided when input data have channels dimension, even if num_channel. If None is
|
||||
provided, the inputs is assumed to be no-channel data, and the generated mask will have no channel
|
||||
dimension. Default: None.
|
||||
|
||||
Return:
|
||||
mask (np.ndarray): boolen mask for generate perturbations.
|
||||
"""
|
||||
|
||||
batch_size = saliency.shape[0]
|
||||
expected_num_dim = len(saliency.shape) + 1
|
||||
has_channel = num_channels is not None
|
||||
num_channels = 1 if num_channels is None else num_channels
|
||||
|
||||
if has_channel:
|
||||
saliency = saliency.mean(axis=1)
|
||||
saliency_rank = rank_pixels(saliency, descending=True)
|
||||
|
||||
num_pixels = reduce(lambda x, y: x * y, saliency.shape[1:])
|
||||
|
||||
if self._pixel_per_step:
|
||||
pixel_per_step = self._pixel_per_step
|
||||
num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step)
|
||||
elif self._num_perturbations:
|
||||
pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations)
|
||||
num_perturbations = self._num_perturbations
|
||||
else:
|
||||
raise ValueError("Must provide either pixel_per_step or num_perturbations.")
|
||||
|
||||
masks = np.zeros((batch_size, num_perturbations, num_channels, saliency_rank.shape[1], saliency_rank.shape[2]),
|
||||
dtype=np.bool)
|
||||
|
||||
factor = 0 if self._is_accumulate else 1
|
||||
|
||||
for i in range(batch_size):
|
||||
low_bound = 0
|
||||
up_bound = low_bound + pixel_per_step
|
||||
for j in range(num_perturbations):
|
||||
masks[i, j, :, ((saliency_rank[i] >= low_bound) & (saliency_rank[i] < up_bound))] = True
|
||||
low_bound = up_bound + factor
|
||||
up_bound += pixel_per_step
|
||||
|
||||
masks = masks if has_channel else np.squeeze(masks, axis=2)
|
||||
|
||||
if len(masks.shape) == expected_num_dim:
|
||||
return masks
|
||||
raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect {expected_num_dim}-dim.')
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Occlusion explainer."""
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib.stride_tricks import as_strided
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
from .ablation import Ablation
|
||||
from .perturbation import PerturbationAttribution
|
||||
from .replacement import Constant
|
||||
from ...._utils import abs_max
|
||||
|
||||
_Array = np.ndarray
|
||||
_Label = Union[int, Tensor]
|
||||
|
||||
|
||||
def _generate_patches(array, window_size, stride):
|
||||
"""View as windows."""
|
||||
if not isinstance(array, np.ndarray):
|
||||
raise TypeError("`array` must be a numpy ndarray")
|
||||
|
||||
arr_shape = np.array(array.shape)
|
||||
window_size = np.array(window_size, dtype=arr_shape.dtype)
|
||||
|
||||
slices = tuple(slice(None, None, st) for st in stride)
|
||||
window_strides = np.array(array.strides)
|
||||
|
||||
indexing_strides = array[slices].strides
|
||||
win_indices_shape = (((np.array(array.shape) - np.array(window_size)) // np.array(stride)) + 1)
|
||||
|
||||
new_shape = tuple(list(win_indices_shape) + list(window_size))
|
||||
strides = tuple(list(indexing_strides) + list(window_strides))
|
||||
|
||||
patches = as_strided(array, shape=new_shape, strides=strides)
|
||||
return patches
|
||||
|
||||
|
||||
class Occlusion(PerturbationAttribution):
|
||||
r"""
|
||||
Occlusion uses a sliding window to replace the pixels with a reference value (e.g. constant value), and computes
|
||||
the output difference w.r.t the original output. The output difference caused by perturbed pixels are assigned as
|
||||
feature importance to those pixels. For pixels involved in multiple sliding windows, the feature importance is the
|
||||
averaged differences from multiple sliding windows.
|
||||
|
||||
For more details, please refer to the original paper via: `<https://arxiv.org/abs/1311.2901>`_.
|
||||
|
||||
Args:
|
||||
network (Cell): Specify the black-box model to be explained.
|
||||
|
||||
Inputs:
|
||||
inputs (Tensor): The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
|
||||
If it is a 1D tensor, its length should be the same as `inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`.
|
||||
|
||||
Example:
|
||||
>>> from mindspore.explainer.explanation import Occlusion
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> occlusion = Occlusion(net)
|
||||
>>> x = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||
>>> label = 1
|
||||
>>> saliency = occlusion(x, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network: Cell, activation_fn: Cell = nn.Softmax()):
|
||||
super().__init__(network, activation_fn)
|
||||
|
||||
self._ablation = Ablation(perturb_mode='Deletion')
|
||||
self._aggregation_fn = abs_max
|
||||
self._get_replacement = Constant(base_value=0.0)
|
||||
self._num_sample_per_dim = 32 # specify the number of perturbations each dimension.
|
||||
self._num_per_eval = 32 # number of perturbations each evaluation step.
|
||||
|
||||
def __call__(self, inputs: Tensor, targets: _Label) -> Tensor:
|
||||
"""Call function for 'Occlusion'."""
|
||||
self._verify_data(inputs, targets)
|
||||
|
||||
inputs = inputs.asnumpy()
|
||||
targets = targets.asnumpy() if isinstance(targets, Tensor) else np.array([targets] * inputs.shape[0], np.int)
|
||||
|
||||
# If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to
|
||||
# `(C, 3, 3)` and `(C, 1, 1)` separately.
|
||||
window_size = tuple(
|
||||
[inputs.shape[1]]
|
||||
+ [x % self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]])
|
||||
strides = tuple(
|
||||
[inputs.shape[1]]
|
||||
+ [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]])
|
||||
|
||||
model = nn.SequentialCell([self._model, self._activation_fn])
|
||||
|
||||
original_outputs = model(Tensor(inputs, ms.float32)).asnumpy()[np.arange(len(targets)), targets]
|
||||
|
||||
total_attribution = np.zeros_like(inputs)
|
||||
weights = np.ones_like(inputs)
|
||||
masks = Occlusion._generate_masks(inputs, window_size, strides)
|
||||
num_perturbations = masks.shape[1]
|
||||
original_outputs_repeat = np.repeat(original_outputs, repeats=num_perturbations, axis=0)
|
||||
|
||||
reference = self._get_replacement(inputs)
|
||||
occluded_inputs = self._ablation(inputs, reference, masks)
|
||||
targets_repeat = np.repeat(targets, repeats=num_perturbations, axis=0)
|
||||
|
||||
occluded_inputs = occluded_inputs.reshape((-1, *inputs.shape[1:]))
|
||||
if occluded_inputs.shape[0] > self._num_per_eval:
|
||||
cal_time = math.ceil(occluded_inputs.shape[0] / self._num_per_eval)
|
||||
occluded_outputs = []
|
||||
for i in range(cal_time):
|
||||
occluded_input = occluded_inputs[i*self._num_per_eval
|
||||
:min((i+1) * self._num_per_eval, occluded_inputs.shape[0])]
|
||||
target = targets_repeat[i*self._num_per_eval
|
||||
:min((i+1) * self._num_per_eval, occluded_inputs.shape[0])]
|
||||
occluded_output = model(Tensor(occluded_input)).asnumpy()[np.arange(target.shape[0]), target]
|
||||
occluded_outputs.append(occluded_output)
|
||||
occluded_outputs = np.concatenate(occluded_outputs)
|
||||
else:
|
||||
occluded_outputs = model(Tensor(occluded_inputs)).asnumpy()[np.arange(len(targets_repeat)), targets_repeat]
|
||||
outputs_diff = original_outputs_repeat - occluded_outputs
|
||||
outputs_diff = outputs_diff.reshape(inputs.shape[0], -1)
|
||||
|
||||
total_attribution += (
|
||||
outputs_diff.reshape(outputs_diff.shape + (1,) * (len(masks.shape) - 2)) * masks).sum(axis=1).clip(1e-6)
|
||||
weights += masks.sum(axis=1)
|
||||
|
||||
attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights))
|
||||
return attribution
|
||||
|
||||
@staticmethod
|
||||
def _generate_masks(inputs: Tensor, window_size: Tuple[int, ...], strides: Tuple[int, ...]) -> _Array:
|
||||
"""Generate masks to perturb contiguous regions."""
|
||||
total_dim = np.prod(inputs.shape[1:]).item()
|
||||
template = np.arange(total_dim).reshape(inputs.shape[1:])
|
||||
indices = _generate_patches(template, window_size, strides)
|
||||
num_perturbations = indices.reshape((-1,) + window_size).shape[0]
|
||||
indices = indices.reshape(num_perturbations, -1)
|
||||
|
||||
mask = np.zeros((num_perturbations, total_dim), dtype=np.bool)
|
||||
for i in range(num_perturbations):
|
||||
mask[i, indices[i]] = True
|
||||
mask = mask.reshape((num_perturbations,) + inputs.shape[1:])
|
||||
|
||||
masks = np.tile(mask, reps=(inputs.shape[0],) + (1,) * len(mask.shape))
|
||||
return masks
|
|
@ -18,7 +18,7 @@
|
|||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from .._attribution import Attribution
|
||||
from ..attribution import Attribution
|
||||
from ...._operators import softmax
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Modules to generate perturbations."""
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
|
||||
_Array = np.ndarray
|
||||
|
||||
__all__ = [
|
||||
'BaseReplacement',
|
||||
'Constant',
|
||||
'GaussianBlur',
|
||||
'RandomPerturb',
|
||||
]
|
||||
|
||||
|
||||
class BaseReplacement:
|
||||
"""
|
||||
Base class of generator for generating different replacement for perturbations.
|
||||
|
||||
Args:
|
||||
kwargs: Optional args for generating replacement. Derived class need to
|
||||
add necessary arg names and default value to '_necessary_args'.
|
||||
If the argument has no default value, the value should be set to
|
||||
'EMPTY' to mark the required args. Initializing an object will
|
||||
check the given kwargs w.r.t '_necessary_args'.
|
||||
|
||||
Raise:
|
||||
ValueError: Raise when provided kwargs not contain necessary arg names with 'EMPTY' mark.
|
||||
"""
|
||||
_necessary_args = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._replace_args = self._necessary_args.copy()
|
||||
for key, value in self._replace_args.items():
|
||||
if key in kwargs.keys():
|
||||
self._replace_args[key] = kwargs[key]
|
||||
elif key not in kwargs.keys() and value == 'EMPTY':
|
||||
raise ValueError(f"Missing keyword arg {key} for {self.__class__.__name__}.")
|
||||
|
||||
def __call__(self, inputs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Constant(BaseReplacement):
|
||||
"""Generator to provide constant-value replacement for perturbations."""
|
||||
_necessary_args = {'base_value': 'EMPTY'}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
replacement = np.ones_like(inputs, dtype=np.float32)
|
||||
replacement *= self._replace_args['base_value']
|
||||
return replacement
|
||||
|
||||
|
||||
class GaussianBlur(BaseReplacement):
|
||||
"""Generator to provided gaussian blurred inputs for perturbation"""
|
||||
_necessary_args = {'sigma': 0.7}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
sigma = self._replace_args['sigma']
|
||||
replacement = gaussian_filter(inputs, sigma=sigma)
|
||||
return replacement
|
||||
|
||||
|
||||
class RandomPerturb(BaseReplacement):
|
||||
"""Generator to provide replacement by randomly adding noise."""
|
||||
_necessary_args = {'radius': 0.2}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
radius = self._replace_args['radius']
|
||||
outputs = inputs + (2 * np.random.rand(*inputs.shape) - 1) * radius
|
||||
return outputs
|
|
@ -64,6 +64,9 @@ class RISE(PerturbationAttribution):
|
|||
activation_fn=nn.Softmax(),
|
||||
perturbation_per_eval=32):
|
||||
super(RISE, self).__init__(network, activation_fn)
|
||||
check_value_type('perturbation_per-eval', perturbation_per_eval, int)
|
||||
if perturbation_per_eval <= 0:
|
||||
raise ValueError('perturbation_per_eval should be postive integer.')
|
||||
self._perturbation_per_eval = perturbation_per_eval
|
||||
|
||||
self._num_masks = 6000 # number of masks to be sampled
|
||||
|
@ -156,12 +159,11 @@ class RISE(PerturbationAttribution):
|
|||
targets = self._unify_targets(inputs, targets)
|
||||
attr_classes = []
|
||||
for idx, target in enumerate(targets):
|
||||
dtype = inputs.dtype
|
||||
attr_np_idx = attr_np[idx]
|
||||
attr_idx = attr_np_idx[target]
|
||||
attr_classes.append(attr_idx)
|
||||
|
||||
return op.Tensor(attr_classes, dtype=dtype)
|
||||
return op.Tensor(attr_classes, dtype=inputs.dtype)
|
||||
|
||||
@staticmethod
|
||||
def _verify_data(inputs, targets):
|
||||
|
@ -183,7 +185,7 @@ class RISE(PerturbationAttribution):
|
|||
def _unify_targets(inputs, targets):
|
||||
"""To unify targets to be 2D numpy.ndarray."""
|
||||
if isinstance(targets, int):
|
||||
return np.array([[targets] for i in inputs]).astype(np.int)
|
||||
return np.array([[targets] for _ in inputs]).astype(np.int)
|
||||
if isinstance(targets, Tensor):
|
||||
if not targets.shape:
|
||||
return np.array([[targets.asnumpy()] for _ in inputs]).astype(np.int)
|
||||
|
|
|
@ -16,8 +16,10 @@
|
|||
|
||||
from typing import Callable
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.nn import Cell
|
||||
|
||||
|
||||
class Attribution:
|
||||
"""
|
||||
|
@ -26,15 +28,20 @@ class Attribution:
|
|||
The explainers which explanation through attributing the relevance scores should inherit this class.
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to explain.
|
||||
network (nn.Cell): The black-box model to explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
check_value_type("network", network, Cell)
|
||||
check_value_type("network", network, nn.Cell)
|
||||
self._model = network
|
||||
self._model.set_train(False)
|
||||
self._model.set_grad(False)
|
||||
|
||||
@staticmethod
|
||||
def _verify_model(model):
|
||||
"""Verify the input `network` for __init__ function."""
|
||||
if not isinstance(model, nn.Cell):
|
||||
raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.")
|
||||
|
||||
__call__: Callable
|
||||
"""
|
||||
|
@ -51,4 +58,17 @@ class Attribution:
|
|||
|
||||
@property
|
||||
def model(self):
|
||||
"""Return the model."""
|
||||
return self._model
|
||||
|
||||
@staticmethod
|
||||
def _verify_data(inputs, targets):
|
||||
"""Verify the validity of the parsed inputs."""
|
||||
check_value_type('inputs', inputs, ms.Tensor)
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument inputs must be 4D Tensor')
|
||||
check_value_type('targets', targets, (ms.Tensor, int))
|
||||
if isinstance(targets, ms.Tensor):
|
||||
if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)):
|
||||
raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
|
||||
'it should have the same length as inputs.')
|
Loading…
Reference in New Issue