!8844 Add rise explanation method into explainer

From: @lixiaohui33
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-21 16:44:50 +08:00 committed by Gitee
commit 8675e7fade
10 changed files with 344 additions and 77 deletions

View File

@ -18,9 +18,8 @@ from typing import List, Tuple, Union, Callable
import numpy as np
import mindspore
from mindspore import nn
import mindspore.ops.operations as op
from mindspore import nn
_Axis = Union[int, Tuple[int, ...], List[int]]
_Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]]
@ -235,7 +234,7 @@ def randint(low: int, high: int, shape: _Shape, dtype: mindspore.dtype = mindspo
return outputs
def softmax(axis: int) -> Callable:
def softmax(axis: int = -1) -> Callable:
"""Softmax activation function."""
func = nn.Softmax(axis=axis)
return func

View File

@ -20,20 +20,23 @@ from time import time
from typing import Tuple, List, Optional
import numpy as np
from scipy.stats import beta
from PIL import Image
from scipy.stats import beta
from mindspore.train._utils import check_value_type
from mindspore.train.summary_pb2 import Explain
import mindspore as ms
import mindspore.dataset as ds
from mindspore import log
from mindspore.nn import Softmax, Cell
from mindspore.nn.probability.toolbox import UncertaintyEvaluation
from mindspore.ops.operations import ExpandDims
from mindspore.train._utils import check_value_type
from mindspore.train.summary._summary_adapter import _convert_image_format
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.nn.probability.toolbox import UncertaintyEvaluation
from mindspore.train.summary_pb2 import Explain
from .benchmark import Localization
from .benchmark._attribution.metric import AttributionMetric
from .explanation import RISE
from .explanation._attribution._attribution import Attribution
# datafile directory names
@ -43,8 +46,8 @@ _HEATMAP_DIRNAME = "heatmap"
# max. no. of sample per directory
_SAMPLE_PER_DIR = 1000
_EXPAND_DIMS = ExpandDims()
_SEED = 58 # set a seed to fix the iterating order of the dataset
def _normalize(img_np):
@ -57,7 +60,7 @@ def _normalize(img_np):
def _np_to_image(img_np, mode):
"""Convert numpy array to PIL image."""
return Image.fromarray(np.uint8(img_np*255), mode=mode)
return Image.fromarray(np.uint8(img_np * 255), mode=mode)
def _calc_prob_interval(volume, probs, prob_vars):
@ -89,7 +92,7 @@ def _calc_prob_interval(volume, probs, prob_vars):
def _get_id_dirname(sample_id: int):
"""Get the name of parent directory of the image id."""
return str(int(sample_id/_SAMPLE_PER_DIR)*_SAMPLE_PER_DIR)
return str(int(sample_id / _SAMPLE_PER_DIR) * _SAMPLE_PER_DIR)
def _extract_timestamp(filename: str):
@ -107,6 +110,9 @@ class ExplainRunner:
After generating results with the explanation methods and the evaluation methods, the results will be written into
a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight.
Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version
will be deprecated and will not be supported in MindInsight of current version.
Args:
summary_dir (str, optional): The directory path to save the summary files which store the generated results.
Default: "./"
@ -131,7 +137,8 @@ class ExplainRunner:
dataset: Tuple,
explainers: List,
benchmarkers: Optional[List] = None,
uncertainty: Optional[UncertaintyEvaluation] = None):
uncertainty: Optional[UncertaintyEvaluation] = None,
activation_fn: Optional[Cell] = Softmax()):
"""
Genereates results and writes results into the summary files in `summary_dir` specified during the object
initialization.
@ -149,8 +156,12 @@ class ExplainRunner:
Default: None
uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference
uncertainty of samples.
activation_fn (Cell, optional): The activation layer that transforms the output of the network to
label probability distribution :math:`P(y|x)`. Default: Softmax().
Examples:
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
>>> from mindspore.nn import Sigmoid
>>> # obtain dataset object
>>> dataset = get_dataset()
>>> classes = ["cat", "dog", ...]
@ -158,13 +169,11 @@ class ExplainRunner:
>>> param_dict = load_checkpoint("checkpoint.ckpt")
>>> net = resnet50(len(classes))
>>> load_parama_into_net(net, param_dict)
>>> # bind net with its output activation
>>> model = nn.SequentialCell([net, nn.Sigmoid()])
>>> gbp = GuidedBackprop(model)
>>> gradient = Gradient(model)
>>> gbp = GuidedBackprop(net)
>>> gradient = Gradient(net)
>>> runner = ExplainRunner("./")
>>> explainers = [gbp, gradient]
>>> runner.run((dataset, classes), explainers)
>>> runner.run((dataset, classes), explainers, activation_fn=Sigmoid())
"""
check_value_type("dataset", dataset, tuple)
@ -181,16 +190,17 @@ class ExplainRunner:
for exp in explainers:
if not isinstance(exp, Attribution):
raise TypeError("Argument explainers should be a list of objects of classes in "
raise TypeError("Argument `explainers` should be a list of objects of classes in "
"`mindspore.explainer.explanation`.")
if benchmarkers is not None:
check_value_type("benchmarkers", benchmarkers, list)
for bench in benchmarkers:
if not isinstance(bench, AttributionMetric):
raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation"
raise TypeError("Argument `benchmarkers` should be a list of objects of classes in explanation"
"`mindspore.explainer.benchmark`.")
check_value_type("activation_fn", activation_fn, Cell)
self._model = explainers[0].model
self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn])
next_element = dataset.create_tuple_iterator().get_next()
inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._model(inputs)
@ -211,9 +221,10 @@ class ExplainRunner:
self._uncertainty = None
with SummaryRecord(self._summary_dir) as summary:
spacer = '{:120}\r'
print("Start running and writing......")
begin = time()
print("Start writing metadata.")
print("Start writing metadata......")
self._summary_timestamp = _extract_timestamp(summary.event_file_name)
if self._summary_timestamp is None:
@ -234,42 +245,47 @@ class ExplainRunner:
print("Finish writing metadata.")
now = time()
print("Start running and writing inference data......")
print("Start running and writing inference data.....")
imageid_labels = self._run_inference(dataset, summary)
print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now))
print(spacer.format("Finish running and writing inference data. "
"Time elapsed: {:.3f} s".format(time() - now)))
if benchmarkers is None or not benchmarkers:
for exp in explainers:
start = time()
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
self._count = 0
ds.config.set_seed(58)
ds.config.set_seed(_SEED)
for idx, next_element in enumerate(dataset):
now = time()
self._run_exp_step(next_element, exp, imageid_labels, summary)
print("Finish writing {}-th explanation data. Time elapsed: {}".format(
idx, time() - now))
print("Finish running and writing explanation data for {}. Time elapsed: {}".format(
exp.__class__.__name__, time() - start))
print(spacer.format("Finish writing {}-th explanation data for {}. Time elapsed: "
"{:.3f} s".format(idx, time() - now, exp.__class__.__name__)), end='')
print(spacer.format(
"Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
exp.__class__.__name__, time() - start)))
else:
for exp in explainers:
explain = Explain()
for bench in benchmarkers:
bench.reset()
print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.")
print(f"Start running and writing explanation and "
f"benchmark data for {exp.__class__.__name__}......")
self._count = 0
start = time()
ds.config.set_seed(58)
ds.config.set_seed(_SEED)
for idx, next_element in enumerate(dataset):
now = time()
saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary)
print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format(
idx, time() - now))
print(spacer.format(
"Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
idx, exp.__class__.__name__, time() - now)), end='')
for bench in benchmarkers:
now = time()
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format(
idx, bench.__class__.__name__, time() - now))
print(spacer.format(
"Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
idx, bench.__class__.__name__, exp.__class__.__name__, time() - now)), end='')
for bench in benchmarkers:
benchmark = explain.benchmark.add()
@ -279,11 +295,11 @@ class ExplainRunner:
benchmark.total_score = bench.performance
benchmark.label_score.extend(bench.class_performances)
print("Finish running and writing explanation and benchmark data for {}. "
"Time elapsed: {}s".format(exp.__class__.__name__, time() - start))
print(spacer.format("Finish running and writing explanation and benchmark data for {}. "
"Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start)))
summary.add_value('explainer', 'benchmark', explain)
summary.record(1)
print("Finish running and writing. Total time elapsed: {}s".format(time() - begin))
print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))
@staticmethod
def _verify_data_form(dataset, benchmarkers):
@ -446,8 +462,9 @@ class ExplainRunner:
Returns:
imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
"""
spacer = '{:120}\r'
imageid_labels = {}
ds.config.set_seed(58)
ds.config.set_seed(_SEED)
self._count = 0
for j, next_element in enumerate(dataset):
now = time()
@ -516,7 +533,9 @@ class ExplainRunner:
summary.record(1)
self._count += 1
print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now))
print(spacer.format("Finish running and writing {}-th batch inference data."
" Time elapsed: {:.3f} s".format(j, time() - now)),
end='')
return imageid_labels
def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
@ -543,18 +562,22 @@ class ExplainRunner:
batch_unions = self._make_label_batch(unions)
saliency_dict_lst = []
batch_saliency_full = []
for i in range(len(batch_unions[0])):
batch_saliency = explainer(inputs, batch_unions[:, i])
batch_saliency_full.append(batch_saliency)
if isinstance(explainer, RISE):
batch_saliency_full = explainer(inputs, batch_unions)
else:
batch_saliency_full = []
for i in range(len(batch_unions[0])):
batch_saliency = explainer(inputs, batch_unions[:, i])
batch_saliency_full.append(batch_saliency)
concat = ms.ops.operations.Concat(1)
batch_saliency_full = concat(tuple(batch_saliency_full))
for idx, union in enumerate(unions):
saliency_dict = {}
explain = Explain()
explain.sample_id = self._count
for k, lab in enumerate(union):
saliency = batch_saliency_full[k][idx:idx + 1]
saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
saliency_dict[lab] = saliency
saliency_np = _normalize(saliency.asnumpy().squeeze())
@ -600,7 +623,7 @@ class ExplainRunner:
def _save_original_image(self, sample_id: int, image):
"""Save an image to summary directory."""
id_dirname = _get_id_dirname(sample_id)
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp),
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
_ORIGINAL_IMAGE_DIRNAME,
id_dirname)
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
@ -613,7 +636,7 @@ class ExplainRunner:
def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image):
"""Save heatmap image to summary directory."""
id_dirname = _get_id_dirname(sample_id)
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp),
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
_HEATMAP_DIRNAME,
explain_method,
id_dirname)

View File

@ -21,6 +21,7 @@ from scipy.ndimage.filters import gaussian_filter
from mindspore import log
import mindspore as ms
from mindspore.train._utils import check_value_type
import mindspore.nn as nn
import mindspore.ops.operations as op
from .metric import AttributionMetric
@ -140,9 +141,7 @@ class Perturb:
@staticmethod
def _assign(x: _Array, y: _Array, masks: _Array):
"""Assign values to perturb pixels on perturbations."""
if masks.dtype != bool:
raise TypeError('The param "masks" should be an array of bool, but receive {}'
.format(masks.dtype))
check_value_type("masks dtype", masks.dtype, type(np.dtype(bool)))
for i in range(x.shape[0]):
x[i][:, masks[i]] = y[:, masks[i]]
@ -336,8 +335,7 @@ class NaiveFaithfulness(_FaithfulnessHelper):
if not np.count_nonzero(saliency):
log.warning("The saliency map is zero everywhere. The correlation will be set to zero.")
correlation = 0
normalized_faithfulness = (correlation + 1) / 2
return np.array([normalized_faithfulness], np.float)
return np.array([correlation], np.float)
reference = self._get_reference(inputs)
perturbations, masks = self._perturb(
inputs, saliency, reference, return_mask=True)
@ -347,8 +345,7 @@ class NaiveFaithfulness(_FaithfulnessHelper):
predictions = model(perturbations).asnumpy()[:, targets]
faithfulness = calc_correlation(feature_importance, predictions)
normalized_faithfulness = (faithfulness + 1) / 2
return np.array([normalized_faithfulness], np.float)
return np.array([faithfulness], np.float)
class DeletionAUC(_FaithfulnessHelper):
@ -533,6 +530,8 @@ class Faithfulness(AttributionMetric):
metric (str, optional): The specifi metric to quantify faithfulness.
Options: "DeletionAUC", "InsertionAUC", "NaiveFaithfulness".
Default: 'NaiveFaithfulness'.
activation_fn (Cell, optional): The activation function that transforms the network output to a probability.
Default: nn.Softmax().
Examples:
>>> from mindspore.explainer.benchmark import Faithfulness
@ -543,7 +542,7 @@ class Faithfulness(AttributionMetric):
"""
_methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC]
def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness"):
def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness", activation_fn=nn.Softmax()):
super(Faithfulness, self).__init__(num_labels)
perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument
@ -552,6 +551,7 @@ class Faithfulness(AttributionMetric):
num_perturb_steps = 100 # separate the perturbation progress in to 100 steps.
base_value = 0.0 # the pixel value set for the perturbed pixels
self._activation_fn = activation_fn
self._verify_metrics(metric)
for method in self._methods:
if metric == method.__name__:
@ -568,9 +568,7 @@ class Faithfulness(AttributionMetric):
Evaluate faithfulness on a single data sample.
Note:
To apply `Faithfulness` to evaluate an explainer, this explainer must be initialized with a network that
contains the output activation function. Otherwise, the results will not be correct. Currently only single
sample (:math:`N=1`) at each call is supported.
Currently only single sample (:math:`N=1`) at each call is supported.
Args:
explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`.
@ -586,7 +584,7 @@ class Faithfulness(AttributionMetric):
Examples:
>>> # init an explainer, the network should contain the output activation function.
>>> network = nn.SequentialCell([resnet50, nn.Sigmoid()])
>>> network = resnet50(20)
>>> gradient = Gradient(network)
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
>>> targets = 5
@ -610,10 +608,10 @@ class Faithfulness(AttributionMetric):
saliency = saliency.squeeze()
if len(saliency.shape) != 2:
raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape)))
faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=explainer.model,
model = nn.SequentialCell([explainer.model, self._activation_fn])
faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=model,
targets=targets, saliency=saliency)
return faithfulness
return (1 + faithfulness) / 2
def _verify_metrics(self, metric: str):
supports = [x.__name__ for x in self._methods]

View File

@ -17,10 +17,12 @@
from ._attribution._backprop.gradcam import GradCAM
from ._attribution._backprop.gradient import Gradient
from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop
from ._attribution._perturbation.rise import RISE
__all__ = [
'Gradient',
'Deconvolution',
'GuidedBackprop',
'GradCAM',
'RISE'
]

View File

@ -16,10 +16,12 @@
from ._backprop.gradcam import GradCAM
from ._backprop.gradient import Gradient
from ._backprop.modified_relu import Deconvolution, GuidedBackprop
from ._perturbation.rise import RISE
__all__ = [
'Gradient',
'Deconvolution',
'GuidedBackprop',
'GradCAM',
'RISE'
]

View File

@ -16,33 +16,25 @@
from typing import Callable
import mindspore as ms
from mindspore.train._utils import check_value_type
from mindspore.nn import Cell
class Attribution:
r"""
"""
Basic class of attributing the salient score
The explainers which explanation through attributing the relevance scores
should inherit this class.
The explainers which explanation through attributing the relevance scores should inherit this class.
Args:
network (ms.nn.Cell): The black-box model to explanation.
network (Cell): The black-box model to explain.
"""
def __init__(self, network):
self._verify_model(network)
check_value_type("network", network, Cell)
self._model = network
self._model.set_train(False)
self._model.set_grad(False)
@staticmethod
def _verify_model(model):
"""
Verify the input `network` for __init__ function.
"""
if not isinstance(model, ms.nn.Cell):
raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.")
__call__: Callable
"""

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
""" GradCAM and GuidedGradCAM. """
"""GradCAM."""
from mindspore.ops import operations as op
@ -98,7 +98,7 @@ class GradCAM(IntermediateLayerAttribution):
"""
Hook function to deal with the backward gradient.
The arguments are set as required by Cell.register_back_hook
The arguments are set as required by `Cell.register_backward_hook`.
"""
self._intermediate_grad = grad_input

View File

@ -0,0 +1,19 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Perturbation-based _attribution explainer. """
from .rise import RISE
__all__ = ['RISE']

View File

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base class `PerturbationAttribtuion`"""
from mindspore.train._utils import check_value_type
from mindspore.nn import Cell
from .._attribution import Attribution
from ...._operators import softmax
class PerturbationAttribution(Attribution):
"""
Base class for perturbation-based attribution methods.
All perturbation-based _attribution methods extend from this class.
"""
def __init__(self,
network,
activation_fn=softmax(),
):
super(PerturbationAttribution, self).__init__(network)
check_value_type("activation_fn", activation_fn, Cell)
self._activation_fn = activation_fn

View File

@ -0,0 +1,194 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RISE."""
import math
import numpy as np
from mindspore import Tensor
from mindspore import nn
from mindspore.train._utils import check_value_type
from .perturbation import PerturbationAttribution
from .... import _operators as op
from ...._utils import resize
class RISE(PerturbationAttribution):
r"""
RISE: Randomized Input Sampling for Explanation of Black-box Model.
RISE is a perturbation-based method that generates attribution maps by sampling on multiple random binary masks.
The original image is randomly masked, and then fed into the black-box model to get predictions. The final
attribution map is the weighted sum of these random masks, with the weights being the corresponding output on the
node of interest:
.. math::
E_{RISE}(I, f)_c = \sum_{i}f_c(I\odot M_i) M_i
For more details, please refer to the original paper via: `RISE <https://arxiv.org/abs/1806.07421>`_.
Args:
network (Cell): The black-box model to be explained.
activation_fn (Cell, optional): The activation layer that transforms logits to prediction probabilities. For
single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks,
`nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as
when combining this function with network, the final output is the probability of the input.
Default: `nn.Softmax`.
perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the
perturbed samples. Default: 32.
Examples:
>>> from mindspore.explainer.explanation import RISE
>>> net = resnet50(10)
>>> param_dict = load_checkpoint("resnet50.ckpt")
>>> load_param_into_net(net, param_dict)
>>> # init RISE with specified activation function
>>> rise = RISE(net, activation_fn=nn.layer.Sigmoid())
"""
def __init__(self,
network,
activation_fn=nn.Softmax(),
perturbation_per_eval=32):
super(RISE, self).__init__(network, activation_fn)
self._perturbation_per_eval = perturbation_per_eval
self._num_masks = 6000 # number of masks to be sampled
self._mask_probability = 0.2 # ratio of inputs to be masked
self._down_sample_size = 10 # the original size of binary masks
self._resize_mode = 'bilinear' # mode choice to resize the down-sized binary masks to size of the inputs
self._perturbation_mode = 'constant' # setting the perturbed pixels to a constant value
self._base_value = 0 # setting the perturbed pixels to this constant value
self._num_classes = None # placeholder of self._num_classes just for future assignment in other methods
def _generate_masks(self, data, batch_size):
"""Generate a batch of binary masks for data."""
height, width = data.shape[2], data.shape[3]
mask_size = (self._down_sample_size, self._down_sample_size)
up_size = (height + mask_size[0], width + mask_size[1])
mask = np.random.random((batch_size, 1) + mask_size) < self._mask_probability
upsample = resize(op.Tensor(mask, data.dtype), up_size,
self._resize_mode)
# Pack operator not available for GPU, thus transfer to numpy first
upsample_np = upsample.asnumpy()
masks_lst = []
for sample in upsample_np:
shift_x = np.random.randint(0, mask_size[0] + 1)
shift_y = np.random.randint(0, mask_size[1] + 1)
masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width])
masks = op.Tensor(np.array(masks_lst), data.dtype)
return masks
def __call__(self, inputs, targets):
"""
Generates attribution maps for inputs.
Args:
inputs (Tensor): Input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
targets (int, Tensor): The labels of interest to be explained. When `targets` is an integer,
all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it
should be of shape :math:`(N, ?)` or :math:`(N,)` :math:`()`.
Returns:
Tensor, a 4D tensor of shape :math:`(N, ?, H, W)` or :math:`(N, 1, H, W)`.
Examples:
>>> # given an instance of RISE, saliency map can be generate
>>> inputs = ms.Tensor(np.random.rand([2, 3, 224, 224]), ms.float32)
>>> # when `targets` is an integer
>>> targets = 5
>>> saliency = rise(inputs, targets)
>>> # `targets` can also be a tensor
>>> targets = ms.Tensor([[5], [1]])
>>> saliency = rise(inputs, targets)
>>>
"""
self._verify_data(inputs, targets)
height, width = inputs.shape[2], inputs.shape[3]
batch_size = inputs.shape[0]
if self._num_classes is None:
logits = self.model(inputs)
num_classes = logits.shape[1]
self._num_classes = num_classes
# Due to the unsupported Op of slice assignment, we use numpy array here
attr_np = np.zeros(shape=(batch_size, self._num_classes, height, width))
cal_times = math.ceil(self._num_masks / self._perturbation_per_eval)
for idx, data in enumerate(inputs):
bg_data = data * 0 + self._base_value
for j in range(cal_times):
bs = min(self._num_masks - j * self._perturbation_per_eval,
self._perturbation_per_eval)
data = op.reshape(data, (1, -1, height, width))
masks = self._generate_masks(data, bs)
masked_input = masks * data + (1 - masks) * bg_data
weights = self._activation_fn(self.model(masked_input))
while len(weights.shape) > 2:
weights = op.mean(weights, axis=2)
weights = op.reshape(weights,
(bs, self._num_classes, 1, 1))
attr_np[idx] += op.summation(weights * masks, axis=0).asnumpy()
attr_np = attr_np / self._num_masks
targets = self._unify_targets(inputs, targets)
attr_classes = []
for idx, target in enumerate(targets):
dtype = inputs.dtype
attr_np_idx = attr_np[idx]
attr_idx = attr_np_idx[target]
attr_classes.append(attr_idx)
return op.Tensor(attr_classes, dtype=dtype)
@staticmethod
def _verify_data(inputs, targets):
"""Verify the validity of the parsed inputs."""
check_value_type('inputs', inputs, Tensor)
if len(inputs.shape) != 4:
raise ValueError('Argument inputs must be 4D Tensor')
check_value_type('targets', targets, (Tensor, int, tuple, list))
if isinstance(targets, Tensor):
if len(targets.shape) > 2:
raise ValueError('Dimension invalid. If `targets` is a Tensor, it should be 0D, 1D or 2D. '
'But got {}D.'.format(len(targets.shape)))
if targets.shape and len(targets) != len(inputs):
raise ValueError(
'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}'.format(
len(inputs), len(targets)))
@staticmethod
def _unify_targets(inputs, targets):
"""To unify targets to be 2D numpy.ndarray."""
if isinstance(targets, int):
return np.array([[targets] for i in inputs]).astype(np.int)
if isinstance(targets, Tensor):
if not targets.shape:
return np.array([[targets.asnumpy()] for _ in inputs]).astype(np.int)
if len(targets.shape) == 1:
return np.array([[t.asnumpy()] for t in targets]).astype(np.int)
if len(targets.shape) == 2:
return np.array([t.asnumpy() for t in targets]).astype(np.int)
return targets