forked from mindspore-Ecosystem/mindspore
Add explainer to provide eXplainable AI tools.
This commit provides APIs for user to use the widely used attribution methods to explain DL models and the evaluation methods to quantify the explanations. With combination of MindInsight, the user can have a friendly visualization on their models.
This commit is contained in:
parent
a418280659
commit
744f094add
|
@ -118,18 +118,10 @@ message Explain {
|
|||
}
|
||||
|
||||
message Benchmark{
|
||||
message TotalScore{
|
||||
optional string benchmark_method = 1;
|
||||
optional float score = 2;
|
||||
}
|
||||
message LabelScore{
|
||||
repeated float score = 1;
|
||||
optional string benchmark_method = 2;
|
||||
}
|
||||
|
||||
optional string explain_method = 1;
|
||||
repeated TotalScore total_score = 2;
|
||||
repeated LabelScore label_score = 3;
|
||||
optional string explain_method = 2;
|
||||
optional float total_score = 3;
|
||||
repeated float label_score = 4;
|
||||
}
|
||||
|
||||
message Metadata{
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Provide ExplainRunner High-level API."""
|
||||
|
||||
from ._runner import ExplainRunner
|
||||
|
||||
__all__ = ['ExplainRunner']
|
|
@ -0,0 +1,261 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Packaged operations based on MindSpore."""
|
||||
from typing import List, Tuple, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import nn
|
||||
import mindspore.ops.operations as op
|
||||
|
||||
|
||||
_Axis = Union[int, Tuple[int, ...], List[int]]
|
||||
_Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]]
|
||||
_Number = Union[int, float, np.int, np.float]
|
||||
_Shape = Union[int, Tuple[int, ...]]
|
||||
Tensor = mindspore.Tensor
|
||||
|
||||
__all__ = [
|
||||
'absolute',
|
||||
'arange',
|
||||
'argmax',
|
||||
'argmin',
|
||||
'argsort',
|
||||
'assign',
|
||||
'intersection',
|
||||
'matmul',
|
||||
'maximum',
|
||||
'minimum',
|
||||
'mean',
|
||||
'mul',
|
||||
'sort',
|
||||
'squeeze',
|
||||
'tile',
|
||||
'reshape',
|
||||
'zeros',
|
||||
'zeros_like',
|
||||
'softmax',
|
||||
'Tensor',
|
||||
'summation'
|
||||
]
|
||||
|
||||
|
||||
def absolute(inputs: Tensor) -> Tensor:
|
||||
"""Get the absolute value of a tensor value."""
|
||||
abs_op = op.Abs()
|
||||
outputs = abs_op(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def arange(
|
||||
start: _Number,
|
||||
end: _Number,
|
||||
step: _Number = 1,
|
||||
dtype: mindspore.dtype = None) -> Tensor:
|
||||
"""Get the arange value of tensor."""
|
||||
nums = np.arange(start=start, stop=end, step=step, dtype=np.int32)
|
||||
nums = mindspore.Tensor(nums, dtype=dtype)
|
||||
return nums
|
||||
|
||||
|
||||
def argmax(inputs: Tensor, axis: int = -1, keep_dims: bool = False) -> Tensor:
|
||||
"""Returns the indices of the maximum values along an axis."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
outputs = np.argmax(inputs_np, axis=axis)
|
||||
|
||||
if keep_dims:
|
||||
outputs = np.expand_dims(outputs, axis=axis)
|
||||
|
||||
return mindspore.Tensor(outputs, mindspore.int32)
|
||||
|
||||
|
||||
def argmin(inputs: Tensor, axis: int = -1, keep_dims: bool = False) -> Tensor:
|
||||
"""Returns the indices of the minimum values along an axis."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
outputs = np.argmin(inputs_np, axis=axis)
|
||||
|
||||
if keep_dims:
|
||||
outputs = np.expand_dims(outputs, axis=axis)
|
||||
|
||||
return mindspore.Tensor(outputs, mindspore.int32)
|
||||
|
||||
|
||||
def argsort(inputs: Tensor, axis: int = -1, descending: bool = False) -> Tensor:
|
||||
"""Returns the indices that would sort an array."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
factor = -1 if descending else 1
|
||||
indices_np = np.argsort(factor * inputs_np, axis=axis)
|
||||
indices = mindspore.Tensor(indices_np, dtype=mindspore.int32)
|
||||
return indices
|
||||
|
||||
|
||||
def assign(inputs: Tensor, idx: _Idx, value: Tensor) -> Tensor:
|
||||
"""Assign a tensor value to the given tensor and index."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
if isinstance(idx, Tensor):
|
||||
idx = idx.asnumpy()
|
||||
value_np = value.asnumpy()
|
||||
inputs_np[idx] = value_np
|
||||
outputs = mindspore.Tensor(inputs_np)
|
||||
return outputs
|
||||
|
||||
|
||||
def intersection(*inputs: Tensor) -> Tensor:
|
||||
"""Get the intersection value by the given tensor list."""
|
||||
outputs_np = np.ones_like(inputs[0])
|
||||
for inp in inputs:
|
||||
outputs_np &= inp.asnumpy()
|
||||
outputs = mindspore.Tensor(outputs_np)
|
||||
return outputs
|
||||
|
||||
|
||||
def matmul(inputs_x: Tensor, inputs_y: Tensor) -> Tensor:
|
||||
"""Multiplies matrix `inputs_x` and matrix `inputs_y`."""
|
||||
matmul_op = op.MatMul()
|
||||
outputs = matmul_op(inputs_x, inputs_y)
|
||||
return outputs
|
||||
|
||||
|
||||
def maximum(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor:
|
||||
"""Reduce a dimension of a tensor by the maximum value in this dimension."""
|
||||
max_op = op.ReduceMax(keep_dims)
|
||||
outputs = max_op(inputs, axis)
|
||||
return outputs
|
||||
|
||||
|
||||
def minimum(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor:
|
||||
"""Reduce a dimension of a tensor by the minimum value in the dimension."""
|
||||
max_op = op.ReduceMin(keep_dims)
|
||||
outputs = max_op(inputs, axis)
|
||||
return outputs
|
||||
|
||||
|
||||
def mean(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor:
|
||||
"""Reduce a dimension of a tensor by averaging all elements in the dimension."""
|
||||
mean_op = op.ReduceMean(keep_dims)
|
||||
outputs = mean_op(inputs, axis)
|
||||
return outputs
|
||||
|
||||
|
||||
def mul(inputs_x: Tensor, inputs_y: Tensor) -> Tensor:
|
||||
"""
|
||||
Multiplies two tensors element-wise.
|
||||
|
||||
Inputs of `input_x` and `input_y` comply with the implicit type conversion rules to make the data types consistent.
|
||||
The inputs must be two tensors or one tensor and one scalar.
|
||||
When the inputs are two tensors,
|
||||
dtypes of them cannot be both bool, and the shapes of them could be broadcast.
|
||||
When the inputs are one tensor and one scalar,
|
||||
the scalar could only be a constant.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
|
||||
a bool or a tensor whose data type is number or bool.
|
||||
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
|
||||
a bool when the first input is a tensor or a tensor whose data type is number or bool.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the one after broadcasting,
|
||||
and the data type is the one with higher precision or higher digits among the two inputs.
|
||||
"""
|
||||
mul_op = op.Mul()
|
||||
outputs = mul_op(inputs_x, inputs_y)
|
||||
return outputs
|
||||
|
||||
|
||||
def sort(inputs: Tensor, axis: _Axis = -1, descending: bool = False) -> Tensor:
|
||||
"""Return a sorted copy of an array."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
outputs_np = np.sort(inputs_np, axis=axis)
|
||||
if descending:
|
||||
outputs_np = np.flip(outputs_np, axis=axis)
|
||||
outputs = mindspore.Tensor(outputs_np)
|
||||
return outputs
|
||||
|
||||
|
||||
def squeeze(inputs: Tensor, axis: _Axis = ()):
|
||||
"""Returns a tensor with the same type but dimensions of 1 are removed based on `axis`."""
|
||||
squeeze_op = op.Squeeze(axis)
|
||||
outputs = squeeze_op(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def tile(inputs: Tensor, shape: Tuple[int, ...]) -> Tensor:
|
||||
"""Replicates a tensor with given multiples times."""
|
||||
tile_op = op.Tile()
|
||||
outputs = tile_op(inputs, shape)
|
||||
return outputs
|
||||
|
||||
|
||||
def reshape(inputs: Tensor, shape: _Shape) -> Tensor:
|
||||
"""Reshapes input tensor with the same values based on a given shape tuple."""
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
return op.Reshape()(inputs, shape)
|
||||
|
||||
|
||||
def zeros(shape: _Shape, dtype: mindspore.dtype = None) -> Tensor:
|
||||
"""Return a new array of given shape and type, filled with zeros."""
|
||||
outputs = np.zeros(shape)
|
||||
return mindspore.Tensor(outputs, dtype=dtype)
|
||||
|
||||
|
||||
def zeros_like(inputs: Tensor, dtype: mindspore.dtype = None) -> Tensor:
|
||||
"""Return an array of zeros with the same shape and type as a given array."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
outputs_np = np.zeros_like(inputs_np)
|
||||
outputs = mindspore.Tensor(outputs_np, dtype)
|
||||
return outputs
|
||||
|
||||
|
||||
def random(shape: _Shape, dtype: mindspore.dtype = None) -> Tensor:
|
||||
"""Return random floats in the half-open interval [0.0, 1.0)."""
|
||||
outputs_np = np.random.random(shape)
|
||||
outputs = mindspore.Tensor(outputs_np, dtype)
|
||||
return outputs
|
||||
|
||||
|
||||
def randint(low: int, high: int, shape: _Shape, dtype: mindspore.dtype = mindspore.int8) -> Tensor:
|
||||
"""Return random integers from `low` (inclusive) to `high` (exclusive)."""
|
||||
outputs_np = np.random.randint(low, high, size=shape)
|
||||
outputs = mindspore.Tensor(outputs_np, dtype=dtype)
|
||||
return outputs
|
||||
|
||||
|
||||
def softmax(axis: int) -> Callable:
|
||||
"""Softmax activation function."""
|
||||
func = nn.Softmax(axis=axis)
|
||||
return func
|
||||
|
||||
|
||||
def summation(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor:
|
||||
"""Reduce a dimension of a tensor by summing all elements in the dimension."""
|
||||
sum_op = op.ReduceSum(keep_dims)
|
||||
outputs = sum_op(inputs, axis)
|
||||
return outputs
|
||||
|
||||
|
||||
def stack(inputs: List[Tensor], axis: int) -> Tensor:
|
||||
"""Packs a list of tensors in specified axis."""
|
||||
pack_op = op.Pack(axis)
|
||||
outputs = pack_op(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def sqrt(inputs: Tensor) -> Tensor:
|
||||
"""Returns square root of a tensor element-wise."""
|
||||
sqrt_op = op.Sqrt()
|
||||
return sqrt_op(inputs)
|
|
@ -0,0 +1,481 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Runner."""
|
||||
from time import time
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from mindspore.train.summary_pb2 import Explain
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log
|
||||
from mindspore.ops.operations import ExpandDims
|
||||
from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from .benchmark import Localization
|
||||
from .benchmark._attribution.metric import AttributionMetric
|
||||
from .explanation._attribution._attribution import Attribution
|
||||
|
||||
_EXPAND_DIMS = ExpandDims()
|
||||
_CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255
|
||||
_CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255
|
||||
|
||||
|
||||
def _normalize(img_np):
|
||||
"""Normalize the image in the numpy array to be in [0, 255]. """
|
||||
max_ = img_np.max()
|
||||
min_ = img_np.min()
|
||||
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
|
||||
return (normed * 255).astype(np.uint8)
|
||||
|
||||
|
||||
def _make_rgba(saliency):
|
||||
"""Make rgba image for saliency map."""
|
||||
saliency = saliency.asnumpy().squeeze()
|
||||
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10)
|
||||
rgba = np.empty((saliency.shape[0], saliency.shape[1], 4))
|
||||
rgba[:, :, :] = np.expand_dims(saliency, 2)
|
||||
rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0
|
||||
rgba[:, :, -1] = saliency * 1
|
||||
return rgba
|
||||
|
||||
|
||||
class ExplainRunner:
|
||||
"""
|
||||
High-level API for users to generate results with the explanation methods and the evaluation methods.
|
||||
|
||||
After generating results with the explanation methods and the evaluation methods, the results will be written into
|
||||
a specified file with 'mindspore.summary.SummaryRecord'. The stored content can be viewed using MindInsight.
|
||||
|
||||
Args:
|
||||
summary_dir (str): The directory path to save the summary files which store the generated results.
|
||||
Default: "./"
|
||||
|
||||
Examples:
|
||||
>>> # init a runner with a specified directory
|
||||
>>> summary_dir = "summary_dir"
|
||||
>>> runner = ExplainRunner(summary_dir)
|
||||
"""
|
||||
|
||||
def __init__(self, summary_dir: Optional[str] = "./"):
|
||||
self._summary_dir = summary_dir
|
||||
self._count = 0
|
||||
self._classes = None
|
||||
self._model = None
|
||||
|
||||
def run(self,
|
||||
dataset: Tuple,
|
||||
explainers: List,
|
||||
benchmarkers: Optional[List] = None):
|
||||
"""
|
||||
Genereate results and write results into the summary files in `self.summary_dir`.
|
||||
|
||||
Args:
|
||||
dataset (tuple): A tuple that contains `mindspore.dataset` object for iteration and its labels.
|
||||
- dataset[0], a `mindspore.dataset` object to provide data to explain.
|
||||
- dataset[1], a list of string that specifies the label names of the dataset.
|
||||
explainers (list): A list of explanation objects to generate _attribution results.
|
||||
benchmarkers (list): A list of benchmark objects to generate evaluation results. Default: None
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
|
||||
>>> # obtain dataset object
|
||||
>>> dataset = get_dataset()
|
||||
>>> classes = ["cat", "dog", ...]
|
||||
>>> # load checkpoint to a network, e.g. resnet50
|
||||
>>> param_dict = load_checkpoint("checkpoint.ckpt")
|
||||
>>> net = resnet50(len(classes))
|
||||
>>> load_parama_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation
|
||||
>>> model = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> gbp = GuidedBackprop(model)
|
||||
>>> gradient = Gradient(model)
|
||||
>>> runner = ExplainRunner("./")
|
||||
>>> explainers = [gbp, gradient]
|
||||
>>> runner.run((dataset, classes), explainers)
|
||||
"""
|
||||
|
||||
if not isinstance(dataset, tuple):
|
||||
raise TypeError("Argument `dataset` must be a tuple.")
|
||||
if len(dataset) != 2:
|
||||
raise ValueError("Argument `dataset` should be a tuple with length = 2.")
|
||||
|
||||
dataset, classes = dataset
|
||||
self._verify_data_form(dataset, benchmarkers)
|
||||
self._classes = classes
|
||||
|
||||
if explainers is None or not explainers:
|
||||
raise ValueError("Argument `explainers` can neither be None nor empty.")
|
||||
|
||||
for exp in explainers:
|
||||
if not isinstance(exp, Attribution) or not isinstance(explainers, list):
|
||||
raise TypeError("Argument explainers should be a list of objects of classes in "
|
||||
"`mindspore.explainer.explanation._attribution`.")
|
||||
if benchmarkers is not None:
|
||||
for bench in benchmarkers:
|
||||
if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list):
|
||||
raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation"
|
||||
"`mindspore.explainer.benchmark._attribution`.")
|
||||
|
||||
self._model = explainers[0].model
|
||||
|
||||
with SummaryRecord(self._summary_dir) as summary:
|
||||
print("Start running and writing......")
|
||||
begin = time()
|
||||
print("Start writing metadata.")
|
||||
|
||||
explain = Explain()
|
||||
explain.metadata.label.extend(classes)
|
||||
exp_names = [exp.__class__.__name__ for exp in explainers]
|
||||
explain.metadata.explain_method.extend(exp_names)
|
||||
if benchmarkers is not None:
|
||||
bench_names = [bench.__class__.__name__ for bench in benchmarkers]
|
||||
explain.metadata.benchmark_method.extend(bench_names)
|
||||
|
||||
summary.add_value("explainer", "metadata", explain)
|
||||
summary.record(1)
|
||||
|
||||
print("Finish writing metadata.")
|
||||
|
||||
now = time()
|
||||
print("Start running and writing inference data......")
|
||||
imageid_labels = self._run_inference(dataset, summary)
|
||||
print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now))
|
||||
|
||||
if benchmarkers is None:
|
||||
for exp in explainers:
|
||||
start = time()
|
||||
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
|
||||
self._count = 0
|
||||
ds.config.set_seed(58)
|
||||
for idx, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
self._run_exp_step(next_element, exp, imageid_labels, summary)
|
||||
print("Finish writing {}-th explanation data. Time elapsed: {}".format(
|
||||
idx, time() - now))
|
||||
print("Finish running and writing explanation data for {}. Time elapsed: {}".format(
|
||||
exp.__class__.__name__, time() - start))
|
||||
else:
|
||||
for exp in explainers:
|
||||
explain = Explain()
|
||||
for bench in benchmarkers:
|
||||
bench.reset()
|
||||
print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.")
|
||||
self._count = 0
|
||||
start = time()
|
||||
ds.config.set_seed(58)
|
||||
for idx, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary)
|
||||
print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format(
|
||||
idx, time() - now))
|
||||
for bench in benchmarkers:
|
||||
now = time()
|
||||
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
|
||||
print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format(
|
||||
idx, bench.__class__.__name__, time() - now))
|
||||
|
||||
for bench in benchmarkers:
|
||||
benchmark = explain.benchmark.add()
|
||||
benchmark.explain_method = exp.__class__.__name__
|
||||
benchmark.benchmark_method = bench.__class__.__name__
|
||||
|
||||
benchmark.total_score = bench.performance
|
||||
benchmark.label_score.extend(bench.class_performances)
|
||||
|
||||
print("Finish running and writing explanation and benchmark data for {}. "
|
||||
"Time elapsed: {}s".format(exp.__class__.__name__, time() - start))
|
||||
summary.add_value('explainer', 'benchmark', explain)
|
||||
summary.record(1)
|
||||
print("Finish running and writing. Total time elapsed: {}s".format(time() - begin))
|
||||
|
||||
@staticmethod
|
||||
def _verify_data_form(dataset, benchmarkers):
|
||||
"""
|
||||
Verify the validity of dataset.
|
||||
|
||||
Args:
|
||||
dataset (`ds`): the user parsed dataset.
|
||||
benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers.
|
||||
"""
|
||||
next_element = dataset.create_tuple_iterator().get_next()
|
||||
|
||||
if len(next_element) not in [1, 2, 3]:
|
||||
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
|
||||
" as columns.")
|
||||
|
||||
if len(next_element) == 3:
|
||||
inputs, labels, bboxes = next_element
|
||||
if bboxes.shape[-1] != 4:
|
||||
raise ValueError("The third element of dataset should be bounding boxes with shape of "
|
||||
"[batch_size, num_ground_truth, 4].")
|
||||
else:
|
||||
if True in [isinstance(bench, Localization) for bench in benchmarkers]:
|
||||
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
|
||||
|
||||
if len(next_element) == 2:
|
||||
inputs, labels = next_element
|
||||
if len(next_element) == 1:
|
||||
inputs = next_element[0]
|
||||
|
||||
if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
|
||||
raise ValueError(
|
||||
"Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
|
||||
inputs.shape))
|
||||
if len(inputs.shape) == 3:
|
||||
log.warning(
|
||||
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
|
||||
" dimension as batch data.".format(inputs.shape))
|
||||
|
||||
if len(next_element) > 1:
|
||||
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
|
||||
raise ValueError(
|
||||
"Labels shape {} is unrecognizable: labels should not have more than two dimensions"
|
||||
" with length greater than 1.".format(labels.shape))
|
||||
|
||||
def _transform_data(self, inputs, labels, bboxes, ifbbox):
|
||||
"""
|
||||
Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): the image data
|
||||
labels (Tensor): the labels
|
||||
bboxes (Tensor): the boudnding boxes data
|
||||
ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t label
|
||||
id will be returned. If False, the returned bboxes is the the parsed bboxes.
|
||||
|
||||
Returns:
|
||||
inputs (Tensor): the image data, unified to a 4D Tensor.
|
||||
labels (List[List[int]]): the ground truth labels.
|
||||
bboxes (Union[List[Dict], None, Tensor]): the bounding boxes
|
||||
"""
|
||||
inputs = ms.Tensor(inputs, ms.float32)
|
||||
if len(inputs.shape) == 3:
|
||||
inputs = _EXPAND_DIMS(inputs, 0)
|
||||
if isinstance(labels, ms.Tensor):
|
||||
labels = ms.Tensor(labels, ms.int32)
|
||||
labels = _EXPAND_DIMS(labels, 0)
|
||||
if isinstance(bboxes, ms.Tensor):
|
||||
bboxes = ms.Tensor(bboxes, ms.int32)
|
||||
bboxes = _EXPAND_DIMS(bboxes, 0)
|
||||
|
||||
input_len = len(inputs)
|
||||
if bboxes is not None and ifbbox:
|
||||
bboxes = ms.Tensor(bboxes, ms.int32)
|
||||
masks_lst = []
|
||||
labels = labels.asnumpy().reshape([input_len, -1])
|
||||
bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
|
||||
for idx, label in enumerate(labels):
|
||||
height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
|
||||
masks = {}
|
||||
for j, label_item in enumerate(label):
|
||||
target = int(label_item)
|
||||
if -1 < target < len(self._classes):
|
||||
if target not in masks:
|
||||
mask = np.zeros((1, 1, height, width))
|
||||
else:
|
||||
mask = masks[target]
|
||||
x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
|
||||
mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
|
||||
masks[target] = mask
|
||||
|
||||
masks_lst.append(masks)
|
||||
bboxes = masks_lst
|
||||
|
||||
labels = ms.Tensor(labels, ms.int32)
|
||||
if len(labels.shape) == 1:
|
||||
labels_lst = [[int(i)] for i in labels.asnumpy()]
|
||||
else:
|
||||
labels = labels.asnumpy().reshape([input_len, -1])
|
||||
labels_lst = []
|
||||
for item in labels:
|
||||
labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._classes))))
|
||||
labels = labels_lst
|
||||
return inputs, labels, bboxes
|
||||
|
||||
def _unpack_next_element(self, next_element, ifbbox=False):
|
||||
"""
|
||||
Unpack a single iteration of dataset.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): a single element iterated from dataset object.
|
||||
ifbbox (bool): whether to preprocess bboxes in self._transform_data.
|
||||
|
||||
Returns:
|
||||
Tuple, a unified Tuple contains image_data, labels, and bounding boxes.
|
||||
"""
|
||||
if len(next_element) == 3:
|
||||
inputs, labels, bboxes = next_element
|
||||
elif len(next_element) == 2:
|
||||
inputs, labels = next_element
|
||||
bboxes = None
|
||||
else:
|
||||
inputs = next_element[0]
|
||||
labels = [[] for x in inputs]
|
||||
bboxes = None
|
||||
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
|
||||
return inputs, labels, bboxes
|
||||
|
||||
@staticmethod
|
||||
def _make_label_batch(labels):
|
||||
"""
|
||||
Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
|
||||
length of all the rows in labels.
|
||||
|
||||
Args:
|
||||
labels (List[List]): the union labels of a data batch.
|
||||
|
||||
Returns:
|
||||
2D Tensor.
|
||||
"""
|
||||
|
||||
max_len = max([len(l) for l in labels])
|
||||
batch_labels = np.zeros((len(labels), max_len))
|
||||
|
||||
for idx, _ in enumerate(batch_labels):
|
||||
length = len(labels[idx])
|
||||
batch_labels[idx, :length] = np.array(labels[idx])
|
||||
|
||||
return ms.Tensor(batch_labels, ms.int32)
|
||||
|
||||
def _run_inference(self, dataset, summary, threshod=0.5):
|
||||
"""
|
||||
Run inference for the dataset and write the inference related data into summary.
|
||||
|
||||
Args:
|
||||
dataset (`ds`): the parsed dataset
|
||||
summary (`SummaryRecord`): the summary object to store the data
|
||||
threshold (float): the threshold for prediction.
|
||||
|
||||
Returns:
|
||||
imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
|
||||
"""
|
||||
imageid_labels = {}
|
||||
ds.config.set_seed(58)
|
||||
self._count = 0
|
||||
for j, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
prob = self._model(inputs).asnumpy()
|
||||
for idx, inp in enumerate(inputs):
|
||||
gt_labels = labels[idx]
|
||||
gt_probs = [float(prob[idx][i]) for i in gt_labels]
|
||||
|
||||
data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
|
||||
_, _, _, image_string = _make_image(_normalize(data_np))
|
||||
|
||||
predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]]
|
||||
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
|
||||
|
||||
union_labs = list(set(gt_labels + predicted_labels))
|
||||
imageid_labels[str(self._count)] = union_labs
|
||||
|
||||
explain = Explain()
|
||||
explain.image_id = str(self._count)
|
||||
explain.image_data = image_string
|
||||
summary.add_value("explainer", "image", explain)
|
||||
|
||||
explain = Explain()
|
||||
explain.image_id = str(self._count)
|
||||
explain.ground_truth_label.extend(gt_labels)
|
||||
explain.inference.ground_truth_prob.extend(gt_probs)
|
||||
explain.inference.predicted_label.extend(predicted_labels)
|
||||
explain.inference.predicted_prob.extend(predicted_probs)
|
||||
summary.add_value("explainer", "inference", explain)
|
||||
|
||||
summary.record(1)
|
||||
|
||||
self._count += 1
|
||||
print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now))
|
||||
return imageid_labels
|
||||
|
||||
def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
|
||||
"""
|
||||
Run the explanation for each step and write explanation results into summary.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): data of one step
|
||||
explainer (_Attribution): an Attribution object to generate saliency maps.
|
||||
imageid_labels (dict): a dict that maps the image_id and its union labels.
|
||||
summary (SummaryRecord): the summary object to store the data
|
||||
|
||||
Returns:
|
||||
List of dict that maps label to its corresponding saliency map.
|
||||
"""
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
count = self._count
|
||||
unions = []
|
||||
for _ in range(len(labels)):
|
||||
unions_labels = imageid_labels[str(count)]
|
||||
unions.append(unions_labels)
|
||||
count += 1
|
||||
|
||||
batch_unions = self._make_label_batch(unions)
|
||||
saliency_dict_lst = []
|
||||
|
||||
batch_saliency_full = []
|
||||
for i in range(len(batch_unions[0])):
|
||||
batch_saliency = explainer(inputs, batch_unions[:, i])
|
||||
batch_saliency_full.append(batch_saliency)
|
||||
|
||||
for idx, union in enumerate(unions):
|
||||
saliency_dict = {}
|
||||
explain = Explain()
|
||||
explain.image_id = str(self._count)
|
||||
for k, lab in enumerate(union):
|
||||
saliency = batch_saliency_full[k][idx:idx + 1]
|
||||
|
||||
saliency_dict[lab] = saliency
|
||||
|
||||
saliency_np = _make_rgba(saliency)
|
||||
_, _, _, saliency_string = _make_image(_normalize(saliency_np))
|
||||
|
||||
explanation = explain.explanation.add()
|
||||
explanation.explain_method = explainer.__class__.__name__
|
||||
|
||||
explanation.label = lab
|
||||
explanation.heatmap = saliency_string
|
||||
|
||||
summary.add_value("explainer", "explanation", explain)
|
||||
summary.record(1)
|
||||
|
||||
self._count += 1
|
||||
saliency_dict_lst.append(saliency_dict)
|
||||
return saliency_dict_lst
|
||||
|
||||
def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
|
||||
"""
|
||||
Run the explanation and evaluation for each step and write explanation results into summary.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): Data of one step
|
||||
explainer (`_Attribution`): An Attribution object to generate saliency maps.
|
||||
imageid_labels (dict): A dict that maps the image_id and its union labels.
|
||||
"""
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
for idx, inp in enumerate(inputs):
|
||||
inp = _EXPAND_DIMS(inp, 0)
|
||||
saliency_dict = saliency_dict_lst[idx]
|
||||
for label, saliency in saliency_dict.items():
|
||||
if isinstance(benchmarker, Localization):
|
||||
_, _, bboxes = self._unpack_next_element(next_element, True)
|
||||
if label in labels[idx]:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
|
||||
saliency=saliency)
|
||||
benchmarker.aggregate(res, label)
|
||||
else:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
|
||||
benchmarker.aggregate(res, label)
|
|
@ -0,0 +1,285 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for MindExplain"""
|
||||
|
||||
__all__ = [
|
||||
'ForwardProbe',
|
||||
'calc_auc',
|
||||
'calc_correlation',
|
||||
'format_tensor_to_ndarray',
|
||||
'generate_one_hot',
|
||||
'rank_pixels',
|
||||
'resize',
|
||||
'retrieve_layer_by_name',
|
||||
'retrieve_layer',
|
||||
'unify_inputs',
|
||||
'unify_targets'
|
||||
]
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as op
|
||||
|
||||
_Array = np.ndarray
|
||||
_Module = nn.Cell
|
||||
_Tensor = ms.Tensor
|
||||
|
||||
|
||||
def generate_one_hot(indices, depth):
|
||||
r"""
|
||||
Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0
|
||||
and 0.0.
|
||||
"""
|
||||
on_value = ms.Tensor(1.0, ms.float32)
|
||||
off_value = ms.Tensor(0.0, ms.float32)
|
||||
weights = op.OneHot()(indices, depth, on_value, off_value)
|
||||
return weights
|
||||
|
||||
|
||||
def unify_inputs(inputs) -> tuple:
|
||||
"""Unify inputs of explainer."""
|
||||
if isinstance(inputs, tuple):
|
||||
return inputs
|
||||
if isinstance(inputs, ms.Tensor):
|
||||
inputs = (inputs,)
|
||||
elif isinstance(inputs, np.ndarray):
|
||||
inputs = (ms.Tensor(inputs),)
|
||||
else:
|
||||
raise TypeError(
|
||||
'inputs must be one of [tuple, ms.Tensor or np.ndarray], '
|
||||
'but get {}'.format(type(inputs)))
|
||||
return inputs
|
||||
|
||||
|
||||
def unify_targets(targets) -> ms.Tensor:
|
||||
"""Unify targets labels of explainer."""
|
||||
if isinstance(targets, ms.Tensor):
|
||||
return targets
|
||||
if isinstance(targets, list):
|
||||
targets = ms.Tensor(targets, dtype=ms.int32)
|
||||
if isinstance(targets, int):
|
||||
targets = ms.Tensor([targets], dtype=ms.int32)
|
||||
else:
|
||||
raise TypeError(
|
||||
'targets must be one of [int, list or ms.Tensor], '
|
||||
'but get {}'.format(type(targets)))
|
||||
return targets
|
||||
|
||||
|
||||
def retrieve_layer_by_name(model: _Module, layer_name: str):
|
||||
"""
|
||||
Retrieve the layer in the model by the given layer_name.
|
||||
|
||||
Args:
|
||||
model (_Module): model which contains the target layer
|
||||
layer_name (str): name of target layer
|
||||
|
||||
Return:
|
||||
- target_layer (_Module)
|
||||
|
||||
Raise:
|
||||
ValueError: is module with given layer_name is not found in the model,
|
||||
raise ValueError.
|
||||
|
||||
"""
|
||||
if not isinstance(layer_name, str):
|
||||
raise TypeError('layer_name should be type of str, but receive {}.'
|
||||
.format(type(layer_name)))
|
||||
|
||||
if not layer_name:
|
||||
return model
|
||||
|
||||
target_layer = None
|
||||
for name, cell in model.cells_and_names():
|
||||
if name == layer_name:
|
||||
target_layer = cell
|
||||
return target_layer
|
||||
|
||||
if target_layer is None:
|
||||
raise ValueError(
|
||||
'Cannot match {}, please provide target layer'
|
||||
'in the given model.'.format(layer_name))
|
||||
return None
|
||||
|
||||
|
||||
def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''):
|
||||
"""
|
||||
Retrieve the layer in the model.
|
||||
|
||||
'target' can be either a layer name or a Cell object. Given the layer name,
|
||||
the method will search thourgh the model and return the matched layer. If a
|
||||
Cell object is provided, it will check whether the given layer exists
|
||||
in the model. If target layer is not found in the model, ValueError will
|
||||
be raised.
|
||||
|
||||
Args:
|
||||
model (_Module): the model to retrieve the target layer
|
||||
target_layer (Union[str, _Module]): target layer to retrieve. Can be
|
||||
either string (layer name) or the Cell object. If '' is provided,
|
||||
the input model will be returned.
|
||||
|
||||
Return:
|
||||
target layer (_Module)
|
||||
"""
|
||||
if isinstance(target_layer, str):
|
||||
target_layer = retrieve_layer_by_name(model, target_layer)
|
||||
return target_layer
|
||||
|
||||
if isinstance(target_layer, _Module):
|
||||
for _, cell in model.cells_and_names():
|
||||
if target_layer is cell:
|
||||
return target_layer
|
||||
raise ValueError(
|
||||
'Model not contain cell {}, fail to probe.'.format(target_layer)
|
||||
)
|
||||
raise TypeError('layer_name must have type of str or ms.nn.Cell,'
|
||||
'but receive {}'.format(type(target_layer)))
|
||||
|
||||
|
||||
class ForwardProbe:
|
||||
"""
|
||||
Probe to capture output of specific layer in a given model.
|
||||
|
||||
Args:
|
||||
target_layer (_Module): name of target layer or just provide the
|
||||
target layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, target_layer: _Module):
|
||||
self._target_layer = target_layer
|
||||
self._original_construct = self._target_layer.construct
|
||||
self._intermediate_tensor = None
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._intermediate_tensor
|
||||
|
||||
def __enter__(self):
|
||||
self._target_layer.construct = self._new_construct
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
self._target_layer.construct = self._original_construct
|
||||
self._intermediate_tensor = None
|
||||
return False
|
||||
|
||||
def _new_construct(self, *inputs):
|
||||
outputs = self._original_construct(*inputs)
|
||||
self._intermediate_tensor = outputs
|
||||
return outputs
|
||||
|
||||
|
||||
def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray:
|
||||
"""Unify `mindspore.Tensor` and `np.ndarray` to `np.ndarray`. """
|
||||
if isinstance(x, ms.Tensor):
|
||||
x = x.asnumpy()
|
||||
|
||||
if not isinstance(x, np.ndarray):
|
||||
raise TypeError('input should be one of [ms.Tensor or np.ndarray],'
|
||||
' but receive {}'.format(type(x)))
|
||||
return x
|
||||
|
||||
|
||||
def calc_correlation(x: Union[ms.Tensor, np.ndarray],
|
||||
y: Union[ms.Tensor, np.ndarray]) -> float:
|
||||
"""Calculate Pearson correlation coefficient between two arrays. """
|
||||
x = format_tensor_to_ndarray(x)
|
||||
y = format_tensor_to_ndarray(y)
|
||||
faithfulness = -np.corrcoef(x, y)[0, 1]
|
||||
|
||||
return faithfulness
|
||||
|
||||
|
||||
def calc_auc(x: _Array) -> float:
|
||||
"""Calculate the Aera under Curve."""
|
||||
# take mean for multiple patches if the model is fully convolutional model
|
||||
if len(x.shape) == 4:
|
||||
x = np.mean(np.mean(x, axis=2), axis=3)
|
||||
|
||||
auc = (x.sum() - x[0] - x[-1]) / len(x)
|
||||
return float(auc)
|
||||
|
||||
|
||||
def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
|
||||
"""
|
||||
Generate rank order fo every pixel in an 2D array.
|
||||
|
||||
The rank order start from 0 to (num_pixel-1). If descending is True, the
|
||||
rank order will generate in a descending order, otherwise in ascending
|
||||
order.
|
||||
|
||||
Example:
|
||||
x = np.array([[4., 3., 1.], [5., 9., 1.]])
|
||||
rank_pixels(x, descending=True)
|
||||
>> np.array([[2, 3, 4], [1, 0, 5]])
|
||||
rank_pixels(x, descending=False)
|
||||
>> np.array([[3, 2, 0], [4, 5, 1]])
|
||||
|
||||
"""
|
||||
if len(inputs.shape) != 2:
|
||||
raise ValueError('Only support 2D array currently')
|
||||
flatten_saliency = inputs.reshape(-1)
|
||||
factor = -1 if descending else 1
|
||||
sorted_arg = np.argsort(factor * flatten_saliency, axis=0)
|
||||
flatten_rank = np.zeros_like(sorted_arg)
|
||||
flatten_rank[sorted_arg] = np.arange(0, sorted_arg.shape[0])
|
||||
rank_map = flatten_rank.reshape(inputs.shape)
|
||||
return rank_map
|
||||
|
||||
|
||||
def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor:
|
||||
"""
|
||||
Resize the intermediate layer _attribution to the same size as inputs.
|
||||
|
||||
Args:
|
||||
inputs (ms.Tensor): the input tensor to be resized
|
||||
size (tupleint]): the targeted size resize to
|
||||
mode (str): the resize mode. Options: 'nearest_neighbor', 'bilinear'
|
||||
|
||||
Returns:
|
||||
outputs (ms.Tensor): the resized tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: the resize mode is not in ['nearest_neighbor',
|
||||
'bilinear'].
|
||||
"""
|
||||
h, w = size
|
||||
if mode == 'nearest_neighbor':
|
||||
resize_nn = op.ResizeNearestNeighbor((h, w))
|
||||
outputs = resize_nn(inputs)
|
||||
|
||||
elif mode == 'bilinear':
|
||||
inputs_np = inputs.asnumpy()
|
||||
inputs_np = np.transpose(inputs_np, [0, 2, 3, 1])
|
||||
array_lst = []
|
||||
for inp in inputs_np:
|
||||
array = (np.repeat(inp, 3, axis=2) * 255).astype(np.uint8)
|
||||
image = Image.fromarray(array)
|
||||
image = image.resize(size, resample=Image.BILINEAR)
|
||||
array = np.asarray(image).astype(np.float32) / 255
|
||||
array_lst.append(array[:, :, 0:1])
|
||||
|
||||
resized_np = np.transpose(array_lst, [0, 3, 1, 2])
|
||||
outputs = ms.Tensor(resized_np, inputs.dtype)
|
||||
else:
|
||||
raise ValueError('Unsupported resize mode {}'.format(mode))
|
||||
|
||||
return outputs
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Predefined XAI metrics."""
|
||||
|
||||
from ._attribution.faithfulness import Faithfulness
|
||||
from ._attribution.localization import Localization
|
||||
|
||||
__all__ = [
|
||||
"Faithfulness",
|
||||
"Localization"
|
||||
]
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Predefined XAI metrics"""
|
||||
|
||||
from .faithfulness import Faithfulness
|
||||
from .localization import Localization
|
||||
|
||||
__all__ = [
|
||||
"Faithfulness",
|
||||
"Localization"
|
||||
]
|
|
@ -0,0 +1,593 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Faithfulness"""
|
||||
import math
|
||||
from typing import Callable, Optional, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as op
|
||||
from .metric import AttributionMetric
|
||||
from ..._utils import calc_correlation, calc_auc, format_tensor_to_ndarray, rank_pixels
|
||||
from ...explanation._attribution._attribution import Attribution as _Attribution
|
||||
|
||||
_Array = np.ndarray
|
||||
_Explainer = Union[_Attribution, Callable]
|
||||
_Label = Union[int, ms.Tensor]
|
||||
_Module = nn.Cell
|
||||
|
||||
|
||||
def _calc_feature_importance(saliency: _Array, masks: _Array) -> _Array:
|
||||
"""Calculate feature important w.r.t given masks."""
|
||||
feature_importance = []
|
||||
num_perturbations = masks.shape[0]
|
||||
for i in range(num_perturbations):
|
||||
patch_feature_importance = saliency[masks[i]].sum() / masks[i].sum()
|
||||
feature_importance.append(patch_feature_importance)
|
||||
feature_importance = np.array(feature_importance, dtype=np.float32)
|
||||
return feature_importance
|
||||
|
||||
|
||||
class _BaseReplacement:
|
||||
"""
|
||||
Base class of generator for generating different replacement for perturbations.
|
||||
|
||||
Args:
|
||||
kwargs: Optional args for generating replacement. Derived class need to
|
||||
add necessary arg names and default value to '_necessary_args'.
|
||||
If the argument has no default value, the value should be set to
|
||||
'EMPTY' to mark the required args. Initializing an object will
|
||||
check the given kwargs w.r.t '_necessary_args'.
|
||||
|
||||
Raise:
|
||||
ValueError: Raise when provided kwargs not contain necessary arg names with 'EMPTY' mark.
|
||||
"""
|
||||
_necessary_args = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._replace_args = self._necessary_args.copy()
|
||||
for key, value in self._replace_args.items():
|
||||
if key in kwargs.keys():
|
||||
self._replace_args[key] = kwargs[key]
|
||||
elif key not in kwargs.keys() and value == 'EMPTY':
|
||||
raise ValueError(f"Missing keyword arg {key} for {self.__class__.__name__}.")
|
||||
|
||||
__call__: Callable
|
||||
"""
|
||||
Generate replacement for perturbations. Derived class should overwrite this
|
||||
function to generate different replacement for perturbing.
|
||||
|
||||
Args:
|
||||
inputs (_Array): Array to be perturb.
|
||||
|
||||
Returns:
|
||||
- replacement (_Array): Array to provide alternative pixels for every
|
||||
position in the given
|
||||
inputs. The returned array should have same shape as inputs.
|
||||
"""
|
||||
|
||||
|
||||
class Constant(_BaseReplacement):
|
||||
""" Generator to provide constant-value replacement for perturbations """
|
||||
_necessary_args = {'base_value': 'EMPTY'}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
replacement = np.ones_like(inputs, dtype=np.float32)
|
||||
replacement *= self._replace_args['base_value']
|
||||
return replacement
|
||||
|
||||
|
||||
class GaussianBlur(_BaseReplacement):
|
||||
""" Generator to provided gaussian blurred inputs for perturbation. """
|
||||
_necessary_args = {'sigma': 0.7}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
sigma = self._replace_args['sigma']
|
||||
replacement = gaussian_filter(inputs, sigma=sigma)
|
||||
return replacement
|
||||
|
||||
|
||||
class Perturb:
|
||||
"""
|
||||
Perturbation generator to generate perturbations for a given array.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_mode (str): specify perturbing mode, through deleting or
|
||||
inserting pixels. Current support: ['Deletion', 'Insertion'].
|
||||
is_accumulate (bool): whether to accumulate the former perturbations to
|
||||
the later perturbations.
|
||||
perturb_pixel_per_step (int, optional): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (int, optional): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perturb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_mode: str,
|
||||
is_accumulate: bool,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None):
|
||||
self._perturb_percent = perturb_percent
|
||||
self._perturb_mode = perturb_mode
|
||||
self._pixel_per_step = perturb_pixel_per_step
|
||||
self._num_perturbations = num_perturbations
|
||||
self._is_accumulate = is_accumulate
|
||||
|
||||
@staticmethod
|
||||
def _assign(x: _Array, y: _Array, masks: _Array):
|
||||
"""Assign values to perturb pixels on perturbations."""
|
||||
if masks.dtype != bool:
|
||||
raise TypeError('The param "masks" should be an array of bool, but receive {}'
|
||||
.format(masks.dtype))
|
||||
for i in range(x.shape[0]):
|
||||
x[i][:, masks[i]] = y[:, masks[i]]
|
||||
|
||||
def _generate_mask(self, saliency_rank: _Array) -> _Array:
|
||||
"""Generate mask for perturbations based on given saliency ranks."""
|
||||
if len(saliency_rank.shape) != 2:
|
||||
raise ValueError(f'The param "saliency_rank" should be 2-dim, but receive {len(saliency_rank.shape)}.')
|
||||
|
||||
num_pixels = saliency_rank.shape[0] * saliency_rank.shape[1]
|
||||
if self._pixel_per_step:
|
||||
pixel_per_step = self._pixel_per_step
|
||||
num_perturbations = math.floor(
|
||||
num_pixels * self._perturb_percent / self._pixel_per_step)
|
||||
elif self._num_perturbations:
|
||||
pixel_per_step = math.floor(
|
||||
num_pixels * self._perturb_percent / self._num_perturbations)
|
||||
num_perturbations = self._num_perturbations
|
||||
else:
|
||||
raise ValueError("Must provide either pixel_per_step or num_perturbations.")
|
||||
|
||||
masks = np.zeros(
|
||||
(num_perturbations, saliency_rank.shape[0], saliency_rank.shape[1]),
|
||||
dtype=np.bool)
|
||||
low_bound = 0
|
||||
up_bound = low_bound + pixel_per_step
|
||||
factor = 0 if self._is_accumulate else 1
|
||||
|
||||
for i in range(num_perturbations):
|
||||
masks[i, ((saliency_rank >= low_bound)
|
||||
& (saliency_rank < up_bound))] = True
|
||||
low_bound = up_bound * factor
|
||||
up_bound += pixel_per_step
|
||||
|
||||
if len(masks.shape) == 3:
|
||||
return masks
|
||||
raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect 3-dim.')
|
||||
|
||||
def __call__(self,
|
||||
inputs: _Array,
|
||||
saliency: _Array,
|
||||
reference: _Array,
|
||||
return_mask: bool = False,
|
||||
) -> Union[_Array, Tuple[_Array, ...]]:
|
||||
"""
|
||||
Generate perturbations of given array.
|
||||
|
||||
Args:
|
||||
inputs (_Array): input array to perturb
|
||||
saliency (_Array): saliency map
|
||||
return_mask (bool): whether return the mask for generating
|
||||
the perturbation. The mask can be used to calculate
|
||||
average feature importance of pixels perturbed at each step.
|
||||
|
||||
Return:
|
||||
perturbations (_Array)
|
||||
masks (_Array): return when return_mask is set to True.
|
||||
"""
|
||||
if not np.array_equal(inputs.shape, reference.shape):
|
||||
raise ValueError('reference must have the same shape as inputs.')
|
||||
|
||||
saliency_rank = rank_pixels(saliency, descending=True)
|
||||
masks = self._generate_mask(saliency_rank)
|
||||
num_perturbations = masks.shape[0]
|
||||
|
||||
if self._perturb_mode == 'Insertion':
|
||||
inputs, reference = reference, inputs
|
||||
|
||||
perturbations = np.tile(
|
||||
inputs, (num_perturbations, *[1] * len(inputs.shape)))
|
||||
|
||||
Perturb._assign(perturbations, reference, masks)
|
||||
|
||||
if return_mask:
|
||||
return perturbations, masks
|
||||
return perturbations
|
||||
|
||||
|
||||
class _FaithfulnessHelper:
|
||||
"""Base class for faithfulness calculator."""
|
||||
_support = [Constant, GaussianBlur]
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_mode: str,
|
||||
perturb_method: str,
|
||||
is_accumulate: bool,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
|
||||
self._get_reference = None
|
||||
for method in self._support:
|
||||
if perturb_method == method.__name__:
|
||||
self._get_reference = method(**kwargs)
|
||||
if self._get_reference is None:
|
||||
raise ValueError(
|
||||
'The param "perturb_method" should be one of {}.'.format([x.__name__ for x in self._support]))
|
||||
|
||||
self._perturb = Perturb(perturb_percent=perturb_percent,
|
||||
perturb_mode=perturb_mode,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=is_accumulate)
|
||||
|
||||
calc_faithfulness: Callable
|
||||
"""
|
||||
Method used to calculate faithfulness for given inputs, target label,
|
||||
saliency. Derive class should implement this method.
|
||||
|
||||
Args:
|
||||
inputs (_Array): sample to calculate faithfulness score
|
||||
model (_Module): model to explanation
|
||||
targets (_Label): label to explanation on.
|
||||
saliency (_Array): Saliency map of given inputs and targets from the
|
||||
explainer.
|
||||
|
||||
Return:
|
||||
- faithfulness (float): faithfulness score
|
||||
"""
|
||||
|
||||
|
||||
class NaiveFaithfulness(_FaithfulnessHelper):
|
||||
"""
|
||||
Calculator for naive faithfulness.
|
||||
|
||||
Naive faithfulness, the metric replace several pixels on original image by
|
||||
specific method for each perturbations. The metric predicts on the perturbed
|
||||
images and record a series of probabilities. Then calculates the
|
||||
correlation between prob distribution and averaged feature importance.
|
||||
Higher correlation indicates better faithfulness.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_method (str): specify the method to replace the pixel.
|
||||
Current support: ['Constant', 'GaussianBlur']
|
||||
is_accumulate (bool): whether to accumulate the former perturbations to
|
||||
the later perturbations.
|
||||
Default: False.
|
||||
perturb_pixel_per_step (Optional[int]): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (Optional[int]): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perturb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
kwargs: specific perturb_method will require
|
||||
different arguments. Below lists required args for each method.
|
||||
|
||||
'Constant': base_value (int)
|
||||
'GaussianBlur': sigma (float): 0.7
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_method: str,
|
||||
is_accumulate: bool = False,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
super(NaiveFaithfulness, self).__init__(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_mode='Deletion',
|
||||
perturb_method=perturb_method,
|
||||
is_accumulate=is_accumulate,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
**kwargs)
|
||||
|
||||
def calc_faithfulness(self,
|
||||
inputs: _Array,
|
||||
model: _Module,
|
||||
targets: _Label,
|
||||
saliency: _Array) -> np.ndarray:
|
||||
"""
|
||||
Calculate naive faithfulness.
|
||||
|
||||
Args:
|
||||
inputs (_Array): sample to calculate faithfulness score
|
||||
model (_Module): model to explanation
|
||||
targets (_Label): label to explanation on.
|
||||
saliency (_Array): Saliency map of given inputs and targets from the
|
||||
explainer.
|
||||
|
||||
Return:
|
||||
- faithfulness (np.ndarray): faithfulness score
|
||||
|
||||
"""
|
||||
reference = self._get_reference(inputs)
|
||||
perturbations, masks = self._perturb(
|
||||
inputs, saliency, reference, return_mask=True)
|
||||
feature_importance = _calc_feature_importance(saliency, masks)
|
||||
|
||||
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
|
||||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
faithfulness = calc_correlation(feature_importance, predictions)
|
||||
normalized_faithfulness = (faithfulness + 1) / 2
|
||||
return np.array([normalized_faithfulness], np.float)
|
||||
|
||||
|
||||
class DeletionAUC(_FaithfulnessHelper):
|
||||
""" Calculator for deletion AUC.
|
||||
|
||||
For Deletion AUC, the metric accumulative replace pixels on origin
|
||||
images through specific 'perturb_method', predict on the perturbed images
|
||||
and record series of probabilities. The metric then calculates the AUC of
|
||||
the probability variation curve during perturbations. Faithfulness is define
|
||||
as (1 - deletion_AUC). Higher score indicates better faithfulness of
|
||||
explanation.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_method (str): specify the method to replace the pixel.
|
||||
Current support: ['Constant', 'GaussianBlur']
|
||||
perturb_pixel_per_step (Optional[int]): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (Optional[int]): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perterb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
kwargs: specific perturb_method will require
|
||||
different arguments. Below lists required args for each method.
|
||||
|
||||
'Constant': base_value (int)
|
||||
'GaussianBlur': sigma (float): 0.7
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_method: str,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
super(DeletionAUC, self).__init__(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_mode='Deletion',
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=True,
|
||||
**kwargs)
|
||||
|
||||
def calc_faithfulness(self,
|
||||
inputs: _Array,
|
||||
model: _Module,
|
||||
targets: _Label,
|
||||
saliency: _Array) -> np.ndarray:
|
||||
"""
|
||||
Calculate faithfulness through deletion AUC.
|
||||
|
||||
Args:
|
||||
inputs (_Array): sample to calculate faithfulness score
|
||||
model (_Module): model to explanation
|
||||
targets (_Label): label to explanation on.
|
||||
saliency (_Array): Saliency map of given inputs and targets from the
|
||||
explainer.
|
||||
|
||||
Return:
|
||||
- faithfulness (float): faithfulness score
|
||||
|
||||
"""
|
||||
reference = self._get_reference(inputs)
|
||||
perturbations = self._perturb(inputs, saliency, reference)
|
||||
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
|
||||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
input_tensor = op.ExpandDims()(ms.Tensor(inputs, ms.float32), 0)
|
||||
original_output = model(input_tensor).asnumpy()[:, targets]
|
||||
|
||||
auc = calc_auc(original_output - predictions)
|
||||
return np.array([1 - auc])
|
||||
|
||||
|
||||
class InsertionAUC(_FaithfulnessHelper):
|
||||
""" Calculator for insertion AUC.
|
||||
|
||||
For Insertion AUC, the metric accumulative replace pixels of reference
|
||||
image by pixels from origin image, like inserting pixel from origin image to
|
||||
reference. The reference if generated through specific 'perturb_method'.
|
||||
The metric predicts on the perturbed images and records series of
|
||||
probabilities. The metric then calculates the AUC of the probability
|
||||
variation curve during perturbations. Faithfulness is define as (1 -
|
||||
deletion_AUC). Higher score indicates better faithfulness of explanation.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_method (str): specify the method to replace the pixel.
|
||||
Current support: ['Constant', 'GaussianBlur']
|
||||
perturb_pixel_per_step (Optional[int]): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (Optional[int]): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perterb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
kwargs: specific perturb_method will require
|
||||
different arguments. Below lists required args for each method.
|
||||
|
||||
'Constant': base_value (int)
|
||||
'GaussianBlur': sigma (float): 0.7
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_method: str,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
super(InsertionAUC, self).__init__(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_mode='Insertion',
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=True,
|
||||
**kwargs)
|
||||
|
||||
def calc_faithfulness(self,
|
||||
inputs: _Array,
|
||||
model: _Module,
|
||||
targets: _Label,
|
||||
saliency: _Array) -> np.ndarray:
|
||||
"""
|
||||
Calculate faithfulness through insertion AUC.
|
||||
|
||||
Args:
|
||||
inputs (_Array): sample to calculate faithfulness score
|
||||
model (_Module): model to explanation
|
||||
targets (_Label): label to explanation on.
|
||||
saliency (_Array): Saliency map of given inputs and targets from the
|
||||
explainer.
|
||||
|
||||
Return:
|
||||
- faithfulness (float): faithfulness score
|
||||
|
||||
"""
|
||||
reference = self._get_reference(inputs)
|
||||
perturbations = self._perturb(inputs, saliency, reference)
|
||||
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
|
||||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
base_tensor = op.ExpandDims()(ms.Tensor(reference, ms.float32), 0)
|
||||
base_outputs = model(base_tensor).asnumpy()[:, targets]
|
||||
|
||||
auc = calc_auc(predictions - base_outputs)
|
||||
return np.array([auc])
|
||||
|
||||
|
||||
class Faithfulness(AttributionMetric):
|
||||
"""
|
||||
Provides evaluation on faithfulness on XAI explanations.
|
||||
|
||||
Faithfulness first generate saliency map with given explainers and calculate faithfulness based on different
|
||||
faithfulness metric.
|
||||
|
||||
Args:
|
||||
num_labels (int): number of labels
|
||||
metric (str): the specifi metric to quantify faithfulness.
|
||||
Options: 'DeletionAUC', 'InsertionAUC', 'NaiveFaithfulness'.
|
||||
Default: 'NaiveFaithfulness'.
|
||||
|
||||
Examples:
|
||||
>>> # init a `Faithfulness` object
|
||||
>>> num_labels = 10
|
||||
>>> metric = "InsertionAUC"
|
||||
>>> faithfulness = Faithfulness(num_labels, metric)
|
||||
"""
|
||||
_methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC]
|
||||
|
||||
def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness"):
|
||||
super(Faithfulness, self).__init__(num_labels)
|
||||
|
||||
perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument
|
||||
perturb_method = "Constant" # perturbation method, all the perturbed pixels will be set to constant
|
||||
num_perturb_pixel_per_step = None # number of pixels for each perturbation step
|
||||
num_perturb_steps = 100 # separate the perturbation progress in to 100 steps.
|
||||
base_value = 0.0 # the pixel value set for the perturbed pixels
|
||||
|
||||
self._verify_metrics(metric)
|
||||
for method in self._methods:
|
||||
if metric == method.__name__:
|
||||
self._faithfulness_helper = method(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=num_perturb_pixel_per_step,
|
||||
num_perturbations=num_perturb_steps,
|
||||
base_value=base_value
|
||||
)
|
||||
|
||||
def evaluate(self, explainer, inputs, targets, saliency=None):
|
||||
"""
|
||||
Evaluate faithfulness on a single data sample.
|
||||
|
||||
Args:
|
||||
explainer (Explainer): A explainer instance object.
|
||||
The 'Explainer' object see mindspore/explainer/explanation.
|
||||
inputs (Tensor): data sample. Currently only support single sample at each call.
|
||||
targets (Union[int, Tensor]): A target label to evaluate on.
|
||||
saliency (Tensor): A saliency tensor.
|
||||
|
||||
Return:
|
||||
np.ndarray: result of faithfulness evaluated on explainer.
|
||||
|
||||
Notes:
|
||||
To apply `Faithfulness` to evaluate an explainer, this explainer must be initialize with a network that
|
||||
contains the output activation function. Otherwise, the results will not be correct.
|
||||
|
||||
Examples:
|
||||
>>> # init an explainer, the network should contain the output activation function.
|
||||
>>> network = nn.SequentialCell([resnet50, nn.Sigmoid()])
|
||||
>>> gradient = Gradient(network)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> targets = 5
|
||||
>>> # usage 1: input the explainer and the data to be explained,
|
||||
>>> # calculate the faithfulness with the specified metric
|
||||
>>> res = faithfulness.evaluate(gradient, inputs, targets)
|
||||
>>> # usage 2: input the generated saliency map
|
||||
>>> saliency = gradient(inputs, targets)
|
||||
>>> res = faithfulenss.evaluate(gradient, inputs, targets, saliency)
|
||||
"""
|
||||
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
|
||||
if saliency is None:
|
||||
saliency = explainer(inputs, targets)
|
||||
|
||||
inputs = format_tensor_to_ndarray(inputs)
|
||||
saliency = format_tensor_to_ndarray(saliency)
|
||||
|
||||
inputs = inputs.squeeze(axis=0)
|
||||
saliency = saliency.squeeze()
|
||||
if len(saliency.shape) != 2:
|
||||
raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape)))
|
||||
|
||||
faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=explainer.model,
|
||||
targets=targets, saliency=saliency)
|
||||
return faithfulness
|
||||
|
||||
def _verify_metrics(self, metric: str):
|
||||
supports = [x.__name__ for x in self._methods]
|
||||
if metric not in supports:
|
||||
raise ValueError("Metric should be one of {}.".format(supports))
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Localization metrics."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train._utils import check_value_type
|
||||
from .metric import AttributionMetric
|
||||
from ..._operators import maximum, reshape, Tensor
|
||||
from ..._utils import format_tensor_to_ndarray
|
||||
|
||||
|
||||
def _get_max_position(saliency):
|
||||
"""Get the position of the max pixel of the saliency map."""
|
||||
saliency = saliency.asnumpy()
|
||||
w = saliency.shape[3]
|
||||
saliency = np.reshape(saliency, (len(saliency), -1))
|
||||
max_arg = np.argmax(saliency, axis=1)
|
||||
return max_arg // w, max_arg - (max_arg // w) * w
|
||||
|
||||
|
||||
def _mask_out_saliency(saliency, threshold):
|
||||
"""Keep the saliency map with value greater than threshold."""
|
||||
max_value = maximum(saliency)
|
||||
mask_out = saliency > (reshape(max_value, (len(saliency), -1, 1, 1)) * threshold)
|
||||
return mask_out
|
||||
|
||||
|
||||
class Localization(AttributionMetric):
|
||||
"""
|
||||
Provides evaluation on the localization capability of XAI methods.
|
||||
|
||||
We support two metrics for the evaluation os localization capability: "PointingGame" and "IoSR".
|
||||
For metric "PointingGame", the localization capability is calculated as the ratio of data in which the max position
|
||||
of their saliency maps lies within the bounding boxes. Specifically, for a single datum, given the saliency map and
|
||||
its bounding box, if the max point of its saliency map lies within the bounding box, the evaluation result is 1
|
||||
otherwise 0.
|
||||
|
||||
For metric "IoSR" (Intersection over Salient Region), the localization capability is calculated as the intersection
|
||||
of the bounding box and the salient region over the area of the salient region.
|
||||
|
||||
Args:
|
||||
num_labels (int): number of classes in the dataset.
|
||||
metric (str): specific metric to calculate localization capability.
|
||||
Options: "PointingGame", "IoSR".
|
||||
Default: "PointingGame".
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer.benchmark import Localization
|
||||
>>> num_labels = 100
|
||||
>>> localization = Localization(num_labels, "PointingGame")
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_labels,
|
||||
metric="PointingGame"
|
||||
):
|
||||
super(Localization, self).__init__(num_labels)
|
||||
self._verify_metrics(metric)
|
||||
self._metric = metric
|
||||
|
||||
# Arg for specific metric, for "PointingGame" it should be an integer indicating the tolerance
|
||||
# of "PointingGame", while for "IoSR" it should be a float number
|
||||
# indicating the threshold to choose salient region. Default: 25.
|
||||
if self._metric == "PointingGame":
|
||||
self._metric_arg = 15
|
||||
else:
|
||||
self._metric_arg = 0.5
|
||||
|
||||
@staticmethod
|
||||
def _verify_metrics(metric):
|
||||
"""Verify the user defined metric."""
|
||||
supports = ["PointingGame", "IoSR"]
|
||||
if metric not in supports:
|
||||
raise ValueError("Metric should be one of {}".format(supports))
|
||||
|
||||
def evaluate(self, explainer, inputs, targets, saliency=None, mask=None):
|
||||
"""
|
||||
Evaluate localization on a single data sample.
|
||||
|
||||
Args:
|
||||
explainer (Explanation): The explainer to be evaluated, see `mindspore/explainer/explanation`.
|
||||
inputs (Tensor): data sample. Currently only support single sample at each call.
|
||||
targets (int): target label to evaluate on.
|
||||
saliency (Tensor): A saliency tensor.
|
||||
mask (Union[Tensor, np.ndarray]): ground truth bounding box/masks for the inputs w.r.t targets.
|
||||
|
||||
Returns:
|
||||
np.ndarray, result of localization evaluated on explainer
|
||||
|
||||
Examples:
|
||||
>>> # init an explainer, the network should contain the output activation function.
|
||||
>>> gradient = Gradient(network)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> masks = np.zeros(1, 1, 224, 224)
|
||||
>>> masks[:, :, 65: 100, 65: 100] = 1
|
||||
>>> targets = 5
|
||||
>>> # usage 1: input the explainer and the data to be explained,
|
||||
>>> # calculate the faithfulness with the specified metric
|
||||
>>> res = localization.evaluate(gradient, inputs, targets, mask=masks)
|
||||
>>> # usage 2: input the generated saliency map
|
||||
>>> saliency = gradient(inputs, targets)
|
||||
>>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks)
|
||||
"""
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
|
||||
mask_np = format_tensor_to_ndarray(mask)[0]
|
||||
|
||||
if saliency is None:
|
||||
saliency = explainer(inputs, targets)
|
||||
|
||||
if self._metric == "PointingGame":
|
||||
point = _get_max_position(saliency)
|
||||
|
||||
x, y = np.meshgrid(
|
||||
(np.arange(mask_np.shape[1]) - point[0]) ** 2,
|
||||
(np.arange(mask_np.shape[2]) - point[1]) ** 2)
|
||||
max_region = (x + y) < self._metric_arg ** 2
|
||||
|
||||
# if max_region has overlap with mask_np return 1 otherwise 0.
|
||||
result = 1 if (mask_np.astype(bool) & max_region).any() else 0
|
||||
|
||||
elif self._metric == "IoSR":
|
||||
mask_out = _mask_out_saliency(saliency, self._metric_arg)
|
||||
mask_out_np = format_tensor_to_ndarray(mask_out)
|
||||
overlap = np.sum(mask_np.astype(bool) & mask_out_np.astype(bool))
|
||||
saliency_area = np.sum(mask_out_np)
|
||||
result = overlap / saliency_area.clip(min=1e-10)
|
||||
return np.array([result], np.float)
|
||||
|
||||
def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask):
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
check_value_type('mask', mask, (Tensor, np.ndarray))
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument mask must be 4D Tensor')
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Base class for XAI metrics."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train._utils import check_value_type
|
||||
from ..._operators import Tensor
|
||||
from ..._utils import format_tensor_to_ndarray
|
||||
from ...explanation._attribution._attribution import Attribution
|
||||
|
||||
|
||||
def verify_argument(inputs, arg_name):
|
||||
"""Verify the validity of the parsed arguments."""
|
||||
check_value_type(arg_name, inputs, Tensor)
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument {} must be a 4D Tensor.'.format(arg_name))
|
||||
if len(inputs) > 1:
|
||||
raise ValueError('Support single data evaluation only, but got {}.'.format(len(inputs)))
|
||||
|
||||
|
||||
def verify_targets(targets, num_labels):
|
||||
"""Verify the validity of the parsed targets."""
|
||||
check_value_type('targets', targets, (int, Tensor))
|
||||
|
||||
if isinstance(targets, Tensor):
|
||||
if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != 1):
|
||||
raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
|
||||
'it should have the length = 1 as we only support single evaluation now.')
|
||||
targets = int(targets.asnumpy()[0]) if len(targets.shape) == 1 else int(targets.asnumpy())
|
||||
if targets > num_labels - 1 or targets < 0:
|
||||
raise ValueError('Parsed targets exceed the label range.')
|
||||
|
||||
|
||||
class AttributionMetric:
|
||||
"""Super class of XAI metric class used in classification scenarios."""
|
||||
|
||||
def __init__(self, num_labels=None):
|
||||
self._num_labels = num_labels
|
||||
self._global_results = {i: [] for i in range(num_labels)}
|
||||
|
||||
def evaluate(self, explainer, inputs, targets, saliency=None):
|
||||
"""This function evaluates on a single sample and return the result."""
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, result, targets):
|
||||
"""Aggregates single result to global_results."""
|
||||
if isinstance(result, float):
|
||||
if isinstance(targets, int):
|
||||
self._global_results[targets].append(result)
|
||||
else:
|
||||
target_np = format_tensor_to_ndarray(targets)
|
||||
if len(target_np) > 1:
|
||||
raise ValueError("One result can not be aggreated to multiple targets.")
|
||||
else:
|
||||
result_np = format_tensor_to_ndarray(result)
|
||||
if isinstance(targets, int):
|
||||
for res in result_np:
|
||||
self._global_results[targets].append(float(res))
|
||||
else:
|
||||
target_np = format_tensor_to_ndarray(targets)
|
||||
if len(target_np) != len(result_np):
|
||||
raise ValueError("Length of result does not match with length of targets.")
|
||||
for tar, res in zip(target_np, result_np):
|
||||
self._global_results[int(tar)].append(float(res))
|
||||
|
||||
def reset(self):
|
||||
"""Resets global_result."""
|
||||
self._global_results = {i: [] for i in range(self._num_labels)}
|
||||
|
||||
@property
|
||||
def class_performances(self):
|
||||
"""
|
||||
Get the class performances by global result.
|
||||
|
||||
|
||||
Returns:
|
||||
(:class:`np.ndarray`): :attr:`num_labels`-dimensional vector
|
||||
containing per-class performance.
|
||||
"""
|
||||
count = np.array(
|
||||
[len(self._global_results[i]) for i in range(self._num_labels)])
|
||||
result_sum = np.array(
|
||||
[sum(self._global_results[i]) for i in range(self._num_labels)])
|
||||
return result_sum / count.clip(min=1)
|
||||
|
||||
@property
|
||||
def performance(self):
|
||||
"""
|
||||
Get the performance by global result.
|
||||
|
||||
Returns:
|
||||
(:class:`float`): mean performance.
|
||||
"""
|
||||
count = sum(
|
||||
[len(self._global_results[i]) for i in range(self._num_labels)])
|
||||
result_sum = sum(
|
||||
[sum(self._global_results[i]) for i in range(self._num_labels)])
|
||||
if count == 0:
|
||||
return 0
|
||||
return result_sum / count
|
||||
|
||||
def get_results(self):
|
||||
"""Global result of the metric can be return"""
|
||||
return self._global_results
|
||||
|
||||
def _check_evaluate_param(self, explainer, inputs, targets, saliency):
|
||||
"""Check the evaluate parameters."""
|
||||
check_value_type('explainer', explainer, Attribution)
|
||||
verify_argument(inputs, 'inputs')
|
||||
verify_targets(targets, self._num_labels)
|
||||
check_value_type('saliency', saliency, (Tensor, type(None)))
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Predefined Attribution explainers."""
|
||||
|
||||
from ._attribution._backprop.gradcam import GradCAM
|
||||
from ._attribution._backprop.gradient import Gradient
|
||||
from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop
|
||||
|
||||
__all__ = [
|
||||
'Gradient',
|
||||
'Deconvolution',
|
||||
'GuidedBackprop',
|
||||
'GradCAM',
|
||||
]
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Predefined Attribution explainers."""
|
||||
from ._backprop.gradcam import GradCAM
|
||||
from ._backprop.gradient import Gradient
|
||||
from ._backprop.modified_relu import Deconvolution, GuidedBackprop
|
||||
|
||||
__all__ = [
|
||||
'Gradient',
|
||||
'Deconvolution',
|
||||
'GuidedBackprop',
|
||||
'GradCAM',
|
||||
]
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Attribution."""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import mindspore as ms
|
||||
|
||||
class Attribution:
|
||||
r"""
|
||||
Basic class of attributing the salient score
|
||||
|
||||
The explainers which explanation through attributing the relevance scores
|
||||
should inherit this class.
|
||||
|
||||
Args:
|
||||
network (ms.nn.Cell): The black-box model to explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
self._verify_model(network)
|
||||
self._model = network
|
||||
|
||||
@staticmethod
|
||||
def _verify_model(model):
|
||||
"""
|
||||
Verify the input `network` for __init__ function.
|
||||
"""
|
||||
if not isinstance(model, ms.nn.Cell):
|
||||
raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.")
|
||||
|
||||
|
||||
__call__: Callable
|
||||
"""
|
||||
The explainers return the explanations by calling directly on the explanation.
|
||||
Derived class should overwrite this implementations for different
|
||||
algorithms.
|
||||
|
||||
Args:
|
||||
input (ms.Tensor): Input tensor to be explained.
|
||||
|
||||
Returns:
|
||||
- saliency map (ms.Tensor): saliency map of the input.
|
||||
"""
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._model
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Backprop-base _attribution explainer."""
|
||||
|
||||
from .gradient import Gradient
|
||||
from .gradcam import GradCAM
|
||||
from .modified_relu import Deconvolution, GuidedBackprop
|
||||
|
||||
__all__ = ['Gradient',
|
||||
'GradCAM',
|
||||
'Deconvolution',
|
||||
'GuidedBackprop']
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing utility functions."""
|
||||
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
from ...._utils import unify_inputs, unify_targets, generate_one_hot
|
||||
|
||||
|
||||
def compute_gradients(model, inputs, targets=None, weights=None):
|
||||
r"""
|
||||
Compute the gradient of output w.r.t input.
|
||||
|
||||
Args:
|
||||
model (`ms.nn.Cell`): Differentiable black-box model.
|
||||
inputs (`ms.Tensor`): Input to calculate gradient and explanation.
|
||||
targets (int, optional): Target label id specifying which category to compute gradient. Default: None.
|
||||
weights (`ms.Tensor`, optional): Custom weights for computing gradients. The shape of weights should match the
|
||||
model outputs. If None is provided, an one-hot weights with one in targets positions will be used instead.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
saliency map (ms.Tensor): Gradient back-propagated to the input.
|
||||
"""
|
||||
inputs = unify_inputs(inputs)
|
||||
if targets is None and weights is None:
|
||||
raise ValueError('Must provide one of targets or weights')
|
||||
if weights is None:
|
||||
targets = unify_targets(targets)
|
||||
output = model(*inputs).asnumpy()
|
||||
num_categories = output.shape[-1]
|
||||
weights = generate_one_hot(targets, num_categories)
|
||||
|
||||
grad_op = GradOperation(
|
||||
get_all=True, get_by_list=False, sens_param=True)(model)
|
||||
gradients = grad_op(*inputs, weights)
|
||||
return gradients[0]
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" GradCAM and GuidedGradCAM. """
|
||||
|
||||
from mindspore.ops import operations as op
|
||||
|
||||
from .backprop_utils import compute_gradients
|
||||
from .intermediate_layer import IntermediateLayerAttribution
|
||||
from ...._utils import ForwardProbe, retrieve_layer, unify_inputs, unify_targets
|
||||
|
||||
|
||||
|
||||
def _gradcam_aggregation(attributions):
|
||||
"""
|
||||
Aggregate the gradient and activation to get the final _attribution.
|
||||
|
||||
Args:
|
||||
attributions (Tensor): the _attribution with channel dimension.
|
||||
|
||||
Returns:
|
||||
Tensor: the _attribution with channel dimension aggregated.
|
||||
"""
|
||||
sum_ = op.ReduceSum(keep_dims=True)
|
||||
relu_ = op.ReLU()
|
||||
attributions = relu_(sum_(attributions, 1))
|
||||
return attributions
|
||||
|
||||
|
||||
class GradCAM(IntermediateLayerAttribution):
|
||||
r"""
|
||||
Provides GradCAM explanation method.
|
||||
|
||||
GradCAM generates saliency map at intermediate layer.
|
||||
..math:
|
||||
\alpha_k^c = 1/Z \sum_i \sum_j \div{\partial{y^c}}{\partial{A_{i,j}^k}}
|
||||
L_{GradCAM} = ReLu(\sum_k \alpha_k^c A^k)
|
||||
For more details, please refer to the original paper: GradCAM
|
||||
[https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf]
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
layer (str): The layer name to generate the explanation at. Default: ''.
|
||||
If default, the explantion will be generated at the input layer.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||
>>> # you may also use the net itself.
|
||||
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer.
|
||||
>>> layer_name = '0.layer4'
|
||||
>>> # init GradCAM with a trained network and specify the layer to obtain
|
||||
>>> gradcam = GradCAM(net, layer=layer_name)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gradcam(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
network,
|
||||
layer=""):
|
||||
super(GradCAM, self).__init__(network, layer)
|
||||
|
||||
self._saliency_cell = retrieve_layer(self._backward_model, target_layer=layer)
|
||||
self._avgpool = op.ReduceMean(keep_dims=True)
|
||||
self._intermediate_grad = None
|
||||
self._aggregation_fn = _gradcam_aggregation
|
||||
self._resize_mode = 'bilinear'
|
||||
|
||||
def _hook_cell(self):
|
||||
if self._saliency_cell:
|
||||
self._saliency_cell.register_backward_hook(self._cell_hook_fn)
|
||||
self._saliency_cell.enable_hook = True
|
||||
self._intermediate_grad = None
|
||||
|
||||
def _cell_hook_fn(self, _, grad_input, grad_output):
|
||||
"""
|
||||
Hook function to deal with the backward gradient.
|
||||
|
||||
The arguments are set as required by Cell.register_back_hook
|
||||
"""
|
||||
self._intermediate_grad = grad_input
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
"""
|
||||
Call function for `GradCAM`.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input data to be explained, 4D Tensor.
|
||||
targets (Union[Tensor, int]): The label of interest. It should be a 1D or 0D Tensor, or an integer.
|
||||
If `targets` is a 1D Tensor, its length should be the same as `inputs`.
|
||||
"""
|
||||
self._verify_data(inputs, targets)
|
||||
self._hook_cell()
|
||||
|
||||
with ForwardProbe(self._saliency_cell) as probe:
|
||||
|
||||
inputs = unify_inputs(inputs)
|
||||
targets = unify_targets(targets)
|
||||
|
||||
gradients = compute_gradients(self._backward_model, *inputs, targets)
|
||||
|
||||
# get intermediate activation
|
||||
activation = (probe.value,)
|
||||
|
||||
if self._layer == "":
|
||||
activation = inputs
|
||||
self._intermediate_grad = unify_inputs(gradients)
|
||||
if self._intermediate_grad is not None:
|
||||
# average pooling on gradients
|
||||
intermediate_grad = unify_inputs(
|
||||
self._avgpool(self._intermediate_grad[0], (2, 3)))
|
||||
else:
|
||||
raise ValueError("Gradient for intermediate layer is not "
|
||||
"obtained")
|
||||
mul = op.Mul()
|
||||
attribution = self._aggregation_fn(
|
||||
mul(*intermediate_grad, *activation))
|
||||
if self._resize:
|
||||
attribution = self._resize_fn(attribution, *inputs,
|
||||
mode=self._resize_mode)
|
||||
self._intermediate_grad = None
|
||||
|
||||
return attribution
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Gradient explainer."""
|
||||
from copy import deepcopy
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as op
|
||||
from mindspore.train._utils import check_value_type
|
||||
from ...._operators import reshape, sqrt, Tensor
|
||||
from .._attribution import Attribution
|
||||
from .backprop_utils import compute_gradients
|
||||
from ...._utils import unify_inputs, unify_targets
|
||||
|
||||
|
||||
def _get_hook(bntype, cache):
|
||||
"""Provide backward hook function for BatchNorm layer in eval mode."""
|
||||
var, gamma, eps = cache
|
||||
if bntype == "2d":
|
||||
var = reshape(var, (1, -1, 1, 1))
|
||||
gamma = reshape(gamma, (1, -1, 1, 1))
|
||||
elif bntype == "1d":
|
||||
var = reshape(var, (1, -1, 1))
|
||||
gamma = reshape(gamma, (1, -1, 1))
|
||||
|
||||
def reset_gradient(_, grad_input, grad_output):
|
||||
grad_output = grad_input[0] * gamma / sqrt(var + eps)
|
||||
return grad_output
|
||||
|
||||
return reset_gradient
|
||||
|
||||
|
||||
def _abs_max(gradients):
|
||||
"""
|
||||
Transform gradients to saliency through abs then take max along
|
||||
channels.
|
||||
"""
|
||||
gradients = op.Abs()(gradients)
|
||||
saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1)
|
||||
return saliency
|
||||
|
||||
|
||||
class Gradient(Attribution):
|
||||
r"""
|
||||
Provides Gradient explanation method.
|
||||
|
||||
Gradient is the simplest attribution method which uses the naive gradients of outputs w.r.t inputs as the
|
||||
explanation.
|
||||
|
||||
.. math::
|
||||
_attribution = \div{\delta{y}, \delta{x}}
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||
>>> # you may also use the net itself. The saliency map might be slightly different for softmax activation.
|
||||
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> # init Gradient with a trained network.
|
||||
>>> gradient = Gradient(net)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gradient(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(Gradient, self).__init__(network)
|
||||
self._backward_model = deepcopy(network)
|
||||
self._backward_model.set_train(False)
|
||||
self._backward_model.set_grad(False)
|
||||
self._hook_bn()
|
||||
self._grad_op = compute_gradients
|
||||
self._aggregation_fn = _abs_max
|
||||
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
"""
|
||||
Call function for `Gradient`.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input data to be explained, 4D Tensor.
|
||||
targets (Union[Tensor, int]): The label of interest. It should be a 1D or 0D Tensor, or an integer.
|
||||
If `targets` is a 1D `Tensor`, its length should be the same as `inputs`.
|
||||
"""
|
||||
self._verify_data(inputs, targets)
|
||||
inputs = unify_inputs(inputs)
|
||||
targets = unify_targets(targets)
|
||||
|
||||
gradient = self._grad_op(self._backward_model, *inputs, targets)
|
||||
saliency = self._aggregation_fn(gradient)
|
||||
return saliency
|
||||
|
||||
def _hook_bn(self):
|
||||
"""Hook BatchNorm layer for `self._backward_model.`"""
|
||||
for _, cell in self._backward_model.cells_and_names():
|
||||
if isinstance(cell, nn.BatchNorm2d):
|
||||
cache = (cell.moving_variance, cell.gamma, cell.eps)
|
||||
cell.register_backward_hook(_get_hook("2d", cache=cache))
|
||||
elif isinstance(cell, nn.BatchNorm1d):
|
||||
cache = (cell.moving_variance, cell.gamma, cell.eps)
|
||||
cell.register_backward_hook(_get_hook("1d", cache=cache))
|
||||
|
||||
@staticmethod
|
||||
def _verify_data(inputs, targets):
|
||||
"""Verify the validity of the parsed inputs."""
|
||||
check_value_type('inputs', inputs, Tensor)
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument inputs must be 4D Tensor')
|
||||
check_value_type('targets', targets, (Tensor, int))
|
||||
if isinstance(targets, Tensor):
|
||||
if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)):
|
||||
raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
|
||||
'it should have the same length as inputs.')
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Base class IntermediateLayerAttribution"""
|
||||
|
||||
from .gradient import Gradient
|
||||
from ...._utils import resize as resize_fn
|
||||
|
||||
|
||||
class IntermediateLayerAttribution(Gradient):
|
||||
"""
|
||||
Base class for generating _attribution map at intermediate layer.
|
||||
|
||||
Args:
|
||||
network (nn.Cell): DNN model to be explained.
|
||||
layer (str, optional): string that specifies the layer to generate
|
||||
intermediate _attribution. When using default value, the input layer
|
||||
will be specified. Default: ''.
|
||||
"""
|
||||
|
||||
def __init__(self, network, layer=''):
|
||||
super(IntermediateLayerAttribution, self).__init__(network)
|
||||
|
||||
# Whether resize the _attribution layer to the input size.
|
||||
self._resize = True
|
||||
# string that specifies the resize mode. Default: 'nearest_neighbor'.
|
||||
self._resize_mode = 'nearest_neighbor'
|
||||
|
||||
self._layer = layer
|
||||
|
||||
@staticmethod
|
||||
def _resize_fn(attributions, inputs, mode):
|
||||
"""Resize the intermediate layer _attribution to the same size as inputs."""
|
||||
height, width = inputs.shape[2], inputs.shape[3]
|
||||
return resize_fn(attributions, (height, width), mode)
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Explainer with modified ReLU."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as op
|
||||
|
||||
from .gradient import Gradient
|
||||
from ...._utils import (
|
||||
unify_inputs,
|
||||
unify_targets,
|
||||
)
|
||||
|
||||
|
||||
class ModifiedReLU(Gradient):
|
||||
"""Basic class for modified ReLU explanation."""
|
||||
|
||||
def __init__(self, network, use_relu_backprop=False):
|
||||
super(ModifiedReLU, self).__init__(network)
|
||||
self.use_relu_backprop = use_relu_backprop
|
||||
self.hooked_list = []
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
self._verify_data(inputs, targets)
|
||||
inputs = unify_inputs(inputs)
|
||||
targets = unify_targets(targets)
|
||||
|
||||
self._hook_relu_backward()
|
||||
gradients = self._grad_op(self._backward_model, inputs, targets)
|
||||
saliency = self._aggregation_fn(gradients)
|
||||
|
||||
return saliency
|
||||
|
||||
def _hook_relu_backward(self):
|
||||
"""Set backward hook for ReLU layers."""
|
||||
for _, cell in self._backward_model.cells_and_names():
|
||||
if isinstance(cell, nn.ReLU):
|
||||
cell.register_backward_hook(self._backward_hook)
|
||||
self.hooked_list.append(cell)
|
||||
|
||||
def _backward_hook(self, _, grad_inputs, grad_outputs):
|
||||
"""Hook function for ReLU layers."""
|
||||
inputs = grad_inputs if self.use_relu_backprop else grad_outputs
|
||||
relu = op.ReLU()
|
||||
if isinstance(inputs, tuple):
|
||||
return relu(*inputs)
|
||||
return relu(inputs)
|
||||
|
||||
|
||||
class Deconvolution(ModifiedReLU):
|
||||
"""
|
||||
Deconvolution explanation.
|
||||
|
||||
To use `Deconvolution`, the `ReLU` operations in the network must be implemented with `mindspore.nn.Cell` object
|
||||
rather than `mindspore.ops.Operations.ReLU`. Otherwise, the results will not be correct.
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||
>>> # you may also use the net itself. The saliency map might be slightly different for softmax activation.
|
||||
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> # init Gradient with a trained network.
|
||||
>>> deconvolution = Deconvolution(net)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = deconvolution(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(Deconvolution, self).__init__(network, use_relu_backprop=True)
|
||||
|
||||
|
||||
class GuidedBackprop(ModifiedReLU):
|
||||
"""
|
||||
Guided-Backpropation explanation.
|
||||
|
||||
To use `GuidedBackprop`, the `ReLU` operations in the network must be implemented with `mindspore.nn.Cell` object
|
||||
rather than `mindspore.ops.Operations.ReLU`. Otherwise, the results will not be correct.
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||
>>> # you may also use the net itself. The saliency map might be slightly different for softmax activation.
|
||||
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> # init Gradient with a trained network.
|
||||
>>> gbp = GuidedBackprop(net)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gbp(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(GuidedBackprop, self).__init__(network, use_relu_backprop=False)
|
Loading…
Reference in New Issue