forked from mindspore-Ecosystem/mindspore
!7656 Add explainable AI methods as a submodule of MindSpore.
Merge pull request !7656 from lixiaohui33/feature_explain_core
This commit is contained in:
commit
e8b4fbbb0e
|
@ -118,19 +118,11 @@ message Explain {
|
|||
}
|
||||
|
||||
message Benchmark{
|
||||
message TotalScore{
|
||||
optional string benchmark_method = 1;
|
||||
optional float score = 2;
|
||||
}
|
||||
message LabelScore{
|
||||
repeated float score = 1;
|
||||
optional string benchmark_method = 2;
|
||||
}
|
||||
|
||||
optional string explain_method = 1;
|
||||
repeated TotalScore total_score = 2;
|
||||
repeated LabelScore label_score = 3;
|
||||
}
|
||||
optional string benchmark_method = 1;
|
||||
optional string explain_method = 2;
|
||||
optional float total_score = 3;
|
||||
repeated float label_score = 4;
|
||||
}
|
||||
|
||||
message Metadata{
|
||||
repeated string label = 1;
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Provide ExplainRunner High-level API."""
|
||||
|
||||
from ._runner import ExplainRunner
|
||||
|
||||
__all__ = ['ExplainRunner']
|
|
@ -0,0 +1,261 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Packaged operations based on MindSpore."""
|
||||
from typing import List, Tuple, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import nn
|
||||
import mindspore.ops.operations as op
|
||||
|
||||
|
||||
_Axis = Union[int, Tuple[int, ...], List[int]]
|
||||
_Idx = Union[int, mindspore.Tensor, Tuple[int, ...], Tuple[mindspore.Tensor, ...]]
|
||||
_Number = Union[int, float, np.int, np.float]
|
||||
_Shape = Union[int, Tuple[int, ...]]
|
||||
Tensor = mindspore.Tensor
|
||||
|
||||
__all__ = [
|
||||
'absolute',
|
||||
'arange',
|
||||
'argmax',
|
||||
'argmin',
|
||||
'argsort',
|
||||
'assign',
|
||||
'intersection',
|
||||
'matmul',
|
||||
'maximum',
|
||||
'minimum',
|
||||
'mean',
|
||||
'mul',
|
||||
'sort',
|
||||
'squeeze',
|
||||
'tile',
|
||||
'reshape',
|
||||
'zeros',
|
||||
'zeros_like',
|
||||
'softmax',
|
||||
'Tensor',
|
||||
'summation'
|
||||
]
|
||||
|
||||
|
||||
def absolute(inputs: Tensor) -> Tensor:
|
||||
"""Get the absolute value of a tensor value."""
|
||||
abs_op = op.Abs()
|
||||
outputs = abs_op(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def arange(
|
||||
start: _Number,
|
||||
end: _Number,
|
||||
step: _Number = 1,
|
||||
dtype: mindspore.dtype = None) -> Tensor:
|
||||
"""Get the arange value of tensor."""
|
||||
nums = np.arange(start=start, stop=end, step=step, dtype=np.int32)
|
||||
nums = mindspore.Tensor(nums, dtype=dtype)
|
||||
return nums
|
||||
|
||||
|
||||
def argmax(inputs: Tensor, axis: int = -1, keep_dims: bool = False) -> Tensor:
|
||||
"""Returns the indices of the maximum values along an axis."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
outputs = np.argmax(inputs_np, axis=axis)
|
||||
|
||||
if keep_dims:
|
||||
outputs = np.expand_dims(outputs, axis=axis)
|
||||
|
||||
return mindspore.Tensor(outputs, mindspore.int32)
|
||||
|
||||
|
||||
def argmin(inputs: Tensor, axis: int = -1, keep_dims: bool = False) -> Tensor:
|
||||
"""Returns the indices of the minimum values along an axis."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
outputs = np.argmin(inputs_np, axis=axis)
|
||||
|
||||
if keep_dims:
|
||||
outputs = np.expand_dims(outputs, axis=axis)
|
||||
|
||||
return mindspore.Tensor(outputs, mindspore.int32)
|
||||
|
||||
|
||||
def argsort(inputs: Tensor, axis: int = -1, descending: bool = False) -> Tensor:
|
||||
"""Returns the indices that would sort an array."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
factor = -1 if descending else 1
|
||||
indices_np = np.argsort(factor * inputs_np, axis=axis)
|
||||
indices = mindspore.Tensor(indices_np, dtype=mindspore.int32)
|
||||
return indices
|
||||
|
||||
|
||||
def assign(inputs: Tensor, idx: _Idx, value: Tensor) -> Tensor:
|
||||
"""Assign a tensor value to the given tensor and index."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
if isinstance(idx, Tensor):
|
||||
idx = idx.asnumpy()
|
||||
value_np = value.asnumpy()
|
||||
inputs_np[idx] = value_np
|
||||
outputs = mindspore.Tensor(inputs_np)
|
||||
return outputs
|
||||
|
||||
|
||||
def intersection(*inputs: Tensor) -> Tensor:
|
||||
"""Get the intersection value by the given tensor list."""
|
||||
outputs_np = np.ones_like(inputs[0])
|
||||
for inp in inputs:
|
||||
outputs_np &= inp.asnumpy()
|
||||
outputs = mindspore.Tensor(outputs_np)
|
||||
return outputs
|
||||
|
||||
|
||||
def matmul(inputs_x: Tensor, inputs_y: Tensor) -> Tensor:
|
||||
"""Multiplies matrix `inputs_x` and matrix `inputs_y`."""
|
||||
matmul_op = op.MatMul()
|
||||
outputs = matmul_op(inputs_x, inputs_y)
|
||||
return outputs
|
||||
|
||||
|
||||
def maximum(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor:
|
||||
"""Reduce a dimension of a tensor by the maximum value in this dimension."""
|
||||
max_op = op.ReduceMax(keep_dims)
|
||||
outputs = max_op(inputs, axis)
|
||||
return outputs
|
||||
|
||||
|
||||
def minimum(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor:
|
||||
"""Reduce a dimension of a tensor by the minimum value in the dimension."""
|
||||
max_op = op.ReduceMin(keep_dims)
|
||||
outputs = max_op(inputs, axis)
|
||||
return outputs
|
||||
|
||||
|
||||
def mean(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor:
|
||||
"""Reduce a dimension of a tensor by averaging all elements in the dimension."""
|
||||
mean_op = op.ReduceMean(keep_dims)
|
||||
outputs = mean_op(inputs, axis)
|
||||
return outputs
|
||||
|
||||
|
||||
def mul(inputs_x: Tensor, inputs_y: Tensor) -> Tensor:
|
||||
"""
|
||||
Multiplies two tensors element-wise.
|
||||
|
||||
Inputs of `input_x` and `input_y` comply with the implicit type conversion rules to make the data types consistent.
|
||||
The inputs must be two tensors or one tensor and one scalar.
|
||||
When the inputs are two tensors,
|
||||
dtypes of them cannot be both bool, and the shapes of them could be broadcast.
|
||||
When the inputs are one tensor and one scalar,
|
||||
the scalar could only be a constant.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
|
||||
a bool or a tensor whose data type is number or bool.
|
||||
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
|
||||
a bool when the first input is a tensor or a tensor whose data type is number or bool.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the one after broadcasting,
|
||||
and the data type is the one with higher precision or higher digits among the two inputs.
|
||||
"""
|
||||
mul_op = op.Mul()
|
||||
outputs = mul_op(inputs_x, inputs_y)
|
||||
return outputs
|
||||
|
||||
|
||||
def sort(inputs: Tensor, axis: _Axis = -1, descending: bool = False) -> Tensor:
|
||||
"""Return a sorted copy of an array."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
outputs_np = np.sort(inputs_np, axis=axis)
|
||||
if descending:
|
||||
outputs_np = np.flip(outputs_np, axis=axis)
|
||||
outputs = mindspore.Tensor(outputs_np)
|
||||
return outputs
|
||||
|
||||
|
||||
def squeeze(inputs: Tensor, axis: _Axis = ()):
|
||||
"""Returns a tensor with the same type but dimensions of 1 are removed based on `axis`."""
|
||||
squeeze_op = op.Squeeze(axis)
|
||||
outputs = squeeze_op(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def tile(inputs: Tensor, shape: Tuple[int, ...]) -> Tensor:
|
||||
"""Replicates a tensor with given multiples times."""
|
||||
tile_op = op.Tile()
|
||||
outputs = tile_op(inputs, shape)
|
||||
return outputs
|
||||
|
||||
|
||||
def reshape(inputs: Tensor, shape: _Shape) -> Tensor:
|
||||
"""Reshapes input tensor with the same values based on a given shape tuple."""
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
return op.Reshape()(inputs, shape)
|
||||
|
||||
|
||||
def zeros(shape: _Shape, dtype: mindspore.dtype = None) -> Tensor:
|
||||
"""Return a new array of given shape and type, filled with zeros."""
|
||||
outputs = np.zeros(shape)
|
||||
return mindspore.Tensor(outputs, dtype=dtype)
|
||||
|
||||
|
||||
def zeros_like(inputs: Tensor, dtype: mindspore.dtype = None) -> Tensor:
|
||||
"""Return an array of zeros with the same shape and type as a given array."""
|
||||
inputs_np = inputs.asnumpy()
|
||||
outputs_np = np.zeros_like(inputs_np)
|
||||
outputs = mindspore.Tensor(outputs_np, dtype)
|
||||
return outputs
|
||||
|
||||
|
||||
def random(shape: _Shape, dtype: mindspore.dtype = None) -> Tensor:
|
||||
"""Return random floats in the half-open interval [0.0, 1.0)."""
|
||||
outputs_np = np.random.random(shape)
|
||||
outputs = mindspore.Tensor(outputs_np, dtype)
|
||||
return outputs
|
||||
|
||||
|
||||
def randint(low: int, high: int, shape: _Shape, dtype: mindspore.dtype = mindspore.int8) -> Tensor:
|
||||
"""Return random integers from `low` (inclusive) to `high` (exclusive)."""
|
||||
outputs_np = np.random.randint(low, high, size=shape)
|
||||
outputs = mindspore.Tensor(outputs_np, dtype=dtype)
|
||||
return outputs
|
||||
|
||||
|
||||
def softmax(axis: int) -> Callable:
|
||||
"""Softmax activation function."""
|
||||
func = nn.Softmax(axis=axis)
|
||||
return func
|
||||
|
||||
|
||||
def summation(inputs: Tensor, axis: _Axis = (), keep_dims: bool = False) -> Tensor:
|
||||
"""Reduce a dimension of a tensor by summing all elements in the dimension."""
|
||||
sum_op = op.ReduceSum(keep_dims)
|
||||
outputs = sum_op(inputs, axis)
|
||||
return outputs
|
||||
|
||||
|
||||
def stack(inputs: List[Tensor], axis: int) -> Tensor:
|
||||
"""Packs a list of tensors in specified axis."""
|
||||
pack_op = op.Pack(axis)
|
||||
outputs = pack_op(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def sqrt(inputs: Tensor) -> Tensor:
|
||||
"""Returns square root of a tensor element-wise."""
|
||||
sqrt_op = op.Sqrt()
|
||||
return sqrt_op(inputs)
|
|
@ -0,0 +1,481 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Runner."""
|
||||
from time import time
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from mindspore.train.summary_pb2 import Explain
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log
|
||||
from mindspore.ops.operations import ExpandDims
|
||||
from mindspore.train.summary._summary_adapter import _convert_image_format, _make_image
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from .benchmark import Localization
|
||||
from .benchmark._attribution.metric import AttributionMetric
|
||||
from .explanation._attribution._attribution import Attribution
|
||||
|
||||
_EXPAND_DIMS = ExpandDims()
|
||||
_CMAP_0 = np.reshape(np.array([55, 25, 86, 255]), [1, 1, 4]) / 255
|
||||
_CMAP_1 = np.reshape(np.array([255, 255, 0, 255]), [1, 1, 4]) / 255
|
||||
|
||||
|
||||
def _normalize(img_np):
|
||||
"""Normalize the image in the numpy array to be in [0, 255]. """
|
||||
max_ = img_np.max()
|
||||
min_ = img_np.min()
|
||||
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
|
||||
return (normed * 255).astype(np.uint8)
|
||||
|
||||
|
||||
def _make_rgba(saliency):
|
||||
"""Make rgba image for saliency map."""
|
||||
saliency = saliency.asnumpy().squeeze()
|
||||
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()).clip(1e-10)
|
||||
rgba = np.empty((saliency.shape[0], saliency.shape[1], 4))
|
||||
rgba[:, :, :] = np.expand_dims(saliency, 2)
|
||||
rgba = rgba * _CMAP_1 + (1 - rgba) * _CMAP_0
|
||||
rgba[:, :, -1] = saliency * 1
|
||||
return rgba
|
||||
|
||||
|
||||
class ExplainRunner:
|
||||
"""
|
||||
High-level API for users to generate results with the explanation methods and the evaluation methods.
|
||||
|
||||
After generating results with the explanation methods and the evaluation methods, the results will be written into
|
||||
a specified file with 'mindspore.summary.SummaryRecord'. The stored content can be viewed using MindInsight.
|
||||
|
||||
Args:
|
||||
summary_dir (str): The directory path to save the summary files which store the generated results.
|
||||
Default: "./"
|
||||
|
||||
Examples:
|
||||
>>> # init a runner with a specified directory
|
||||
>>> summary_dir = "summary_dir"
|
||||
>>> runner = ExplainRunner(summary_dir)
|
||||
"""
|
||||
|
||||
def __init__(self, summary_dir: Optional[str] = "./"):
|
||||
self._summary_dir = summary_dir
|
||||
self._count = 0
|
||||
self._classes = None
|
||||
self._model = None
|
||||
|
||||
def run(self,
|
||||
dataset: Tuple,
|
||||
explainers: List,
|
||||
benchmarkers: Optional[List] = None):
|
||||
"""
|
||||
Genereate results and write results into the summary files in `self.summary_dir`.
|
||||
|
||||
Args:
|
||||
dataset (tuple): A tuple that contains `mindspore.dataset` object for iteration and its labels.
|
||||
- dataset[0], a `mindspore.dataset` object to provide data to explain.
|
||||
- dataset[1], a list of string that specifies the label names of the dataset.
|
||||
explainers (list): A list of explanation objects to generate _attribution results.
|
||||
benchmarkers (list): A list of benchmark objects to generate evaluation results. Default: None
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
|
||||
>>> # obtain dataset object
|
||||
>>> dataset = get_dataset()
|
||||
>>> classes = ["cat", "dog", ...]
|
||||
>>> # load checkpoint to a network, e.g. resnet50
|
||||
>>> param_dict = load_checkpoint("checkpoint.ckpt")
|
||||
>>> net = resnet50(len(classes))
|
||||
>>> load_parama_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation
|
||||
>>> model = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> gbp = GuidedBackprop(model)
|
||||
>>> gradient = Gradient(model)
|
||||
>>> runner = ExplainRunner("./")
|
||||
>>> explainers = [gbp, gradient]
|
||||
>>> runner.run((dataset, classes), explainers)
|
||||
"""
|
||||
|
||||
if not isinstance(dataset, tuple):
|
||||
raise TypeError("Argument `dataset` must be a tuple.")
|
||||
if len(dataset) != 2:
|
||||
raise ValueError("Argument `dataset` should be a tuple with length = 2.")
|
||||
|
||||
dataset, classes = dataset
|
||||
self._verify_data_form(dataset, benchmarkers)
|
||||
self._classes = classes
|
||||
|
||||
if explainers is None or not explainers:
|
||||
raise ValueError("Argument `explainers` can neither be None nor empty.")
|
||||
|
||||
for exp in explainers:
|
||||
if not isinstance(exp, Attribution) or not isinstance(explainers, list):
|
||||
raise TypeError("Argument explainers should be a list of objects of classes in "
|
||||
"`mindspore.explainer.explanation._attribution`.")
|
||||
if benchmarkers is not None:
|
||||
for bench in benchmarkers:
|
||||
if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list):
|
||||
raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation"
|
||||
"`mindspore.explainer.benchmark._attribution`.")
|
||||
|
||||
self._model = explainers[0].model
|
||||
|
||||
with SummaryRecord(self._summary_dir) as summary:
|
||||
print("Start running and writing......")
|
||||
begin = time()
|
||||
print("Start writing metadata.")
|
||||
|
||||
explain = Explain()
|
||||
explain.metadata.label.extend(classes)
|
||||
exp_names = [exp.__class__.__name__ for exp in explainers]
|
||||
explain.metadata.explain_method.extend(exp_names)
|
||||
if benchmarkers is not None:
|
||||
bench_names = [bench.__class__.__name__ for bench in benchmarkers]
|
||||
explain.metadata.benchmark_method.extend(bench_names)
|
||||
|
||||
summary.add_value("explainer", "metadata", explain)
|
||||
summary.record(1)
|
||||
|
||||
print("Finish writing metadata.")
|
||||
|
||||
now = time()
|
||||
print("Start running and writing inference data......")
|
||||
imageid_labels = self._run_inference(dataset, summary)
|
||||
print("Finish running and writing inference data. Time elapsed: {}s".format(time() - now))
|
||||
|
||||
if benchmarkers is None:
|
||||
for exp in explainers:
|
||||
start = time()
|
||||
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
|
||||
self._count = 0
|
||||
ds.config.set_seed(58)
|
||||
for idx, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
self._run_exp_step(next_element, exp, imageid_labels, summary)
|
||||
print("Finish writing {}-th explanation data. Time elapsed: {}".format(
|
||||
idx, time() - now))
|
||||
print("Finish running and writing explanation data for {}. Time elapsed: {}".format(
|
||||
exp.__class__.__name__, time() - start))
|
||||
else:
|
||||
for exp in explainers:
|
||||
explain = Explain()
|
||||
for bench in benchmarkers:
|
||||
bench.reset()
|
||||
print(f"Start running and writing explanation and benchmark data for {exp.__class__.__name__}.")
|
||||
self._count = 0
|
||||
start = time()
|
||||
ds.config.set_seed(58)
|
||||
for idx, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary)
|
||||
print("Finish writing {}-th batch explanation data. Time elapsed: {}s".format(
|
||||
idx, time() - now))
|
||||
for bench in benchmarkers:
|
||||
now = time()
|
||||
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
|
||||
print("Finish running {}-th batch benchmark data for {}. Time elapsed: {}s".format(
|
||||
idx, bench.__class__.__name__, time() - now))
|
||||
|
||||
for bench in benchmarkers:
|
||||
benchmark = explain.benchmark.add()
|
||||
benchmark.explain_method = exp.__class__.__name__
|
||||
benchmark.benchmark_method = bench.__class__.__name__
|
||||
|
||||
benchmark.total_score = bench.performance
|
||||
benchmark.label_score.extend(bench.class_performances)
|
||||
|
||||
print("Finish running and writing explanation and benchmark data for {}. "
|
||||
"Time elapsed: {}s".format(exp.__class__.__name__, time() - start))
|
||||
summary.add_value('explainer', 'benchmark', explain)
|
||||
summary.record(1)
|
||||
print("Finish running and writing. Total time elapsed: {}s".format(time() - begin))
|
||||
|
||||
@staticmethod
|
||||
def _verify_data_form(dataset, benchmarkers):
|
||||
"""
|
||||
Verify the validity of dataset.
|
||||
|
||||
Args:
|
||||
dataset (`ds`): the user parsed dataset.
|
||||
benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers.
|
||||
"""
|
||||
next_element = dataset.create_tuple_iterator().get_next()
|
||||
|
||||
if len(next_element) not in [1, 2, 3]:
|
||||
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
|
||||
" as columns.")
|
||||
|
||||
if len(next_element) == 3:
|
||||
inputs, labels, bboxes = next_element
|
||||
if bboxes.shape[-1] != 4:
|
||||
raise ValueError("The third element of dataset should be bounding boxes with shape of "
|
||||
"[batch_size, num_ground_truth, 4].")
|
||||
else:
|
||||
if True in [isinstance(bench, Localization) for bench in benchmarkers]:
|
||||
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
|
||||
|
||||
if len(next_element) == 2:
|
||||
inputs, labels = next_element
|
||||
if len(next_element) == 1:
|
||||
inputs = next_element[0]
|
||||
|
||||
if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
|
||||
raise ValueError(
|
||||
"Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
|
||||
inputs.shape))
|
||||
if len(inputs.shape) == 3:
|
||||
log.warning(
|
||||
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
|
||||
" dimension as batch data.".format(inputs.shape))
|
||||
|
||||
if len(next_element) > 1:
|
||||
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
|
||||
raise ValueError(
|
||||
"Labels shape {} is unrecognizable: labels should not have more than two dimensions"
|
||||
" with length greater than 1.".format(labels.shape))
|
||||
|
||||
def _transform_data(self, inputs, labels, bboxes, ifbbox):
|
||||
"""
|
||||
Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): the image data
|
||||
labels (Tensor): the labels
|
||||
bboxes (Tensor): the boudnding boxes data
|
||||
ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t label
|
||||
id will be returned. If False, the returned bboxes is the the parsed bboxes.
|
||||
|
||||
Returns:
|
||||
inputs (Tensor): the image data, unified to a 4D Tensor.
|
||||
labels (List[List[int]]): the ground truth labels.
|
||||
bboxes (Union[List[Dict], None, Tensor]): the bounding boxes
|
||||
"""
|
||||
inputs = ms.Tensor(inputs, ms.float32)
|
||||
if len(inputs.shape) == 3:
|
||||
inputs = _EXPAND_DIMS(inputs, 0)
|
||||
if isinstance(labels, ms.Tensor):
|
||||
labels = ms.Tensor(labels, ms.int32)
|
||||
labels = _EXPAND_DIMS(labels, 0)
|
||||
if isinstance(bboxes, ms.Tensor):
|
||||
bboxes = ms.Tensor(bboxes, ms.int32)
|
||||
bboxes = _EXPAND_DIMS(bboxes, 0)
|
||||
|
||||
input_len = len(inputs)
|
||||
if bboxes is not None and ifbbox:
|
||||
bboxes = ms.Tensor(bboxes, ms.int32)
|
||||
masks_lst = []
|
||||
labels = labels.asnumpy().reshape([input_len, -1])
|
||||
bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
|
||||
for idx, label in enumerate(labels):
|
||||
height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
|
||||
masks = {}
|
||||
for j, label_item in enumerate(label):
|
||||
target = int(label_item)
|
||||
if -1 < target < len(self._classes):
|
||||
if target not in masks:
|
||||
mask = np.zeros((1, 1, height, width))
|
||||
else:
|
||||
mask = masks[target]
|
||||
x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
|
||||
mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
|
||||
masks[target] = mask
|
||||
|
||||
masks_lst.append(masks)
|
||||
bboxes = masks_lst
|
||||
|
||||
labels = ms.Tensor(labels, ms.int32)
|
||||
if len(labels.shape) == 1:
|
||||
labels_lst = [[int(i)] for i in labels.asnumpy()]
|
||||
else:
|
||||
labels = labels.asnumpy().reshape([input_len, -1])
|
||||
labels_lst = []
|
||||
for item in labels:
|
||||
labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._classes))))
|
||||
labels = labels_lst
|
||||
return inputs, labels, bboxes
|
||||
|
||||
def _unpack_next_element(self, next_element, ifbbox=False):
|
||||
"""
|
||||
Unpack a single iteration of dataset.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): a single element iterated from dataset object.
|
||||
ifbbox (bool): whether to preprocess bboxes in self._transform_data.
|
||||
|
||||
Returns:
|
||||
Tuple, a unified Tuple contains image_data, labels, and bounding boxes.
|
||||
"""
|
||||
if len(next_element) == 3:
|
||||
inputs, labels, bboxes = next_element
|
||||
elif len(next_element) == 2:
|
||||
inputs, labels = next_element
|
||||
bboxes = None
|
||||
else:
|
||||
inputs = next_element[0]
|
||||
labels = [[] for x in inputs]
|
||||
bboxes = None
|
||||
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
|
||||
return inputs, labels, bboxes
|
||||
|
||||
@staticmethod
|
||||
def _make_label_batch(labels):
|
||||
"""
|
||||
Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
|
||||
length of all the rows in labels.
|
||||
|
||||
Args:
|
||||
labels (List[List]): the union labels of a data batch.
|
||||
|
||||
Returns:
|
||||
2D Tensor.
|
||||
"""
|
||||
|
||||
max_len = max([len(l) for l in labels])
|
||||
batch_labels = np.zeros((len(labels), max_len))
|
||||
|
||||
for idx, _ in enumerate(batch_labels):
|
||||
length = len(labels[idx])
|
||||
batch_labels[idx, :length] = np.array(labels[idx])
|
||||
|
||||
return ms.Tensor(batch_labels, ms.int32)
|
||||
|
||||
def _run_inference(self, dataset, summary, threshod=0.5):
|
||||
"""
|
||||
Run inference for the dataset and write the inference related data into summary.
|
||||
|
||||
Args:
|
||||
dataset (`ds`): the parsed dataset
|
||||
summary (`SummaryRecord`): the summary object to store the data
|
||||
threshold (float): the threshold for prediction.
|
||||
|
||||
Returns:
|
||||
imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
|
||||
"""
|
||||
imageid_labels = {}
|
||||
ds.config.set_seed(58)
|
||||
self._count = 0
|
||||
for j, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
prob = self._model(inputs).asnumpy()
|
||||
for idx, inp in enumerate(inputs):
|
||||
gt_labels = labels[idx]
|
||||
gt_probs = [float(prob[idx][i]) for i in gt_labels]
|
||||
|
||||
data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
|
||||
_, _, _, image_string = _make_image(_normalize(data_np))
|
||||
|
||||
predicted_labels = [int(i) for i in (prob[idx] > threshod).nonzero()[0]]
|
||||
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
|
||||
|
||||
union_labs = list(set(gt_labels + predicted_labels))
|
||||
imageid_labels[str(self._count)] = union_labs
|
||||
|
||||
explain = Explain()
|
||||
explain.image_id = str(self._count)
|
||||
explain.image_data = image_string
|
||||
summary.add_value("explainer", "image", explain)
|
||||
|
||||
explain = Explain()
|
||||
explain.image_id = str(self._count)
|
||||
explain.ground_truth_label.extend(gt_labels)
|
||||
explain.inference.ground_truth_prob.extend(gt_probs)
|
||||
explain.inference.predicted_label.extend(predicted_labels)
|
||||
explain.inference.predicted_prob.extend(predicted_probs)
|
||||
summary.add_value("explainer", "inference", explain)
|
||||
|
||||
summary.record(1)
|
||||
|
||||
self._count += 1
|
||||
print("Finish running and writing {}-th batch inference data. Time elapsed: {}s".format(j, time() - now))
|
||||
return imageid_labels
|
||||
|
||||
def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
|
||||
"""
|
||||
Run the explanation for each step and write explanation results into summary.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): data of one step
|
||||
explainer (_Attribution): an Attribution object to generate saliency maps.
|
||||
imageid_labels (dict): a dict that maps the image_id and its union labels.
|
||||
summary (SummaryRecord): the summary object to store the data
|
||||
|
||||
Returns:
|
||||
List of dict that maps label to its corresponding saliency map.
|
||||
"""
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
count = self._count
|
||||
unions = []
|
||||
for _ in range(len(labels)):
|
||||
unions_labels = imageid_labels[str(count)]
|
||||
unions.append(unions_labels)
|
||||
count += 1
|
||||
|
||||
batch_unions = self._make_label_batch(unions)
|
||||
saliency_dict_lst = []
|
||||
|
||||
batch_saliency_full = []
|
||||
for i in range(len(batch_unions[0])):
|
||||
batch_saliency = explainer(inputs, batch_unions[:, i])
|
||||
batch_saliency_full.append(batch_saliency)
|
||||
|
||||
for idx, union in enumerate(unions):
|
||||
saliency_dict = {}
|
||||
explain = Explain()
|
||||
explain.image_id = str(self._count)
|
||||
for k, lab in enumerate(union):
|
||||
saliency = batch_saliency_full[k][idx:idx + 1]
|
||||
|
||||
saliency_dict[lab] = saliency
|
||||
|
||||
saliency_np = _make_rgba(saliency)
|
||||
_, _, _, saliency_string = _make_image(_normalize(saliency_np))
|
||||
|
||||
explanation = explain.explanation.add()
|
||||
explanation.explain_method = explainer.__class__.__name__
|
||||
|
||||
explanation.label = lab
|
||||
explanation.heatmap = saliency_string
|
||||
|
||||
summary.add_value("explainer", "explanation", explain)
|
||||
summary.record(1)
|
||||
|
||||
self._count += 1
|
||||
saliency_dict_lst.append(saliency_dict)
|
||||
return saliency_dict_lst
|
||||
|
||||
def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
|
||||
"""
|
||||
Run the explanation and evaluation for each step and write explanation results into summary.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): Data of one step
|
||||
explainer (`_Attribution`): An Attribution object to generate saliency maps.
|
||||
imageid_labels (dict): A dict that maps the image_id and its union labels.
|
||||
"""
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
for idx, inp in enumerate(inputs):
|
||||
inp = _EXPAND_DIMS(inp, 0)
|
||||
saliency_dict = saliency_dict_lst[idx]
|
||||
for label, saliency in saliency_dict.items():
|
||||
if isinstance(benchmarker, Localization):
|
||||
_, _, bboxes = self._unpack_next_element(next_element, True)
|
||||
if label in labels[idx]:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
|
||||
saliency=saliency)
|
||||
benchmarker.aggregate(res, label)
|
||||
else:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
|
||||
benchmarker.aggregate(res, label)
|
|
@ -0,0 +1,285 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for MindExplain"""
|
||||
|
||||
__all__ = [
|
||||
'ForwardProbe',
|
||||
'calc_auc',
|
||||
'calc_correlation',
|
||||
'format_tensor_to_ndarray',
|
||||
'generate_one_hot',
|
||||
'rank_pixels',
|
||||
'resize',
|
||||
'retrieve_layer_by_name',
|
||||
'retrieve_layer',
|
||||
'unify_inputs',
|
||||
'unify_targets'
|
||||
]
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as op
|
||||
|
||||
_Array = np.ndarray
|
||||
_Module = nn.Cell
|
||||
_Tensor = ms.Tensor
|
||||
|
||||
|
||||
def generate_one_hot(indices, depth):
|
||||
r"""
|
||||
Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0
|
||||
and 0.0.
|
||||
"""
|
||||
on_value = ms.Tensor(1.0, ms.float32)
|
||||
off_value = ms.Tensor(0.0, ms.float32)
|
||||
weights = op.OneHot()(indices, depth, on_value, off_value)
|
||||
return weights
|
||||
|
||||
|
||||
def unify_inputs(inputs) -> tuple:
|
||||
"""Unify inputs of explainer."""
|
||||
if isinstance(inputs, tuple):
|
||||
return inputs
|
||||
if isinstance(inputs, ms.Tensor):
|
||||
inputs = (inputs,)
|
||||
elif isinstance(inputs, np.ndarray):
|
||||
inputs = (ms.Tensor(inputs),)
|
||||
else:
|
||||
raise TypeError(
|
||||
'inputs must be one of [tuple, ms.Tensor or np.ndarray], '
|
||||
'but get {}'.format(type(inputs)))
|
||||
return inputs
|
||||
|
||||
|
||||
def unify_targets(targets) -> ms.Tensor:
|
||||
"""Unify targets labels of explainer."""
|
||||
if isinstance(targets, ms.Tensor):
|
||||
return targets
|
||||
if isinstance(targets, list):
|
||||
targets = ms.Tensor(targets, dtype=ms.int32)
|
||||
if isinstance(targets, int):
|
||||
targets = ms.Tensor([targets], dtype=ms.int32)
|
||||
else:
|
||||
raise TypeError(
|
||||
'targets must be one of [int, list or ms.Tensor], '
|
||||
'but get {}'.format(type(targets)))
|
||||
return targets
|
||||
|
||||
|
||||
def retrieve_layer_by_name(model: _Module, layer_name: str):
|
||||
"""
|
||||
Retrieve the layer in the model by the given layer_name.
|
||||
|
||||
Args:
|
||||
model (_Module): model which contains the target layer
|
||||
layer_name (str): name of target layer
|
||||
|
||||
Return:
|
||||
- target_layer (_Module)
|
||||
|
||||
Raise:
|
||||
ValueError: is module with given layer_name is not found in the model,
|
||||
raise ValueError.
|
||||
|
||||
"""
|
||||
if not isinstance(layer_name, str):
|
||||
raise TypeError('layer_name should be type of str, but receive {}.'
|
||||
.format(type(layer_name)))
|
||||
|
||||
if not layer_name:
|
||||
return model
|
||||
|
||||
target_layer = None
|
||||
for name, cell in model.cells_and_names():
|
||||
if name == layer_name:
|
||||
target_layer = cell
|
||||
return target_layer
|
||||
|
||||
if target_layer is None:
|
||||
raise ValueError(
|
||||
'Cannot match {}, please provide target layer'
|
||||
'in the given model.'.format(layer_name))
|
||||
return None
|
||||
|
||||
|
||||
def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''):
|
||||
"""
|
||||
Retrieve the layer in the model.
|
||||
|
||||
'target' can be either a layer name or a Cell object. Given the layer name,
|
||||
the method will search thourgh the model and return the matched layer. If a
|
||||
Cell object is provided, it will check whether the given layer exists
|
||||
in the model. If target layer is not found in the model, ValueError will
|
||||
be raised.
|
||||
|
||||
Args:
|
||||
model (_Module): the model to retrieve the target layer
|
||||
target_layer (Union[str, _Module]): target layer to retrieve. Can be
|
||||
either string (layer name) or the Cell object. If '' is provided,
|
||||
the input model will be returned.
|
||||
|
||||
Return:
|
||||
target layer (_Module)
|
||||
"""
|
||||
if isinstance(target_layer, str):
|
||||
target_layer = retrieve_layer_by_name(model, target_layer)
|
||||
return target_layer
|
||||
|
||||
if isinstance(target_layer, _Module):
|
||||
for _, cell in model.cells_and_names():
|
||||
if target_layer is cell:
|
||||
return target_layer
|
||||
raise ValueError(
|
||||
'Model not contain cell {}, fail to probe.'.format(target_layer)
|
||||
)
|
||||
raise TypeError('layer_name must have type of str or ms.nn.Cell,'
|
||||
'but receive {}'.format(type(target_layer)))
|
||||
|
||||
|
||||
class ForwardProbe:
|
||||
"""
|
||||
Probe to capture output of specific layer in a given model.
|
||||
|
||||
Args:
|
||||
target_layer (_Module): name of target layer or just provide the
|
||||
target layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, target_layer: _Module):
|
||||
self._target_layer = target_layer
|
||||
self._original_construct = self._target_layer.construct
|
||||
self._intermediate_tensor = None
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._intermediate_tensor
|
||||
|
||||
def __enter__(self):
|
||||
self._target_layer.construct = self._new_construct
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
self._target_layer.construct = self._original_construct
|
||||
self._intermediate_tensor = None
|
||||
return False
|
||||
|
||||
def _new_construct(self, *inputs):
|
||||
outputs = self._original_construct(*inputs)
|
||||
self._intermediate_tensor = outputs
|
||||
return outputs
|
||||
|
||||
|
||||
def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray:
|
||||
"""Unify `mindspore.Tensor` and `np.ndarray` to `np.ndarray`. """
|
||||
if isinstance(x, ms.Tensor):
|
||||
x = x.asnumpy()
|
||||
|
||||
if not isinstance(x, np.ndarray):
|
||||
raise TypeError('input should be one of [ms.Tensor or np.ndarray],'
|
||||
' but receive {}'.format(type(x)))
|
||||
return x
|
||||
|
||||
|
||||
def calc_correlation(x: Union[ms.Tensor, np.ndarray],
|
||||
y: Union[ms.Tensor, np.ndarray]) -> float:
|
||||
"""Calculate Pearson correlation coefficient between two arrays. """
|
||||
x = format_tensor_to_ndarray(x)
|
||||
y = format_tensor_to_ndarray(y)
|
||||
faithfulness = -np.corrcoef(x, y)[0, 1]
|
||||
|
||||
return faithfulness
|
||||
|
||||
|
||||
def calc_auc(x: _Array) -> float:
|
||||
"""Calculate the Aera under Curve."""
|
||||
# take mean for multiple patches if the model is fully convolutional model
|
||||
if len(x.shape) == 4:
|
||||
x = np.mean(np.mean(x, axis=2), axis=3)
|
||||
|
||||
auc = (x.sum() - x[0] - x[-1]) / len(x)
|
||||
return float(auc)
|
||||
|
||||
|
||||
def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
|
||||
"""
|
||||
Generate rank order fo every pixel in an 2D array.
|
||||
|
||||
The rank order start from 0 to (num_pixel-1). If descending is True, the
|
||||
rank order will generate in a descending order, otherwise in ascending
|
||||
order.
|
||||
|
||||
Example:
|
||||
x = np.array([[4., 3., 1.], [5., 9., 1.]])
|
||||
rank_pixels(x, descending=True)
|
||||
>> np.array([[2, 3, 4], [1, 0, 5]])
|
||||
rank_pixels(x, descending=False)
|
||||
>> np.array([[3, 2, 0], [4, 5, 1]])
|
||||
|
||||
"""
|
||||
if len(inputs.shape) != 2:
|
||||
raise ValueError('Only support 2D array currently')
|
||||
flatten_saliency = inputs.reshape(-1)
|
||||
factor = -1 if descending else 1
|
||||
sorted_arg = np.argsort(factor * flatten_saliency, axis=0)
|
||||
flatten_rank = np.zeros_like(sorted_arg)
|
||||
flatten_rank[sorted_arg] = np.arange(0, sorted_arg.shape[0])
|
||||
rank_map = flatten_rank.reshape(inputs.shape)
|
||||
return rank_map
|
||||
|
||||
|
||||
def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor:
|
||||
"""
|
||||
Resize the intermediate layer _attribution to the same size as inputs.
|
||||
|
||||
Args:
|
||||
inputs (ms.Tensor): the input tensor to be resized
|
||||
size (tupleint]): the targeted size resize to
|
||||
mode (str): the resize mode. Options: 'nearest_neighbor', 'bilinear'
|
||||
|
||||
Returns:
|
||||
outputs (ms.Tensor): the resized tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: the resize mode is not in ['nearest_neighbor',
|
||||
'bilinear'].
|
||||
"""
|
||||
h, w = size
|
||||
if mode == 'nearest_neighbor':
|
||||
resize_nn = op.ResizeNearestNeighbor((h, w))
|
||||
outputs = resize_nn(inputs)
|
||||
|
||||
elif mode == 'bilinear':
|
||||
inputs_np = inputs.asnumpy()
|
||||
inputs_np = np.transpose(inputs_np, [0, 2, 3, 1])
|
||||
array_lst = []
|
||||
for inp in inputs_np:
|
||||
array = (np.repeat(inp, 3, axis=2) * 255).astype(np.uint8)
|
||||
image = Image.fromarray(array)
|
||||
image = image.resize(size, resample=Image.BILINEAR)
|
||||
array = np.asarray(image).astype(np.float32) / 255
|
||||
array_lst.append(array[:, :, 0:1])
|
||||
|
||||
resized_np = np.transpose(array_lst, [0, 3, 1, 2])
|
||||
outputs = ms.Tensor(resized_np, inputs.dtype)
|
||||
else:
|
||||
raise ValueError('Unsupported resize mode {}'.format(mode))
|
||||
|
||||
return outputs
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Predefined XAI metrics."""
|
||||
|
||||
from ._attribution.faithfulness import Faithfulness
|
||||
from ._attribution.localization import Localization
|
||||
|
||||
__all__ = [
|
||||
"Faithfulness",
|
||||
"Localization"
|
||||
]
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Predefined XAI metrics"""
|
||||
|
||||
from .faithfulness import Faithfulness
|
||||
from .localization import Localization
|
||||
|
||||
__all__ = [
|
||||
"Faithfulness",
|
||||
"Localization"
|
||||
]
|
|
@ -0,0 +1,593 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Faithfulness"""
|
||||
import math
|
||||
from typing import Callable, Optional, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as op
|
||||
from .metric import AttributionMetric
|
||||
from ..._utils import calc_correlation, calc_auc, format_tensor_to_ndarray, rank_pixels
|
||||
from ...explanation._attribution._attribution import Attribution as _Attribution
|
||||
|
||||
_Array = np.ndarray
|
||||
_Explainer = Union[_Attribution, Callable]
|
||||
_Label = Union[int, ms.Tensor]
|
||||
_Module = nn.Cell
|
||||
|
||||
|
||||
def _calc_feature_importance(saliency: _Array, masks: _Array) -> _Array:
|
||||
"""Calculate feature important w.r.t given masks."""
|
||||
feature_importance = []
|
||||
num_perturbations = masks.shape[0]
|
||||
for i in range(num_perturbations):
|
||||
patch_feature_importance = saliency[masks[i]].sum() / masks[i].sum()
|
||||
feature_importance.append(patch_feature_importance)
|
||||
feature_importance = np.array(feature_importance, dtype=np.float32)
|
||||
return feature_importance
|
||||
|
||||
|
||||
class _BaseReplacement:
|
||||
"""
|
||||
Base class of generator for generating different replacement for perturbations.
|
||||
|
||||
Args:
|
||||
kwargs: Optional args for generating replacement. Derived class need to
|
||||
add necessary arg names and default value to '_necessary_args'.
|
||||
If the argument has no default value, the value should be set to
|
||||
'EMPTY' to mark the required args. Initializing an object will
|
||||
check the given kwargs w.r.t '_necessary_args'.
|
||||
|
||||
Raise:
|
||||
ValueError: Raise when provided kwargs not contain necessary arg names with 'EMPTY' mark.
|
||||
"""
|
||||
_necessary_args = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._replace_args = self._necessary_args.copy()
|
||||
for key, value in self._replace_args.items():
|
||||
if key in kwargs.keys():
|
||||
self._replace_args[key] = kwargs[key]
|
||||
elif key not in kwargs.keys() and value == 'EMPTY':
|
||||
raise ValueError(f"Missing keyword arg {key} for {self.__class__.__name__}.")
|
||||
|
||||
__call__: Callable
|
||||
"""
|
||||
Generate replacement for perturbations. Derived class should overwrite this
|
||||
function to generate different replacement for perturbing.
|
||||
|
||||
Args:
|
||||
inputs (_Array): Array to be perturb.
|
||||
|
||||
Returns:
|
||||
- replacement (_Array): Array to provide alternative pixels for every
|
||||
position in the given
|
||||
inputs. The returned array should have same shape as inputs.
|
||||
"""
|
||||
|
||||
|
||||
class Constant(_BaseReplacement):
|
||||
""" Generator to provide constant-value replacement for perturbations """
|
||||
_necessary_args = {'base_value': 'EMPTY'}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
replacement = np.ones_like(inputs, dtype=np.float32)
|
||||
replacement *= self._replace_args['base_value']
|
||||
return replacement
|
||||
|
||||
|
||||
class GaussianBlur(_BaseReplacement):
|
||||
""" Generator to provided gaussian blurred inputs for perturbation. """
|
||||
_necessary_args = {'sigma': 0.7}
|
||||
|
||||
def __call__(self, inputs: _Array) -> _Array:
|
||||
sigma = self._replace_args['sigma']
|
||||
replacement = gaussian_filter(inputs, sigma=sigma)
|
||||
return replacement
|
||||
|
||||
|
||||
class Perturb:
|
||||
"""
|
||||
Perturbation generator to generate perturbations for a given array.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_mode (str): specify perturbing mode, through deleting or
|
||||
inserting pixels. Current support: ['Deletion', 'Insertion'].
|
||||
is_accumulate (bool): whether to accumulate the former perturbations to
|
||||
the later perturbations.
|
||||
perturb_pixel_per_step (int, optional): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (int, optional): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perturb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_mode: str,
|
||||
is_accumulate: bool,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None):
|
||||
self._perturb_percent = perturb_percent
|
||||
self._perturb_mode = perturb_mode
|
||||
self._pixel_per_step = perturb_pixel_per_step
|
||||
self._num_perturbations = num_perturbations
|
||||
self._is_accumulate = is_accumulate
|
||||
|
||||
@staticmethod
|
||||
def _assign(x: _Array, y: _Array, masks: _Array):
|
||||
"""Assign values to perturb pixels on perturbations."""
|
||||
if masks.dtype != bool:
|
||||
raise TypeError('The param "masks" should be an array of bool, but receive {}'
|
||||
.format(masks.dtype))
|
||||
for i in range(x.shape[0]):
|
||||
x[i][:, masks[i]] = y[:, masks[i]]
|
||||
|
||||
def _generate_mask(self, saliency_rank: _Array) -> _Array:
|
||||
"""Generate mask for perturbations based on given saliency ranks."""
|
||||
if len(saliency_rank.shape) != 2:
|
||||
raise ValueError(f'The param "saliency_rank" should be 2-dim, but receive {len(saliency_rank.shape)}.')
|
||||
|
||||
num_pixels = saliency_rank.shape[0] * saliency_rank.shape[1]
|
||||
if self._pixel_per_step:
|
||||
pixel_per_step = self._pixel_per_step
|
||||
num_perturbations = math.floor(
|
||||
num_pixels * self._perturb_percent / self._pixel_per_step)
|
||||
elif self._num_perturbations:
|
||||
pixel_per_step = math.floor(
|
||||
num_pixels * self._perturb_percent / self._num_perturbations)
|
||||
num_perturbations = self._num_perturbations
|
||||
else:
|
||||
raise ValueError("Must provide either pixel_per_step or num_perturbations.")
|
||||
|
||||
masks = np.zeros(
|
||||
(num_perturbations, saliency_rank.shape[0], saliency_rank.shape[1]),
|
||||
dtype=np.bool)
|
||||
low_bound = 0
|
||||
up_bound = low_bound + pixel_per_step
|
||||
factor = 0 if self._is_accumulate else 1
|
||||
|
||||
for i in range(num_perturbations):
|
||||
masks[i, ((saliency_rank >= low_bound)
|
||||
& (saliency_rank < up_bound))] = True
|
||||
low_bound = up_bound * factor
|
||||
up_bound += pixel_per_step
|
||||
|
||||
if len(masks.shape) == 3:
|
||||
return masks
|
||||
raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect 3-dim.')
|
||||
|
||||
def __call__(self,
|
||||
inputs: _Array,
|
||||
saliency: _Array,
|
||||
reference: _Array,
|
||||
return_mask: bool = False,
|
||||
) -> Union[_Array, Tuple[_Array, ...]]:
|
||||
"""
|
||||
Generate perturbations of given array.
|
||||
|
||||
Args:
|
||||
inputs (_Array): input array to perturb
|
||||
saliency (_Array): saliency map
|
||||
return_mask (bool): whether return the mask for generating
|
||||
the perturbation. The mask can be used to calculate
|
||||
average feature importance of pixels perturbed at each step.
|
||||
|
||||
Return:
|
||||
perturbations (_Array)
|
||||
masks (_Array): return when return_mask is set to True.
|
||||
"""
|
||||
if not np.array_equal(inputs.shape, reference.shape):
|
||||
raise ValueError('reference must have the same shape as inputs.')
|
||||
|
||||
saliency_rank = rank_pixels(saliency, descending=True)
|
||||
masks = self._generate_mask(saliency_rank)
|
||||
num_perturbations = masks.shape[0]
|
||||
|
||||
if self._perturb_mode == 'Insertion':
|
||||
inputs, reference = reference, inputs
|
||||
|
||||
perturbations = np.tile(
|
||||
inputs, (num_perturbations, *[1] * len(inputs.shape)))
|
||||
|
||||
Perturb._assign(perturbations, reference, masks)
|
||||
|
||||
if return_mask:
|
||||
return perturbations, masks
|
||||
return perturbations
|
||||
|
||||
|
||||
class _FaithfulnessHelper:
|
||||
"""Base class for faithfulness calculator."""
|
||||
_support = [Constant, GaussianBlur]
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_mode: str,
|
||||
perturb_method: str,
|
||||
is_accumulate: bool,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
|
||||
self._get_reference = None
|
||||
for method in self._support:
|
||||
if perturb_method == method.__name__:
|
||||
self._get_reference = method(**kwargs)
|
||||
if self._get_reference is None:
|
||||
raise ValueError(
|
||||
'The param "perturb_method" should be one of {}.'.format([x.__name__ for x in self._support]))
|
||||
|
||||
self._perturb = Perturb(perturb_percent=perturb_percent,
|
||||
perturb_mode=perturb_mode,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=is_accumulate)
|
||||
|
||||
calc_faithfulness: Callable
|
||||
"""
|
||||
Method used to calculate faithfulness for given inputs, target label,
|
||||
saliency. Derive class should implement this method.
|
||||
|
||||
Args:
|
||||
inputs (_Array): sample to calculate faithfulness score
|
||||
model (_Module): model to explanation
|
||||
targets (_Label): label to explanation on.
|
||||
saliency (_Array): Saliency map of given inputs and targets from the
|
||||
explainer.
|
||||
|
||||
Return:
|
||||
- faithfulness (float): faithfulness score
|
||||
"""
|
||||
|
||||
|
||||
class NaiveFaithfulness(_FaithfulnessHelper):
|
||||
"""
|
||||
Calculator for naive faithfulness.
|
||||
|
||||
Naive faithfulness, the metric replace several pixels on original image by
|
||||
specific method for each perturbations. The metric predicts on the perturbed
|
||||
images and record a series of probabilities. Then calculates the
|
||||
correlation between prob distribution and averaged feature importance.
|
||||
Higher correlation indicates better faithfulness.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_method (str): specify the method to replace the pixel.
|
||||
Current support: ['Constant', 'GaussianBlur']
|
||||
is_accumulate (bool): whether to accumulate the former perturbations to
|
||||
the later perturbations.
|
||||
Default: False.
|
||||
perturb_pixel_per_step (Optional[int]): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (Optional[int]): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perturb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
kwargs: specific perturb_method will require
|
||||
different arguments. Below lists required args for each method.
|
||||
|
||||
'Constant': base_value (int)
|
||||
'GaussianBlur': sigma (float): 0.7
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_method: str,
|
||||
is_accumulate: bool = False,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
super(NaiveFaithfulness, self).__init__(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_mode='Deletion',
|
||||
perturb_method=perturb_method,
|
||||
is_accumulate=is_accumulate,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
**kwargs)
|
||||
|
||||
def calc_faithfulness(self,
|
||||
inputs: _Array,
|
||||
model: _Module,
|
||||
targets: _Label,
|
||||
saliency: _Array) -> np.ndarray:
|
||||
"""
|
||||
Calculate naive faithfulness.
|
||||
|
||||
Args:
|
||||
inputs (_Array): sample to calculate faithfulness score
|
||||
model (_Module): model to explanation
|
||||
targets (_Label): label to explanation on.
|
||||
saliency (_Array): Saliency map of given inputs and targets from the
|
||||
explainer.
|
||||
|
||||
Return:
|
||||
- faithfulness (np.ndarray): faithfulness score
|
||||
|
||||
"""
|
||||
reference = self._get_reference(inputs)
|
||||
perturbations, masks = self._perturb(
|
||||
inputs, saliency, reference, return_mask=True)
|
||||
feature_importance = _calc_feature_importance(saliency, masks)
|
||||
|
||||
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
|
||||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
faithfulness = calc_correlation(feature_importance, predictions)
|
||||
normalized_faithfulness = (faithfulness + 1) / 2
|
||||
return np.array([normalized_faithfulness], np.float)
|
||||
|
||||
|
||||
class DeletionAUC(_FaithfulnessHelper):
|
||||
""" Calculator for deletion AUC.
|
||||
|
||||
For Deletion AUC, the metric accumulative replace pixels on origin
|
||||
images through specific 'perturb_method', predict on the perturbed images
|
||||
and record series of probabilities. The metric then calculates the AUC of
|
||||
the probability variation curve during perturbations. Faithfulness is define
|
||||
as (1 - deletion_AUC). Higher score indicates better faithfulness of
|
||||
explanation.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_method (str): specify the method to replace the pixel.
|
||||
Current support: ['Constant', 'GaussianBlur']
|
||||
perturb_pixel_per_step (Optional[int]): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (Optional[int]): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perterb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
kwargs: specific perturb_method will require
|
||||
different arguments. Below lists required args for each method.
|
||||
|
||||
'Constant': base_value (int)
|
||||
'GaussianBlur': sigma (float): 0.7
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_method: str,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
super(DeletionAUC, self).__init__(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_mode='Deletion',
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=True,
|
||||
**kwargs)
|
||||
|
||||
def calc_faithfulness(self,
|
||||
inputs: _Array,
|
||||
model: _Module,
|
||||
targets: _Label,
|
||||
saliency: _Array) -> np.ndarray:
|
||||
"""
|
||||
Calculate faithfulness through deletion AUC.
|
||||
|
||||
Args:
|
||||
inputs (_Array): sample to calculate faithfulness score
|
||||
model (_Module): model to explanation
|
||||
targets (_Label): label to explanation on.
|
||||
saliency (_Array): Saliency map of given inputs and targets from the
|
||||
explainer.
|
||||
|
||||
Return:
|
||||
- faithfulness (float): faithfulness score
|
||||
|
||||
"""
|
||||
reference = self._get_reference(inputs)
|
||||
perturbations = self._perturb(inputs, saliency, reference)
|
||||
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
|
||||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
input_tensor = op.ExpandDims()(ms.Tensor(inputs, ms.float32), 0)
|
||||
original_output = model(input_tensor).asnumpy()[:, targets]
|
||||
|
||||
auc = calc_auc(original_output - predictions)
|
||||
return np.array([1 - auc])
|
||||
|
||||
|
||||
class InsertionAUC(_FaithfulnessHelper):
|
||||
""" Calculator for insertion AUC.
|
||||
|
||||
For Insertion AUC, the metric accumulative replace pixels of reference
|
||||
image by pixels from origin image, like inserting pixel from origin image to
|
||||
reference. The reference if generated through specific 'perturb_method'.
|
||||
The metric predicts on the perturbed images and records series of
|
||||
probabilities. The metric then calculates the AUC of the probability
|
||||
variation curve during perturbations. Faithfulness is define as (1 -
|
||||
deletion_AUC). Higher score indicates better faithfulness of explanation.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
perturb_method (str): specify the method to replace the pixel.
|
||||
Current support: ['Constant', 'GaussianBlur']
|
||||
perturb_pixel_per_step (Optional[int]): number of pixel to perturb
|
||||
for each perturbation. If perturb_pixel_per_step is None, actual
|
||||
perturb_pixel_per_step will be calculate by:
|
||||
num_image_pixel * perturb_percent / num_perturb_steps.
|
||||
Default: None
|
||||
num_perturbations (Optional[int]): number of perturbations. If
|
||||
num_perturbations if None, it will be calculated by:
|
||||
num_image_pixel * perterb_percent / perturb_pixel_per_step.
|
||||
Default: None
|
||||
kwargs: specific perturb_method will require
|
||||
different arguments. Below lists required args for each method.
|
||||
|
||||
'Constant': base_value (int)
|
||||
'GaussianBlur': sigma (float): 0.7
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
perturb_percent: float,
|
||||
perturb_method: str,
|
||||
perturb_pixel_per_step: Optional[int] = None,
|
||||
num_perturbations: Optional[int] = None,
|
||||
**kwargs):
|
||||
super(InsertionAUC, self).__init__(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_mode='Insertion',
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=perturb_pixel_per_step,
|
||||
num_perturbations=num_perturbations,
|
||||
is_accumulate=True,
|
||||
**kwargs)
|
||||
|
||||
def calc_faithfulness(self,
|
||||
inputs: _Array,
|
||||
model: _Module,
|
||||
targets: _Label,
|
||||
saliency: _Array) -> np.ndarray:
|
||||
"""
|
||||
Calculate faithfulness through insertion AUC.
|
||||
|
||||
Args:
|
||||
inputs (_Array): sample to calculate faithfulness score
|
||||
model (_Module): model to explanation
|
||||
targets (_Label): label to explanation on.
|
||||
saliency (_Array): Saliency map of given inputs and targets from the
|
||||
explainer.
|
||||
|
||||
Return:
|
||||
- faithfulness (float): faithfulness score
|
||||
|
||||
"""
|
||||
reference = self._get_reference(inputs)
|
||||
perturbations = self._perturb(inputs, saliency, reference)
|
||||
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
|
||||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
base_tensor = op.ExpandDims()(ms.Tensor(reference, ms.float32), 0)
|
||||
base_outputs = model(base_tensor).asnumpy()[:, targets]
|
||||
|
||||
auc = calc_auc(predictions - base_outputs)
|
||||
return np.array([auc])
|
||||
|
||||
|
||||
class Faithfulness(AttributionMetric):
|
||||
"""
|
||||
Provides evaluation on faithfulness on XAI explanations.
|
||||
|
||||
Faithfulness first generate saliency map with given explainers and calculate faithfulness based on different
|
||||
faithfulness metric.
|
||||
|
||||
Args:
|
||||
num_labels (int): number of labels
|
||||
metric (str): the specifi metric to quantify faithfulness.
|
||||
Options: 'DeletionAUC', 'InsertionAUC', 'NaiveFaithfulness'.
|
||||
Default: 'NaiveFaithfulness'.
|
||||
|
||||
Examples:
|
||||
>>> # init a `Faithfulness` object
|
||||
>>> num_labels = 10
|
||||
>>> metric = "InsertionAUC"
|
||||
>>> faithfulness = Faithfulness(num_labels, metric)
|
||||
"""
|
||||
_methods = [NaiveFaithfulness, DeletionAUC, InsertionAUC]
|
||||
|
||||
def __init__(self, num_labels: int, metric: str = "NaiveFaithfulness"):
|
||||
super(Faithfulness, self).__init__(num_labels)
|
||||
|
||||
perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument
|
||||
perturb_method = "Constant" # perturbation method, all the perturbed pixels will be set to constant
|
||||
num_perturb_pixel_per_step = None # number of pixels for each perturbation step
|
||||
num_perturb_steps = 100 # separate the perturbation progress in to 100 steps.
|
||||
base_value = 0.0 # the pixel value set for the perturbed pixels
|
||||
|
||||
self._verify_metrics(metric)
|
||||
for method in self._methods:
|
||||
if metric == method.__name__:
|
||||
self._faithfulness_helper = method(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=num_perturb_pixel_per_step,
|
||||
num_perturbations=num_perturb_steps,
|
||||
base_value=base_value
|
||||
)
|
||||
|
||||
def evaluate(self, explainer, inputs, targets, saliency=None):
|
||||
"""
|
||||
Evaluate faithfulness on a single data sample.
|
||||
|
||||
Args:
|
||||
explainer (Explainer): A explainer instance object.
|
||||
The 'Explainer' object see mindspore/explainer/explanation.
|
||||
inputs (Tensor): data sample. Currently only support single sample at each call.
|
||||
targets (Union[int, Tensor]): A target label to evaluate on.
|
||||
saliency (Tensor): A saliency tensor.
|
||||
|
||||
Return:
|
||||
np.ndarray: result of faithfulness evaluated on explainer.
|
||||
|
||||
Notes:
|
||||
To apply `Faithfulness` to evaluate an explainer, this explainer must be initialize with a network that
|
||||
contains the output activation function. Otherwise, the results will not be correct.
|
||||
|
||||
Examples:
|
||||
>>> # init an explainer, the network should contain the output activation function.
|
||||
>>> network = nn.SequentialCell([resnet50, nn.Sigmoid()])
|
||||
>>> gradient = Gradient(network)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> targets = 5
|
||||
>>> # usage 1: input the explainer and the data to be explained,
|
||||
>>> # calculate the faithfulness with the specified metric
|
||||
>>> res = faithfulness.evaluate(gradient, inputs, targets)
|
||||
>>> # usage 2: input the generated saliency map
|
||||
>>> saliency = gradient(inputs, targets)
|
||||
>>> res = faithfulenss.evaluate(gradient, inputs, targets, saliency)
|
||||
"""
|
||||
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
|
||||
if saliency is None:
|
||||
saliency = explainer(inputs, targets)
|
||||
|
||||
inputs = format_tensor_to_ndarray(inputs)
|
||||
saliency = format_tensor_to_ndarray(saliency)
|
||||
|
||||
inputs = inputs.squeeze(axis=0)
|
||||
saliency = saliency.squeeze()
|
||||
if len(saliency.shape) != 2:
|
||||
raise ValueError('Squeezed saliency map is expected to 2D, but receive {}.'.format(len(saliency.shape)))
|
||||
|
||||
faithfulness = self._faithfulness_helper.calc_faithfulness(inputs=inputs, model=explainer.model,
|
||||
targets=targets, saliency=saliency)
|
||||
return faithfulness
|
||||
|
||||
def _verify_metrics(self, metric: str):
|
||||
supports = [x.__name__ for x in self._methods]
|
||||
if metric not in supports:
|
||||
raise ValueError("Metric should be one of {}.".format(supports))
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Localization metrics."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train._utils import check_value_type
|
||||
from .metric import AttributionMetric
|
||||
from ..._operators import maximum, reshape, Tensor
|
||||
from ..._utils import format_tensor_to_ndarray
|
||||
|
||||
|
||||
def _get_max_position(saliency):
|
||||
"""Get the position of the max pixel of the saliency map."""
|
||||
saliency = saliency.asnumpy()
|
||||
w = saliency.shape[3]
|
||||
saliency = np.reshape(saliency, (len(saliency), -1))
|
||||
max_arg = np.argmax(saliency, axis=1)
|
||||
return max_arg // w, max_arg - (max_arg // w) * w
|
||||
|
||||
|
||||
def _mask_out_saliency(saliency, threshold):
|
||||
"""Keep the saliency map with value greater than threshold."""
|
||||
max_value = maximum(saliency)
|
||||
mask_out = saliency > (reshape(max_value, (len(saliency), -1, 1, 1)) * threshold)
|
||||
return mask_out
|
||||
|
||||
|
||||
class Localization(AttributionMetric):
|
||||
"""
|
||||
Provides evaluation on the localization capability of XAI methods.
|
||||
|
||||
We support two metrics for the evaluation os localization capability: "PointingGame" and "IoSR".
|
||||
For metric "PointingGame", the localization capability is calculated as the ratio of data in which the max position
|
||||
of their saliency maps lies within the bounding boxes. Specifically, for a single datum, given the saliency map and
|
||||
its bounding box, if the max point of its saliency map lies within the bounding box, the evaluation result is 1
|
||||
otherwise 0.
|
||||
|
||||
For metric "IoSR" (Intersection over Salient Region), the localization capability is calculated as the intersection
|
||||
of the bounding box and the salient region over the area of the salient region.
|
||||
|
||||
Args:
|
||||
num_labels (int): number of classes in the dataset.
|
||||
metric (str): specific metric to calculate localization capability.
|
||||
Options: "PointingGame", "IoSR".
|
||||
Default: "PointingGame".
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer.benchmark import Localization
|
||||
>>> num_labels = 100
|
||||
>>> localization = Localization(num_labels, "PointingGame")
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_labels,
|
||||
metric="PointingGame"
|
||||
):
|
||||
super(Localization, self).__init__(num_labels)
|
||||
self._verify_metrics(metric)
|
||||
self._metric = metric
|
||||
|
||||
# Arg for specific metric, for "PointingGame" it should be an integer indicating the tolerance
|
||||
# of "PointingGame", while for "IoSR" it should be a float number
|
||||
# indicating the threshold to choose salient region. Default: 25.
|
||||
if self._metric == "PointingGame":
|
||||
self._metric_arg = 15
|
||||
else:
|
||||
self._metric_arg = 0.5
|
||||
|
||||
@staticmethod
|
||||
def _verify_metrics(metric):
|
||||
"""Verify the user defined metric."""
|
||||
supports = ["PointingGame", "IoSR"]
|
||||
if metric not in supports:
|
||||
raise ValueError("Metric should be one of {}".format(supports))
|
||||
|
||||
def evaluate(self, explainer, inputs, targets, saliency=None, mask=None):
|
||||
"""
|
||||
Evaluate localization on a single data sample.
|
||||
|
||||
Args:
|
||||
explainer (Explanation): The explainer to be evaluated, see `mindspore/explainer/explanation`.
|
||||
inputs (Tensor): data sample. Currently only support single sample at each call.
|
||||
targets (int): target label to evaluate on.
|
||||
saliency (Tensor): A saliency tensor.
|
||||
mask (Union[Tensor, np.ndarray]): ground truth bounding box/masks for the inputs w.r.t targets.
|
||||
|
||||
Returns:
|
||||
np.ndarray, result of localization evaluated on explainer
|
||||
|
||||
Examples:
|
||||
>>> # init an explainer, the network should contain the output activation function.
|
||||
>>> gradient = Gradient(network)
|
||||
>>> inputs = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||
>>> masks = np.zeros(1, 1, 224, 224)
|
||||
>>> masks[:, :, 65: 100, 65: 100] = 1
|
||||
>>> targets = 5
|
||||
>>> # usage 1: input the explainer and the data to be explained,
|
||||
>>> # calculate the faithfulness with the specified metric
|
||||
>>> res = localization.evaluate(gradient, inputs, targets, mask=masks)
|
||||
>>> # usage 2: input the generated saliency map
|
||||
>>> saliency = gradient(inputs, targets)
|
||||
>>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks)
|
||||
"""
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
|
||||
mask_np = format_tensor_to_ndarray(mask)[0]
|
||||
|
||||
if saliency is None:
|
||||
saliency = explainer(inputs, targets)
|
||||
|
||||
if self._metric == "PointingGame":
|
||||
point = _get_max_position(saliency)
|
||||
|
||||
x, y = np.meshgrid(
|
||||
(np.arange(mask_np.shape[1]) - point[0]) ** 2,
|
||||
(np.arange(mask_np.shape[2]) - point[1]) ** 2)
|
||||
max_region = (x + y) < self._metric_arg ** 2
|
||||
|
||||
# if max_region has overlap with mask_np return 1 otherwise 0.
|
||||
result = 1 if (mask_np.astype(bool) & max_region).any() else 0
|
||||
|
||||
elif self._metric == "IoSR":
|
||||
mask_out = _mask_out_saliency(saliency, self._metric_arg)
|
||||
mask_out_np = format_tensor_to_ndarray(mask_out)
|
||||
overlap = np.sum(mask_np.astype(bool) & mask_out_np.astype(bool))
|
||||
saliency_area = np.sum(mask_out_np)
|
||||
result = overlap / saliency_area.clip(min=1e-10)
|
||||
return np.array([result], np.float)
|
||||
|
||||
def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask):
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
check_value_type('mask', mask, (Tensor, np.ndarray))
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument mask must be 4D Tensor')
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Base class for XAI metrics."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train._utils import check_value_type
|
||||
from ..._operators import Tensor
|
||||
from ..._utils import format_tensor_to_ndarray
|
||||
from ...explanation._attribution._attribution import Attribution
|
||||
|
||||
|
||||
def verify_argument(inputs, arg_name):
|
||||
"""Verify the validity of the parsed arguments."""
|
||||
check_value_type(arg_name, inputs, Tensor)
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument {} must be a 4D Tensor.'.format(arg_name))
|
||||
if len(inputs) > 1:
|
||||
raise ValueError('Support single data evaluation only, but got {}.'.format(len(inputs)))
|
||||
|
||||
|
||||
def verify_targets(targets, num_labels):
|
||||
"""Verify the validity of the parsed targets."""
|
||||
check_value_type('targets', targets, (int, Tensor))
|
||||
|
||||
if isinstance(targets, Tensor):
|
||||
if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != 1):
|
||||
raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
|
||||
'it should have the length = 1 as we only support single evaluation now.')
|
||||
targets = int(targets.asnumpy()[0]) if len(targets.shape) == 1 else int(targets.asnumpy())
|
||||
if targets > num_labels - 1 or targets < 0:
|
||||
raise ValueError('Parsed targets exceed the label range.')
|
||||
|
||||
|
||||
class AttributionMetric:
|
||||
"""Super class of XAI metric class used in classification scenarios."""
|
||||
|
||||
def __init__(self, num_labels=None):
|
||||
self._num_labels = num_labels
|
||||
self._global_results = {i: [] for i in range(num_labels)}
|
||||
|
||||
def evaluate(self, explainer, inputs, targets, saliency=None):
|
||||
"""This function evaluates on a single sample and return the result."""
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, result, targets):
|
||||
"""Aggregates single result to global_results."""
|
||||
if isinstance(result, float):
|
||||
if isinstance(targets, int):
|
||||
self._global_results[targets].append(result)
|
||||
else:
|
||||
target_np = format_tensor_to_ndarray(targets)
|
||||
if len(target_np) > 1:
|
||||
raise ValueError("One result can not be aggreated to multiple targets.")
|
||||
else:
|
||||
result_np = format_tensor_to_ndarray(result)
|
||||
if isinstance(targets, int):
|
||||
for res in result_np:
|
||||
self._global_results[targets].append(float(res))
|
||||
else:
|
||||
target_np = format_tensor_to_ndarray(targets)
|
||||
if len(target_np) != len(result_np):
|
||||
raise ValueError("Length of result does not match with length of targets.")
|
||||
for tar, res in zip(target_np, result_np):
|
||||
self._global_results[int(tar)].append(float(res))
|
||||
|
||||
def reset(self):
|
||||
"""Resets global_result."""
|
||||
self._global_results = {i: [] for i in range(self._num_labels)}
|
||||
|
||||
@property
|
||||
def class_performances(self):
|
||||
"""
|
||||
Get the class performances by global result.
|
||||
|
||||
|
||||
Returns:
|
||||
(:class:`np.ndarray`): :attr:`num_labels`-dimensional vector
|
||||
containing per-class performance.
|
||||
"""
|
||||
count = np.array(
|
||||
[len(self._global_results[i]) for i in range(self._num_labels)])
|
||||
result_sum = np.array(
|
||||
[sum(self._global_results[i]) for i in range(self._num_labels)])
|
||||
return result_sum / count.clip(min=1)
|
||||
|
||||
@property
|
||||
def performance(self):
|
||||
"""
|
||||
Get the performance by global result.
|
||||
|
||||
Returns:
|
||||
(:class:`float`): mean performance.
|
||||
"""
|
||||
count = sum(
|
||||
[len(self._global_results[i]) for i in range(self._num_labels)])
|
||||
result_sum = sum(
|
||||
[sum(self._global_results[i]) for i in range(self._num_labels)])
|
||||
if count == 0:
|
||||
return 0
|
||||
return result_sum / count
|
||||
|
||||
def get_results(self):
|
||||
"""Global result of the metric can be return"""
|
||||
return self._global_results
|
||||
|
||||
def _check_evaluate_param(self, explainer, inputs, targets, saliency):
|
||||
"""Check the evaluate parameters."""
|
||||
check_value_type('explainer', explainer, Attribution)
|
||||
verify_argument(inputs, 'inputs')
|
||||
verify_targets(targets, self._num_labels)
|
||||
check_value_type('saliency', saliency, (Tensor, type(None)))
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Predefined Attribution explainers."""
|
||||
|
||||
from ._attribution._backprop.gradcam import GradCAM
|
||||
from ._attribution._backprop.gradient import Gradient
|
||||
from ._attribution._backprop.modified_relu import Deconvolution, GuidedBackprop
|
||||
|
||||
__all__ = [
|
||||
'Gradient',
|
||||
'Deconvolution',
|
||||
'GuidedBackprop',
|
||||
'GradCAM',
|
||||
]
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Predefined Attribution explainers."""
|
||||
from ._backprop.gradcam import GradCAM
|
||||
from ._backprop.gradient import Gradient
|
||||
from ._backprop.modified_relu import Deconvolution, GuidedBackprop
|
||||
|
||||
__all__ = [
|
||||
'Gradient',
|
||||
'Deconvolution',
|
||||
'GuidedBackprop',
|
||||
'GradCAM',
|
||||
]
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Attribution."""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import mindspore as ms
|
||||
|
||||
class Attribution:
|
||||
r"""
|
||||
Basic class of attributing the salient score
|
||||
|
||||
The explainers which explanation through attributing the relevance scores
|
||||
should inherit this class.
|
||||
|
||||
Args:
|
||||
network (ms.nn.Cell): The black-box model to explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
self._verify_model(network)
|
||||
self._model = network
|
||||
|
||||
@staticmethod
|
||||
def _verify_model(model):
|
||||
"""
|
||||
Verify the input `network` for __init__ function.
|
||||
"""
|
||||
if not isinstance(model, ms.nn.Cell):
|
||||
raise TypeError("The parsed `network` must be a `mindspore.nn.Cell` object.")
|
||||
|
||||
|
||||
__call__: Callable
|
||||
"""
|
||||
The explainers return the explanations by calling directly on the explanation.
|
||||
Derived class should overwrite this implementations for different
|
||||
algorithms.
|
||||
|
||||
Args:
|
||||
input (ms.Tensor): Input tensor to be explained.
|
||||
|
||||
Returns:
|
||||
- saliency map (ms.Tensor): saliency map of the input.
|
||||
"""
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._model
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Backprop-base _attribution explainer."""
|
||||
|
||||
from .gradient import Gradient
|
||||
from .gradcam import GradCAM
|
||||
from .modified_relu import Deconvolution, GuidedBackprop
|
||||
|
||||
__all__ = ['Gradient',
|
||||
'GradCAM',
|
||||
'Deconvolution',
|
||||
'GuidedBackprop']
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing utility functions."""
|
||||
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
from ...._utils import unify_inputs, unify_targets, generate_one_hot
|
||||
|
||||
|
||||
def compute_gradients(model, inputs, targets=None, weights=None):
|
||||
r"""
|
||||
Compute the gradient of output w.r.t input.
|
||||
|
||||
Args:
|
||||
model (`ms.nn.Cell`): Differentiable black-box model.
|
||||
inputs (`ms.Tensor`): Input to calculate gradient and explanation.
|
||||
targets (int, optional): Target label id specifying which category to compute gradient. Default: None.
|
||||
weights (`ms.Tensor`, optional): Custom weights for computing gradients. The shape of weights should match the
|
||||
model outputs. If None is provided, an one-hot weights with one in targets positions will be used instead.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
saliency map (ms.Tensor): Gradient back-propagated to the input.
|
||||
"""
|
||||
inputs = unify_inputs(inputs)
|
||||
if targets is None and weights is None:
|
||||
raise ValueError('Must provide one of targets or weights')
|
||||
if weights is None:
|
||||
targets = unify_targets(targets)
|
||||
output = model(*inputs).asnumpy()
|
||||
num_categories = output.shape[-1]
|
||||
weights = generate_one_hot(targets, num_categories)
|
||||
|
||||
grad_op = GradOperation(
|
||||
get_all=True, get_by_list=False, sens_param=True)(model)
|
||||
gradients = grad_op(*inputs, weights)
|
||||
return gradients[0]
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" GradCAM and GuidedGradCAM. """
|
||||
|
||||
from mindspore.ops import operations as op
|
||||
|
||||
from .backprop_utils import compute_gradients
|
||||
from .intermediate_layer import IntermediateLayerAttribution
|
||||
from ...._utils import ForwardProbe, retrieve_layer, unify_inputs, unify_targets
|
||||
|
||||
|
||||
|
||||
def _gradcam_aggregation(attributions):
|
||||
"""
|
||||
Aggregate the gradient and activation to get the final _attribution.
|
||||
|
||||
Args:
|
||||
attributions (Tensor): the _attribution with channel dimension.
|
||||
|
||||
Returns:
|
||||
Tensor: the _attribution with channel dimension aggregated.
|
||||
"""
|
||||
sum_ = op.ReduceSum(keep_dims=True)
|
||||
relu_ = op.ReLU()
|
||||
attributions = relu_(sum_(attributions, 1))
|
||||
return attributions
|
||||
|
||||
|
||||
class GradCAM(IntermediateLayerAttribution):
|
||||
r"""
|
||||
Provides GradCAM explanation method.
|
||||
|
||||
GradCAM generates saliency map at intermediate layer.
|
||||
..math:
|
||||
\alpha_k^c = 1/Z \sum_i \sum_j \div{\partial{y^c}}{\partial{A_{i,j}^k}}
|
||||
L_{GradCAM} = ReLu(\sum_k \alpha_k^c A^k)
|
||||
For more details, please refer to the original paper: GradCAM
|
||||
[https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf]
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
layer (str): The layer name to generate the explanation at. Default: ''.
|
||||
If default, the explantion will be generated at the input layer.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||
>>> # you may also use the net itself.
|
||||
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> # specify a layer name to generate explanation, usually the layer can be set as the last conv layer.
|
||||
>>> layer_name = '0.layer4'
|
||||
>>> # init GradCAM with a trained network and specify the layer to obtain
|
||||
>>> gradcam = GradCAM(net, layer=layer_name)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gradcam(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
network,
|
||||
layer=""):
|
||||
super(GradCAM, self).__init__(network, layer)
|
||||
|
||||
self._saliency_cell = retrieve_layer(self._backward_model, target_layer=layer)
|
||||
self._avgpool = op.ReduceMean(keep_dims=True)
|
||||
self._intermediate_grad = None
|
||||
self._aggregation_fn = _gradcam_aggregation
|
||||
self._resize_mode = 'bilinear'
|
||||
|
||||
def _hook_cell(self):
|
||||
if self._saliency_cell:
|
||||
self._saliency_cell.register_backward_hook(self._cell_hook_fn)
|
||||
self._saliency_cell.enable_hook = True
|
||||
self._intermediate_grad = None
|
||||
|
||||
def _cell_hook_fn(self, _, grad_input, grad_output):
|
||||
"""
|
||||
Hook function to deal with the backward gradient.
|
||||
|
||||
The arguments are set as required by Cell.register_back_hook
|
||||
"""
|
||||
self._intermediate_grad = grad_input
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
"""
|
||||
Call function for `GradCAM`.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input data to be explained, 4D Tensor.
|
||||
targets (Union[Tensor, int]): The label of interest. It should be a 1D or 0D Tensor, or an integer.
|
||||
If `targets` is a 1D Tensor, its length should be the same as `inputs`.
|
||||
"""
|
||||
self._verify_data(inputs, targets)
|
||||
self._hook_cell()
|
||||
|
||||
with ForwardProbe(self._saliency_cell) as probe:
|
||||
|
||||
inputs = unify_inputs(inputs)
|
||||
targets = unify_targets(targets)
|
||||
|
||||
gradients = compute_gradients(self._backward_model, *inputs, targets)
|
||||
|
||||
# get intermediate activation
|
||||
activation = (probe.value,)
|
||||
|
||||
if self._layer == "":
|
||||
activation = inputs
|
||||
self._intermediate_grad = unify_inputs(gradients)
|
||||
if self._intermediate_grad is not None:
|
||||
# average pooling on gradients
|
||||
intermediate_grad = unify_inputs(
|
||||
self._avgpool(self._intermediate_grad[0], (2, 3)))
|
||||
else:
|
||||
raise ValueError("Gradient for intermediate layer is not "
|
||||
"obtained")
|
||||
mul = op.Mul()
|
||||
attribution = self._aggregation_fn(
|
||||
mul(*intermediate_grad, *activation))
|
||||
if self._resize:
|
||||
attribution = self._resize_fn(attribution, *inputs,
|
||||
mode=self._resize_mode)
|
||||
self._intermediate_grad = None
|
||||
|
||||
return attribution
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Gradient explainer."""
|
||||
from copy import deepcopy
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as op
|
||||
from mindspore.train._utils import check_value_type
|
||||
from ...._operators import reshape, sqrt, Tensor
|
||||
from .._attribution import Attribution
|
||||
from .backprop_utils import compute_gradients
|
||||
from ...._utils import unify_inputs, unify_targets
|
||||
|
||||
|
||||
def _get_hook(bntype, cache):
|
||||
"""Provide backward hook function for BatchNorm layer in eval mode."""
|
||||
var, gamma, eps = cache
|
||||
if bntype == "2d":
|
||||
var = reshape(var, (1, -1, 1, 1))
|
||||
gamma = reshape(gamma, (1, -1, 1, 1))
|
||||
elif bntype == "1d":
|
||||
var = reshape(var, (1, -1, 1))
|
||||
gamma = reshape(gamma, (1, -1, 1))
|
||||
|
||||
def reset_gradient(_, grad_input, grad_output):
|
||||
grad_output = grad_input[0] * gamma / sqrt(var + eps)
|
||||
return grad_output
|
||||
|
||||
return reset_gradient
|
||||
|
||||
|
||||
def _abs_max(gradients):
|
||||
"""
|
||||
Transform gradients to saliency through abs then take max along
|
||||
channels.
|
||||
"""
|
||||
gradients = op.Abs()(gradients)
|
||||
saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1)
|
||||
return saliency
|
||||
|
||||
|
||||
class Gradient(Attribution):
|
||||
r"""
|
||||
Provides Gradient explanation method.
|
||||
|
||||
Gradient is the simplest attribution method which uses the naive gradients of outputs w.r.t inputs as the
|
||||
explanation.
|
||||
|
||||
.. math::
|
||||
_attribution = \div{\delta{y}, \delta{x}}
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||
>>> # you may also use the net itself. The saliency map might be slightly different for softmax activation.
|
||||
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> # init Gradient with a trained network.
|
||||
>>> gradient = Gradient(net)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gradient(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(Gradient, self).__init__(network)
|
||||
self._backward_model = deepcopy(network)
|
||||
self._backward_model.set_train(False)
|
||||
self._backward_model.set_grad(False)
|
||||
self._hook_bn()
|
||||
self._grad_op = compute_gradients
|
||||
self._aggregation_fn = _abs_max
|
||||
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
"""
|
||||
Call function for `Gradient`.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input data to be explained, 4D Tensor.
|
||||
targets (Union[Tensor, int]): The label of interest. It should be a 1D or 0D Tensor, or an integer.
|
||||
If `targets` is a 1D `Tensor`, its length should be the same as `inputs`.
|
||||
"""
|
||||
self._verify_data(inputs, targets)
|
||||
inputs = unify_inputs(inputs)
|
||||
targets = unify_targets(targets)
|
||||
|
||||
gradient = self._grad_op(self._backward_model, *inputs, targets)
|
||||
saliency = self._aggregation_fn(gradient)
|
||||
return saliency
|
||||
|
||||
def _hook_bn(self):
|
||||
"""Hook BatchNorm layer for `self._backward_model.`"""
|
||||
for _, cell in self._backward_model.cells_and_names():
|
||||
if isinstance(cell, nn.BatchNorm2d):
|
||||
cache = (cell.moving_variance, cell.gamma, cell.eps)
|
||||
cell.register_backward_hook(_get_hook("2d", cache=cache))
|
||||
elif isinstance(cell, nn.BatchNorm1d):
|
||||
cache = (cell.moving_variance, cell.gamma, cell.eps)
|
||||
cell.register_backward_hook(_get_hook("1d", cache=cache))
|
||||
|
||||
@staticmethod
|
||||
def _verify_data(inputs, targets):
|
||||
"""Verify the validity of the parsed inputs."""
|
||||
check_value_type('inputs', inputs, Tensor)
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument inputs must be 4D Tensor')
|
||||
check_value_type('targets', targets, (Tensor, int))
|
||||
if isinstance(targets, Tensor):
|
||||
if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != len(inputs)):
|
||||
raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
|
||||
'it should have the same length as inputs.')
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Base class IntermediateLayerAttribution"""
|
||||
|
||||
from .gradient import Gradient
|
||||
from ...._utils import resize as resize_fn
|
||||
|
||||
|
||||
class IntermediateLayerAttribution(Gradient):
|
||||
"""
|
||||
Base class for generating _attribution map at intermediate layer.
|
||||
|
||||
Args:
|
||||
network (nn.Cell): DNN model to be explained.
|
||||
layer (str, optional): string that specifies the layer to generate
|
||||
intermediate _attribution. When using default value, the input layer
|
||||
will be specified. Default: ''.
|
||||
"""
|
||||
|
||||
def __init__(self, network, layer=''):
|
||||
super(IntermediateLayerAttribution, self).__init__(network)
|
||||
|
||||
# Whether resize the _attribution layer to the input size.
|
||||
self._resize = True
|
||||
# string that specifies the resize mode. Default: 'nearest_neighbor'.
|
||||
self._resize_mode = 'nearest_neighbor'
|
||||
|
||||
self._layer = layer
|
||||
|
||||
@staticmethod
|
||||
def _resize_fn(attributions, inputs, mode):
|
||||
"""Resize the intermediate layer _attribution to the same size as inputs."""
|
||||
height, width = inputs.shape[2], inputs.shape[3]
|
||||
return resize_fn(attributions, (height, width), mode)
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Explainer with modified ReLU."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as op
|
||||
|
||||
from .gradient import Gradient
|
||||
from ...._utils import (
|
||||
unify_inputs,
|
||||
unify_targets,
|
||||
)
|
||||
|
||||
|
||||
class ModifiedReLU(Gradient):
|
||||
"""Basic class for modified ReLU explanation."""
|
||||
|
||||
def __init__(self, network, use_relu_backprop=False):
|
||||
super(ModifiedReLU, self).__init__(network)
|
||||
self.use_relu_backprop = use_relu_backprop
|
||||
self.hooked_list = []
|
||||
|
||||
def __call__(self, inputs, targets):
|
||||
self._verify_data(inputs, targets)
|
||||
inputs = unify_inputs(inputs)
|
||||
targets = unify_targets(targets)
|
||||
|
||||
self._hook_relu_backward()
|
||||
gradients = self._grad_op(self._backward_model, inputs, targets)
|
||||
saliency = self._aggregation_fn(gradients)
|
||||
|
||||
return saliency
|
||||
|
||||
def _hook_relu_backward(self):
|
||||
"""Set backward hook for ReLU layers."""
|
||||
for _, cell in self._backward_model.cells_and_names():
|
||||
if isinstance(cell, nn.ReLU):
|
||||
cell.register_backward_hook(self._backward_hook)
|
||||
self.hooked_list.append(cell)
|
||||
|
||||
def _backward_hook(self, _, grad_inputs, grad_outputs):
|
||||
"""Hook function for ReLU layers."""
|
||||
inputs = grad_inputs if self.use_relu_backprop else grad_outputs
|
||||
relu = op.ReLU()
|
||||
if isinstance(inputs, tuple):
|
||||
return relu(*inputs)
|
||||
return relu(inputs)
|
||||
|
||||
|
||||
class Deconvolution(ModifiedReLU):
|
||||
"""
|
||||
Deconvolution explanation.
|
||||
|
||||
To use `Deconvolution`, the `ReLU` operations in the network must be implemented with `mindspore.nn.Cell` object
|
||||
rather than `mindspore.ops.Operations.ReLU`. Otherwise, the results will not be correct.
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||
>>> # you may also use the net itself. The saliency map might be slightly different for softmax activation.
|
||||
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> # init Gradient with a trained network.
|
||||
>>> deconvolution = Deconvolution(net)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = deconvolution(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(Deconvolution, self).__init__(network, use_relu_backprop=True)
|
||||
|
||||
|
||||
class GuidedBackprop(ModifiedReLU):
|
||||
"""
|
||||
Guided-Backpropation explanation.
|
||||
|
||||
To use `GuidedBackprop`, the `ReLU` operations in the network must be implemented with `mindspore.nn.Cell` object
|
||||
rather than `mindspore.ops.Operations.ReLU`. Otherwise, the results will not be correct.
|
||||
|
||||
Args:
|
||||
network (Cell): The black-box model to be explained.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
>>> param_dict = load_checkpoint("resnet50.ckpt")
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> # bind net with its output activation if you wish, e.g. nn.Sigmoid(),
|
||||
>>> # you may also use the net itself. The saliency map might be slightly different for softmax activation.
|
||||
>>> net = nn.SequentialCell([net, nn.Sigmoid()])
|
||||
>>> # init Gradient with a trained network.
|
||||
>>> gbp = GuidedBackprop(net)
|
||||
>>> # parse data and the target label to be explained and get the saliency map
|
||||
>>> inputs = ms.Tensor(np.random.rand([1, 3, 224, 224]), ms.float32)
|
||||
>>> label = 5
|
||||
>>> saliency = gbp(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(GuidedBackprop, self).__init__(network, use_relu_backprop=False)
|
Loading…
Reference in New Issue