forked from mindspore-Ecosystem/mindspore
!8844 Add rise explanation method into explainer
From: @lixiaohui33 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
8675e7fade
|
@ -18,9 +18,8 @@ from typing import List, Tuple, Union, Callable
|
|||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import nn
|
||||
import mindspore.ops.operations as op
|
||||
|
||||
from mindspore import nn
|
||||
|
||||
_Axis = Union[int, Tuple[int, ...], List[int]]
|
||||
_Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]]
|
||||
|
@ -235,7 +234,7 @@ def randint(low: int, high: int, shape: _Shape, dtype: mindspore.dtype = mindspo
|
|||
return outputs
|
||||
|
||||
|
||||
def softmax(axis: int) -> Callable:
|
||||
def softmax(axis: int = -1) -> Callable:
|
||||
"""Softmax activation function."""
|
||||
func = nn.Softmax(axis=axis)
|
||||
return func
|
||||
|
|
|
@ -20,20 +20,23 @@ from time import time
|
|||
from typing import Tuple, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import beta
|
||||
from PIL import Image
|
||||
from scipy.stats import beta
|
||||
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.train.summary_pb2 import Explain
|
||||
import mindspore as ms
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log
|
||||
from mindspore.nn import Softmax, Cell
|
||||
from mindspore.nn.probability.toolbox import UncertaintyEvaluation
|
||||
from mindspore.ops.operations import ExpandDims
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.train.summary._summary_adapter import _convert_image_format
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from mindspore.nn.probability.toolbox import UncertaintyEvaluation
|
||||
from mindspore.train.summary_pb2 import Explain
|
||||
|
||||
from .benchmark import Localization
|
||||
from .benchmark._attribution.metric import AttributionMetric
|
||||
from .explanation import RISE
|
||||
from .explanation._attribution._attribution import Attribution
|
||||
|
||||
# datafile directory names
|
||||
|
@ -43,8 +46,8 @@ _HEATMAP_DIRNAME = "heatmap"
|
|||
# max. no. of sample per directory
|
||||
_SAMPLE_PER_DIR = 1000
|
||||
|
||||
|
||||
_EXPAND_DIMS = ExpandDims()
|
||||
_SEED = 58 # set a seed to fix the iterating order of the dataset
|
||||
|
||||
|
||||
def _normalize(img_np):
|
||||
|
@ -57,7 +60,7 @@ def _normalize(img_np):
|
|||
|
||||
def _np_to_image(img_np, mode):
|
||||
"""Convert numpy array to PIL image."""
|
||||
return Image.fromarray(np.uint8(img_np*255), mode=mode)
|
||||
return Image.fromarray(np.uint8(img_np * 255), mode=mode)
|
||||
|
||||
|
||||
def _calc_prob_interval(volume, probs, prob_vars):
|
||||
|
@ -89,7 +92,7 @@ def _calc_prob_interval(volume, probs, prob_vars):
|
|||
|
||||
def _get_id_dirname(sample_id: int):
|
||||
"""Get the name of parent directory of the image id."""
|
||||
return str(int(sample_id/_SAMPLE_PER_DIR)*_SAMPLE_PER_DIR)
|
||||
return str(int(sample_id / _SAMPLE_PER_DIR) * _SAMPLE_PER_DIR)
|
||||
|
||||
|
||||
def _extract_timestamp(filename: str):
|
||||
|
@ -107,6 +110,9 @@ class ExplainRunner:
|
|||
After generating results with the explanation methods and the evaluation methods, the results will be written into
|
||||
a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight.
|
||||
|
||||
Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version
|
||||
will be deprecated and will not be supported in MindInsight of current version.
|
||||
|
||||
Args:
|
||||
summary_dir (str, optional): The directory path to save the summary files which store the generated results.
|
||||
Default: "./"
|
||||
|
@ -131,7 +137,8 @@ class ExplainRunner:
|
|||
dataset: Tuple,
|
||||
explainers: List,
|
||||
benchmarkers: Optional[List] = None,
|
||||
uncertainty: Optional[UncertaintyEvaluation] = None):
|
||||
uncertainty: Optional[UncertaintyEvaluation] = None,
|
||||
activation_fn: Optional[Cell] = Softmax()):
|
||||
"""
|
||||
Genereates results and writes results into the summary files in `summary_dir` specified during the object
|
||||
initialization.
|
||||
|
@ -149,8 +156,12 @@ class ExplainRunner:
|
|||
Default: None
|
||||
uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference
|
||||
uncertainty of samples.
|
||||
activation_fn (Cell, optional): The activation layer that transforms the output of the network to
|
||||
label probability distribution :math:`P(y|x)`. Default: Softmax().
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
|
||||
>>> from mindspore.nn import Sigmoid
|
||||
>>> # obtain dataset object
|
||||
>>> dataset = get_dataset()
|
||||
>>> classes = ["cat", "dog", ...]
|
||||
|
@ -158,13 +169,11 @@ class ExplainRunner:
|
|||
>>> param_dict = load_checkpoint("checkpoint.ckpt")
|
||||
>>> net = resnet50(len(classes))
|
||||
>>> load_parama_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation
|
||||
>>> model = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> gbp = GuidedBackprop(model)
|
||||
>>> gradient = Gradient(model)
|
||||
>>> gbp = GuidedBackprop(net)
|
||||
>>> gradient = Gradient(net)
|
||||
>>> runner = ExplainRunner("./")
|
||||
>>> explainers = [gbp, gradient]
|
||||
>>> runner.run((dataset, classes), explainers)
|
||||
>>> runner.run((dataset, classes), explainers, activation_fn=Sigmoid())
|
||||
"""
|
||||
|
||||
check_value_type("dataset", dataset, tuple)
|
||||
|
@ -181,16 +190,17 @@ class ExplainRunner:
|
|||
|
||||
for exp in explainers:
|
||||
if not isinstance(exp, Attribution):
|
||||
raise TypeError("Argument explainers should be a list of objects of classes in "
|
||||
raise TypeError("Argument `explainers` should be a list of objects of classes in "
|
||||
"`mindspore.explainer.explanation`.")
|
||||
if benchmarkers is not None:
|
||||
check_value_type("benchmarkers", benchmarkers, list)
|
||||
for bench in benchmarkers:
|
||||
if not isinstance(bench, AttributionMetric):
|
||||
raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation"
|
||||
raise TypeError("Argument `benchmarkers` should be a list of objects of classes in explanation"
|
||||
"`mindspore.explainer.benchmark`.")
|
||||
check_value_type("activation_fn", activation_fn, Cell)
|
||||
|
||||
self._model = explainers[0].model
|
||||
self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn])
|
||||
next_element = dataset.create_tuple_iterator().get_next()
|
||||
inputs, _, _ = self._unpack_next_element(next_element)
|
||||
prop_test = self._model(inputs)
|
||||
|
@ -211,9 +221,10 @@ class ExplainRunner:
|
|||
self._uncertainty = None
|
||||
|
||||
with SummaryRecord(self._summary_dir) as summary:
|
||||
spacer = '{:120}\r'
|
||||
print("Start running and writing......")
|
||||
begin = time()
|
||||
print("Start writing metadata.")
|
||||
print("Start writing metadata......")
|
||||
|
||||
self._summary_timestamp = _extract_timestamp(summary.event_file_name)
|
||||
if self._summary_timestamp is None:
|
||||
|
@ -234,42 +245,47 @@ class ExplainRunner:
|
|||
print("Finish writing metadata.")
|
||||
|
||||
now = time()
|
||||
print("Start running and writing inference data......")
|
||||
print("Start running and writing inference data.....")
|
||||
imageid_labels = self._run_inference(dataset, summary)
|
||||
print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now))
|
||||
print(spacer.format("Finish running and writing inference data. "
|
||||
"Time elapsed: {:.3f} s".format(time() - now)))
|
||||
|
||||
if benchmarkers is None or not benchmarkers:
|
||||
for exp in explainers:
|
||||
start = time()
|
||||
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
|
||||
self._count = 0
|
||||
ds.config.set_seed(58)
|
||||
ds.config.set_seed(_SEED)
|
||||
for idx, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
self._run_exp_step(next_element, exp, imageid_labels, summary)
|
||||
print("Finish writing {}-th explanation data. Time elapsed: {}".format(
|
||||
idx, time() - now))
|
||||
print("Finish running and writing explanation data for {}. Time elapsed: {}".format(
|
||||
exp.__class__.__name__, time() - start))
|
||||
print(spacer.format("Finish writing {}-th explanation data for {}. Time elapsed: "
|
||||
"{:.3f} s".format(idx, time() - now, exp.__class__.__name__)), end='')
|
||||
print(spacer.format(
|
||||
"Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
|
||||
exp.__class__.__name__, time() - start)))
|
||||
else:
|
||||
for exp in explainers:
|
||||
explain = Explain()
|
||||
for bench in benchmarkers:
|
||||
bench.reset()
|
||||
print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.")
|
||||
print(f"Start running and writing explanation and "
|
||||
f"benchmark data for {exp.__class__.__name__}......")
|
||||
self._count = 0
|
||||
start = time()
|
||||
ds.config.set_seed(58)
|
||||
ds.config.set_seed(_SEED)
|
||||
for idx, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary)
|
||||
print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format(
|
||||
idx, time() - now))
|
||||
print(spacer.format(
|
||||
"Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
|
||||
idx, exp.__class__.__name__, time() - now)), end='')
|
||||
for bench in benchmarkers:
|
||||
now = time()
|
||||
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
|
||||
print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format(
|
||||
idx, bench.__class__.__name__, time() - now))
|
||||
print(spacer.format(
|
||||
"Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
|
||||
idx, bench.__class__.__name__, exp.__class__.__name__, time() - now)), end='')
|
||||
|
||||
for bench in benchmarkers:
|
||||
benchmark = explain.benchmark.add()
|
||||
|
@ -279,11 +295,11 @@ class ExplainRunner:
|
|||
benchmark.total_score = bench.performance
|
||||
benchmark.label_score.extend(bench.class_performances)
|
||||
|
||||
print("Finish running and writing explanation and benchmark data for {}. "
|
||||
"Time elapsed: {}s".format(exp.__class__.__name__, time() - start))
|
||||
print(spacer.format("Finish running and writing explanation and benchmark data for {}. "
|
||||
"Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start)))
|
||||
summary.add_value('explainer', 'benchmark', explain)
|
||||
summary.record(1)
|
||||
print("Finish running and writing. Total time elapsed: {}s".format(time() - begin))
|
||||
print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))
|
||||
|
||||
@staticmethod
|
||||
def _verify_data_form(dataset, benchmarkers):
|
||||
|
@ -446,8 +462,9 @@ class ExplainRunner:
|
|||
Returns:
|
||||
imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
|
||||
"""
|
||||
spacer = '{:120}\r'
|
||||
imageid_labels = {}
|
||||
ds.config.set_seed(58)
|
||||
ds.config.set_seed(_SEED)
|
||||
self._count = 0
|
||||
for j, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
|
@ -516,7 +533,9 @@ class ExplainRunner:
|
|||
summary.record(1)
|
||||
|
||||
self._count += 1
|
||||
print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now))
|
||||
print(spacer.format("Finish running and writing {}-th batch inference data."
|
||||
" Time elapsed: {:.3f} s".format(j, time() - now)),
|
||||
end='')
|
||||
return imageid_labels
|
||||
|
||||
def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
|
||||
|
@ -543,18 +562,22 @@ class ExplainRunner:
|
|||
batch_unions = self._make_label_batch(unions)
|
||||
saliency_dict_lst = []
|
||||
|
||||
batch_saliency_full = []
|
||||
for i in range(len(batch_unions[0])):
|
||||
batch_saliency = explainer(inputs, batch_unions[:, i])
|
||||
batch_saliency_full.append(batch_saliency)
|
||||
if isinstance(explainer, RISE):
|
||||
batch_saliency_full = explainer(inputs, batch_unions)
|
||||
else:
|
||||
batch_saliency_full = []
|
||||
for i in range(len(batch_unions[0])):
|
||||
batch_saliency = explainer(inputs, batch_unions[:, i])
|
||||
batch_saliency_full.append(batch_saliency)
|
||||
concat = ms.ops.operations.Concat(1)
|
||||
batch_saliency_full = concat(tuple(batch_saliency_full))
|
||||
|
||||
for idx, union in enumerate(unions):
|
||||
saliency_dict = {}
|
||||
explain = Explain()
|
||||
explain.sample_id = self._count
|
||||
for k, lab in enumerate(union):
|
||||
saliency = batch_saliency_full[k][idx:idx + 1]
|
||||
|
||||
saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
|
||||
saliency_dict[lab] = saliency
|
||||
|
||||
saliency_np = _normalize(saliency.asnumpy().squeeze())
|
||||
|
@ -600,7 +623,7 @@ class ExplainRunner:
|
|||
def _save_original_image(self, sample_id: int, image):
|
||||
"""Save an image to summary directory."""
|
||||
id_dirname = _get_id_dirname(sample_id)
|
||||
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp),
|
||||
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
|
||||
_ORIGINAL_IMAGE_DIRNAME,
|
||||
id_dirname)
|
||||
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
|
||||
|
@ -613,7 +636,7 @@ class ExplainRunner:
|
|||
def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image):
|
||||
"""Save heatmap image to summary directory."""
|
||||
id_dirname = _get_id_dirname(sample_id)
|
||||
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX+str(self._summary_timestamp),
|
||||
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
|
||||
_HEATMAP_DIRNAME,
|
||||
explain_method,
|
||||
id_dirname)
|
||||
|
|
|
@ -21,6 +21,7 @@ from scipy.ndimage.filters import gaussian_filter
|
|||
|
||||
from mindspore import log
|
||||
import mindspore as ms
|
||||
from mindspore.train._utils import check_value_type
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as op
|
||||
from .metric import AttributionMetric
|
||||
|
@ -140,9 +141,7 @@ class Perturb:
|
|||
@staticmethod
|
||||
def _assign(x: _Array, y: _Array, masks: _Array):
|
||||
"""Assign values to perturb pixels on perturbations."""
|
||||
if masks.dtype != bool:
|
||||
raise TypeError('The param "masks" should be an array of bool, but receive {}'
|
||||
.format(masks.dtype))
|
||||
check_value_type("masks dtype", masks.dtype, type(np.dtype(bool)))
|
||||
for i in range(x.shape[0]):
|
||||
x[i][:, masks[i]] = y[:, masks[i]]
|
||||
|
||||
|
@ -336,8 +335,7 @@ class NaiveFaithfulness(_FaithfulnessHelper):
|
|||
if not np.count_nonzero(saliency):
|
||||
log.warning("The saliency map is zero everywhere. The correlation will be set to zero.")
|
||||
correlation = 0
|
||||
normalized_faithfulness = (correlation + 1) / 2
|
||||
return np.array([normalized_faithfulness], np.float)
|
||||
return np.array([correlation], np.float)
|
||||
reference = self._get_reference(inputs)
|
||||
perturbations, masks = self._perturb(
|
||||
inputs, saliency, reference, return_mask=True)
|
||||
|
@ -347,8 +345,7 @@ class NaiveFaithfulness(_FaithfulnessHelper):
|
|||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
|
||||
faithfulness = calc_correlation(feature_importance, predictions)
|
||||
normalized_faithfulness = (faithfulness + 1) / 2
|
||||
return np.array([normalized_faithfulness], np.float)
|
||||
return np.array([faithfulness], np.float)
|
||||
|
||||
|
||||
class DeletionAUC(_FaithfulnessHelper):
|
||||
|
@ -533,6 +530,8 @@ class Faithfulness(AttributionMetric):
|
|||
metric (str, optional): The specifi metric to quantify faithfulness.
|
||||
Options: "DeletionAUC", "InsertionAUC", "NaiveFaithfulness".
|
||||
Default: 'NaiveFaithfulness'.
|
||||
activation_fn (Cell, optional): The activation function that transforms the network output to a probability.
|
||||
Default: nn.Softmax().
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer.benchmark import Faithfulness
|
||||
|
@ -543,7 +542,7 @@ class Faithfulness(AttributionMetric):
|
|||
"""
|
||||
_methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC]
|
||||
|
||||
def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness"):
|
||||
def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness", activation_fn=nn.Softmax()):
|
||||
super(Faithfulness, self).__init__(num_labels)
|
||||
|
||||
perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument
|
||||
|
@ -552,6 +551,7 @@ class Faithfulness(AttributionMetric):
|
|||
num_perturb_steps = 100 # separate the perturbation progress in to 100 steps.
|
||||
base_value = 0.0 # the pixel value set for the perturbed pixels
|
||||
|
||||
self._activation_fn = activation_fn
|
||||
self._verify_metrics(metric)
|
||||
for method in self._methods:
|
||||
if metric == method.__name__:
|
||||
|
@ -568,9 +568,7 @@ class Faithfulness(AttributionMetric):
|
|||
Evaluate faithfulness on a single data sample.
|
||||
|
||||
Note:
|
||||
To apply `Faithfulness` to evaluate an explainer, this explainer must be initialized with a network that
|
||||
contains the output activation function. Otherwise, the results will not be correct. Currently only single
|
||||
sample (:math:`N=1`) at each call is supported.
|
||||
Currently only single sample (:math:`N=1`) at each call is supported.
|
||||
|
||||
Args:
|
||||
explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`.
|
||||
|
@ -586,7 +584,7 @@ class Faithfulness(AttributionMetric):
|
|||
|
||||
Examples:
|
||||
>>> # init an explainer, the network should contain the output activation function.
|
||||
>>> network = nn.SequentialCell([resnet50, nn.Sigmoid()])
|
||||
>>> network = resnet50(20)
|
||||
>>> gradient = Gradient(network)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> targets = 5
|
||||
|
@ -610,10 +608,10 @@ class Faithfulness(AttributionMetric):
|
|||
saliency = saliency.squeeze()
|
||||
if len(saliency.shape) != 2:
|
||||
raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape)))
|
||||
|
||||
faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=explainer.model,
|
||||
model = nn.SequentialCell([explainer.model, self._activation_fn])
|
||||
faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=model,
|
||||
targets=targets, saliency=saliency)
|
||||
return faithfulness
|
||||
return (1 + faithfulness) / 2
|
||||
|
||||
def _verify_metrics(self, metric: str):
|
||||
supports = [x.__name__ for x in self._methods]
|
||||
|
|
|
@ -17,10 +17,12 @@
|
|||
from ._attribution._backprop.gradcam import GradCAM
|
||||
from ._attribution._backprop.gradient import Gradient
|
||||
from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop
|
||||
from ._attribution._perturbation.rise import RISE
|
||||
|
||||
__all__ = [
|
||||
'Gradient',
|
||||
'Deconvolution',
|
||||
'GuidedBackprop',
|
||||
'GradCAM',
|
||||
'RISE'
|
||||
]
|
||||
|
|
|
@ -16,10 +16,12 @@
|
|||
from ._backprop.gradcam import GradCAM
|
||||
from ._backprop.gradient import Gradient
|
||||
from ._backprop.modified_relu import Deconvolution, GuidedBackprop
|
||||
from ._perturbation.rise import RISE
|
||||
|
||||
__all__ = [
|
||||
'Gradient',
|
||||
'Deconvolution',
|
||||
'GuidedBackprop',
|
||||
'GradCAM',
|
||||
'RISE'
|
||||
]
|
||||
|
|
|
@ -16,33 +16,25 @@
|
|||
|
||||
from typing import Callable
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.nn import Cell
|
||||
|
||||
class Attribution:
|
||||
r"""
|
||||
"""
|
||||
Basic class of attributing the salient score
|
||||
|
||||
The explainers which explanation through attributing the relevance scores
|
||||
should inherit this class.
|
||||
The explainers which explanation through attributing the relevance scores should inherit this class.
|
||||
|
||||
Args:
|
||||
network (ms.nn.Cell): The black-box model to explanation.
|
||||
network (Cell): The black-box model to explain.
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
self._verify_model(network)
|
||||
check_value_type("network", network, Cell)
|
||||
self._model = network
|
||||
self._model.set_train(False)
|
||||
self._model.set_grad(False)
|
||||
|
||||
@staticmethod
|
||||
def _verify_model(model):
|
||||
"""
|
||||
Verify the input `network` for __init__ function.
|
||||
"""
|
||||
if not isinstance(model, ms.nn.Cell):
|
||||
raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.")
|
||||
|
||||
|
||||
__call__: Callable
|
||||
"""
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" GradCAM and GuidedGradCAM. """
|
||||
"""GradCAM."""
|
||||
|
||||
from mindspore.ops import operations as op
|
||||
|
||||
|
@ -98,7 +98,7 @@ class GradCAM(IntermediateLayerAttribution):
|
|||
"""
|
||||
Hook function to deal with the backward gradient.
|
||||
|
||||
The arguments are set as required by Cell.register_back_hook
|
||||
The arguments are set as required by `Cell.register_backward_hook`.
|
||||
"""
|
||||
self._intermediate_grad = grad_input
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" Perturbation-based _attribution explainer. """
|
||||
|
||||
from .rise import RISE
|
||||
|
||||
__all__ = ['RISE']
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Base class `PerturbationAttribtuion`"""
|
||||
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from .._attribution import Attribution
|
||||
from ...._operators import softmax
|
||||
|
||||
|
||||
class PerturbationAttribution(Attribution):
|
||||
"""
|
||||
Base class for perturbation-based attribution methods.
|
||||
|
||||
All perturbation-based _attribution methods extend from this class.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
network,
|
||||
activation_fn=softmax(),
|
||||
):
|
||||
super(PerturbationAttribution, self).__init__(network)
|
||||
check_value_type("activation_fn", activation_fn, Cell)
|
||||
self._activation_fn = activation_fn
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""RISE."""
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.train._utils import check_value_type
|
||||
|
||||
from .perturbation import PerturbationAttribution
|
||||
from .... import _operators as op
|
||||
from ...._utils import resize
|
||||
|
||||
|
||||
class RISE(PerturbationAttribution):
|
||||
r"""
|
||||
RISE: Randomized Input Sampling for Explanation of Black-box Model.
|
||||
|
||||
RISE is a perturbation-based method that generates attribution maps by sampling on multiple random binary masks.
|
||||
The original image is randomly masked, and then fed into the black-box model to get predictions. The final
|
||||
attribution map is the weighted sum of these random masks, with the weights being the corresponding output on the
|
||||
node of interest:
|
||||
|
||||
.. math::
|
||||
E_{RISE}(I, f)_c = \sum_{i}f_c(I\odot M_i) M_i
|
||||
|
||||
For more details, please refer to the original paper via: `RISE <https://arxiv.org/abs/1806.07421>`_.
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
activation_fn (Cell, optional): The activation layer that transforms logits to prediction probabilities. For
|
||||
single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification tasks,
|
||||
`nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long as
|
||||
when combining this function with network, the final output is the probability of the input.
|
||||
Default: `nn.Softmax`.
|
||||
perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the
|
||||
perturbed samples. Default: 32.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer.explanation import RISE
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # init RISE with specified activation function
|
||||
>>> rise = RISE(net, activation_fn=nn.layer.Sigmoid())
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
network,
|
||||
activation_fn=nn.Softmax(),
|
||||
perturbation_per_eval=32):
|
||||
super(RISE, self).__init__(network, activation_fn)
|
||||
self._perturbation_per_eval = perturbation_per_eval
|
||||
|
||||
self._num_masks = 6000 # number of masks to be sampled
|
||||
self._mask_probability = 0.2 # ratio of inputs to be masked
|
||||
self._down_sample_size = 10 # the original size of binary masks
|
||||
self._resize_mode = 'bilinear' # mode choice to resize the down-sized binary masks to size of the inputs
|
||||
self._perturbation_mode = 'constant' # setting the perturbed pixels to a constant value
|
||||
self._base_value = 0 # setting the perturbed pixels to this constant value
|
||||
self._num_classes = None # placeholder of self._num_classes just for future assignment in other methods
|
||||
|
||||
def _generate_masks(self, data, batch_size):
|
||||
"""Generate a batch of binary masks for data."""
|
||||
|
||||
height, width = data.shape[2], data.shape[3]
|
||||
|
||||
mask_size = (self._down_sample_size, self._down_sample_size)
|
||||
|
||||
up_size = (height + mask_size[0], width + mask_size[1])
|
||||
mask = np.random.random((batch_size, 1) + mask_size) < self._mask_probability
|
||||
upsample = resize(op.Tensor(mask, data.dtype), up_size,
|
||||
self._resize_mode)
|
||||
|
||||
# Pack operator not available for GPU, thus transfer to numpy first
|
||||
upsample_np = upsample.asnumpy()
|
||||
masks_lst = []
|
||||
for sample in upsample_np:
|
||||
shift_x = np.random.randint(0, mask_size[0] + 1)
|
||||
shift_y = np.random.randint(0, mask_size[1] + 1)
|
||||
masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width])
|
||||
masks = op.Tensor(np.array(masks_lst), data.dtype)
|
||||
return masks
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
"""
|
||||
Generates attribution maps for inputs.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
|
||||
targets (int, Tensor): The labels of interest to be explained. When `targets` is an integer,
|
||||
all of the inputs will generates attribution map w.r.t this integer. When `targets` is a tensor, it
|
||||
should be of shape :math:`(N, ?)` or :math:`(N,)` :math:`()`.
|
||||
|
||||
Returns:
|
||||
Tensor, a 4D tensor of shape :math:`(N, ?, H, W)` or :math:`(N, 1, H, W)`.
|
||||
|
||||
Examples:
|
||||
>>> # given an instance of RISE, saliency map can be generate
|
||||
>>> inputs = ms.Tensor(np.random.rand([2, 3, 224, 224]), ms.float32)
|
||||
>>> # when `targets` is an integer
|
||||
>>> targets = 5
|
||||
>>> saliency = rise(inputs, targets)
|
||||
>>> # `targets` can also be a tensor
|
||||
>>> targets = ms.Tensor([[5], [1]])
|
||||
>>> saliency = rise(inputs, targets)
|
||||
>>>
|
||||
"""
|
||||
self._verify_data(inputs, targets)
|
||||
height, width = inputs.shape[2], inputs.shape[3]
|
||||
|
||||
batch_size = inputs.shape[0]
|
||||
|
||||
if self._num_classes is None:
|
||||
logits = self.model(inputs)
|
||||
num_classes = logits.shape[1]
|
||||
self._num_classes = num_classes
|
||||
|
||||
# Due to the unsupported Op of slice assignment, we use numpy array here
|
||||
attr_np = np.zeros(shape=(batch_size, self._num_classes, height, width))
|
||||
|
||||
cal_times = math.ceil(self._num_masks / self._perturbation_per_eval)
|
||||
|
||||
for idx, data in enumerate(inputs):
|
||||
bg_data = data * 0 + self._base_value
|
||||
for j in range(cal_times):
|
||||
bs = min(self._num_masks - j * self._perturbation_per_eval,
|
||||
self._perturbation_per_eval)
|
||||
data = op.reshape(data, (1, -1, height, width))
|
||||
masks = self._generate_masks(data, bs)
|
||||
|
||||
masked_input = masks * data + (1 - masks) * bg_data
|
||||
weights = self._activation_fn(self.model(masked_input))
|
||||
while len(weights.shape) > 2:
|
||||
weights = op.mean(weights, axis=2)
|
||||
weights = op.reshape(weights,
|
||||
(bs, self._num_classes, 1, 1))
|
||||
|
||||
attr_np[idx] += op.summation(weights * masks, axis=0).asnumpy()
|
||||
|
||||
attr_np = attr_np / self._num_masks
|
||||
targets = self._unify_targets(inputs, targets)
|
||||
attr_classes = []
|
||||
for idx, target in enumerate(targets):
|
||||
dtype = inputs.dtype
|
||||
attr_np_idx = attr_np[idx]
|
||||
attr_idx = attr_np_idx[target]
|
||||
attr_classes.append(attr_idx)
|
||||
|
||||
return op.Tensor(attr_classes, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _verify_data(inputs, targets):
|
||||
"""Verify the validity of the parsed inputs."""
|
||||
check_value_type('inputs', inputs, Tensor)
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument inputs must be 4D Tensor')
|
||||
check_value_type('targets', targets, (Tensor, int, tuple, list))
|
||||
if isinstance(targets, Tensor):
|
||||
if len(targets.shape) > 2:
|
||||
raise ValueError('Dimension invalid. If `targets` is a Tensor, it should be 0D, 1D or 2D. '
|
||||
'But got {}D.'.format(len(targets.shape)))
|
||||
if targets.shape and len(targets) != len(inputs):
|
||||
raise ValueError(
|
||||
'If `targets` is a 2D, 1D Tensor, it should have the same length as inputs {}. But got {}'.format(
|
||||
len(inputs), len(targets)))
|
||||
|
||||
@staticmethod
|
||||
def _unify_targets(inputs, targets):
|
||||
"""To unify targets to be 2D numpy.ndarray."""
|
||||
if isinstance(targets, int):
|
||||
return np.array([[targets] for i in inputs]).astype(np.int)
|
||||
if isinstance(targets, Tensor):
|
||||
if not targets.shape:
|
||||
return np.array([[targets.asnumpy()] for _ in inputs]).astype(np.int)
|
||||
if len(targets.shape) == 1:
|
||||
return np.array([[t.asnumpy()] for t in targets]).astype(np.int)
|
||||
if len(targets.shape) == 2:
|
||||
return np.array([t.asnumpy() for t in targets]).astype(np.int)
|
||||
return targets
|
Loading…
Reference in New Issue