forked from mindspore-Ecosystem/mindspore
add skipping when return performance and fix bug for small-size ablation
This commit is contained in:
parent
56ce0f4a27
commit
afe360f4e8
|
@ -418,29 +418,24 @@ class ImageClassificationRunner:
|
|||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
for idx, inp in enumerate(inputs):
|
||||
inp = _EXPAND_DIMS(inp, 0)
|
||||
saliency_dict = saliency_dict_lst[idx]
|
||||
for label, saliency in saliency_dict.items():
|
||||
if isinstance(benchmarker, Localization):
|
||||
_, _, bboxes = self._unpack_next_element(next_element, True)
|
||||
if label in labels[idx]:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
|
||||
saliency=saliency)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
if isinstance(benchmarker, LabelAgnosticMetric):
|
||||
res = benchmarker.evaluate(explainer, inp)
|
||||
benchmarker.aggregate(res)
|
||||
else:
|
||||
saliency_dict = saliency_dict_lst[idx]
|
||||
for label, saliency in saliency_dict.items():
|
||||
if isinstance(benchmarker, Localization):
|
||||
_, _, bboxes = self._unpack_next_element(next_element, True)
|
||||
if label in labels[idx]:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
|
||||
saliency=saliency)
|
||||
benchmarker.aggregate(res, label)
|
||||
elif isinstance(benchmarker, LabelSensitiveMetric):
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
|
||||
benchmarker.aggregate(res, label)
|
||||
elif isinstance(benchmarker, LabelSensitiveMetric):
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
benchmarker.aggregate(res, label)
|
||||
elif isinstance(benchmarker, LabelAgnosticMetric):
|
||||
res = benchmarker.evaluate(explainer, inp)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
benchmarker.aggregate(res)
|
||||
else:
|
||||
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
|
||||
'receive {}'.format(type(benchmarker)))
|
||||
else:
|
||||
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
|
||||
'receive {}'.format(type(benchmarker)))
|
||||
|
||||
def _verify_data(self):
|
||||
"""Verify dataset and labels."""
|
||||
|
|
|
@ -382,8 +382,6 @@ class Faithfulness(LabelSensitiveMetric):
|
|||
|
||||
perturb_percent = 0.5 # ratio of pixels to be perturbed, future argument
|
||||
perturb_method = "Constant" # perturbation method, all the perturbed pixels will be set to constant
|
||||
num_perturb_pixel_per_step = None # number of pixels for each perturbation step
|
||||
num_perturb_steps = 100 # separate the perturbation progress in to 100 steps.
|
||||
base_value = 0.0 # the pixel value set for the perturbed pixels
|
||||
|
||||
check_value_type("activation_fn", activation_fn, nn.Cell)
|
||||
|
@ -395,8 +393,6 @@ class Faithfulness(LabelSensitiveMetric):
|
|||
self._faithfulness_helper = method(
|
||||
perturb_percent=perturb_percent,
|
||||
perturb_method=perturb_method,
|
||||
perturb_pixel_per_step=num_perturb_pixel_per_step,
|
||||
num_perturbations=num_perturb_steps,
|
||||
base_value=base_value
|
||||
)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Base class for XAI metrics."""
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
@ -88,11 +89,12 @@ class LabelAgnosticMetric(AttributionMetric):
|
|||
Return:
|
||||
float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned.
|
||||
"""
|
||||
if not self._global_results:
|
||||
return 0.0
|
||||
results_sum = sum(self._global_results)
|
||||
count = len(self._global_results)
|
||||
return results_sum / count
|
||||
result_sum, count = 0, 0
|
||||
for res in self._global_results:
|
||||
if math.isfinite(res):
|
||||
result_sum += res
|
||||
count += 1
|
||||
return 0. if count == 0 else result_sum / count
|
||||
|
||||
def aggregate(self, result):
|
||||
"""Aggregate single evaluation result to global results."""
|
||||
|
@ -100,7 +102,7 @@ class LabelAgnosticMetric(AttributionMetric):
|
|||
self._global_results.append(result)
|
||||
elif isinstance(result, (ms.Tensor, np.ndarray)):
|
||||
result = format_tensor_to_ndarray(result)
|
||||
self._global_results.append(float(result))
|
||||
self._global_results.extend([float(res) for res in result.reshape(-1)])
|
||||
else:
|
||||
raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))
|
||||
|
||||
|
@ -130,10 +132,12 @@ class LabelSensitiveMetric(AttributionMetric):
|
|||
|
||||
@property
|
||||
def num_labels(self):
|
||||
"""Number of labels used in evaluation."""
|
||||
return self._num_labels
|
||||
|
||||
@staticmethod
|
||||
def _verify_params(num_labels):
|
||||
"""Checks whether num_labels is valid."""
|
||||
check_value_type("num_labels", num_labels, int)
|
||||
if num_labels < 1:
|
||||
raise ValueError("Argument num_labels must be parsed with a integer > 0.")
|
||||
|
@ -147,17 +151,19 @@ class LabelSensitiveMetric(AttributionMetric):
|
|||
target_np = format_tensor_to_ndarray(targets)
|
||||
if len(target_np) > 1:
|
||||
raise ValueError("One result can not be aggreated to multiple targets.")
|
||||
else:
|
||||
result_np = format_tensor_to_ndarray(result)
|
||||
elif isinstance(result, (ms.Tensor, np.ndarray)):
|
||||
result_np = format_tensor_to_ndarray(result).reshape(-1)
|
||||
if isinstance(targets, int):
|
||||
for res in result_np:
|
||||
self._global_results[targets].append(float(res))
|
||||
else:
|
||||
target_np = format_tensor_to_ndarray(targets)
|
||||
target_np = format_tensor_to_ndarray(targets).reshape(-1)
|
||||
if len(target_np) != len(result_np):
|
||||
raise ValueError("Length of result does not match with length of targets.")
|
||||
for tar, res in zip(target_np, result_np):
|
||||
self._global_results[int(tar)].append(float(res))
|
||||
else:
|
||||
raise TypeError('Result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))
|
||||
|
||||
def reset(self):
|
||||
"""Resets global_result."""
|
||||
|
@ -168,16 +174,18 @@ class LabelSensitiveMetric(AttributionMetric):
|
|||
"""
|
||||
Get the class performances by global result.
|
||||
|
||||
|
||||
Returns:
|
||||
(:class:`np.ndarray`): :attr:`num_labels`-dimensional vector
|
||||
containing per-class performance.
|
||||
(:class:`list`): a list of performances where each value is the average score of specific class.
|
||||
"""
|
||||
count = np.array(
|
||||
[len(self._global_results[i]) for i in range(self._num_labels)])
|
||||
result_sum = np.array(
|
||||
[sum(self._global_results[i]) for i in range(self._num_labels)])
|
||||
return result_sum / count.clip(min=1)
|
||||
results_on_labels = []
|
||||
for label_id in range(self._num_labels):
|
||||
sum_of_label, count_of_label = 0, 0
|
||||
for res in self._global_results[label_id]:
|
||||
if math.isfinite(res):
|
||||
sum_of_label += res
|
||||
count_of_label += 1
|
||||
results_on_labels.append(0. if count_of_label == 0 else sum_of_label / count_of_label)
|
||||
return results_on_labels
|
||||
|
||||
@property
|
||||
def performance(self):
|
||||
|
@ -187,13 +195,13 @@ class LabelSensitiveMetric(AttributionMetric):
|
|||
Returns:
|
||||
(:class:`float`): mean performance.
|
||||
"""
|
||||
count = sum(
|
||||
[len(self._global_results[i]) for i in range(self._num_labels)])
|
||||
result_sum = sum(
|
||||
[sum(self._global_results[i]) for i in range(self._num_labels)])
|
||||
if count == 0:
|
||||
return 0
|
||||
return result_sum / count
|
||||
result_sum, count = 0, 0
|
||||
for label_id in range(self._num_labels):
|
||||
for res in self._global_results[label_id]:
|
||||
if math.isfinite(res):
|
||||
result_sum += res
|
||||
count += 1
|
||||
return 0. if count == 0 else result_sum / count
|
||||
|
||||
def get_results(self):
|
||||
"""Global result of the metric can be return"""
|
||||
|
|
|
@ -122,8 +122,8 @@ class Robustness(LabelSensitiveMetric):
|
|||
perturbations.append(perturbation_on_single_sample)
|
||||
perturbations = np.vstack(perturbations)
|
||||
perturbations_saliency = explainer(ms.Tensor(perturbations, ms.float32), targets).asnumpy()
|
||||
sensitivity = np.sum((perturbations_saliency - saliency_np) ** 2,
|
||||
axis=tuple(range(1, len(saliency_np.shape))))
|
||||
sensitivity = np.sqrt(np.sum((perturbations_saliency - saliency_np) ** 2,
|
||||
axis=tuple(range(1, len(saliency_np.shape)))))
|
||||
sensitivities.append(sensitivity)
|
||||
sensitivities = np.stack(sensitivities, axis=-1)
|
||||
max_sensitivity = np.max(sensitivities, axis=1) / norm
|
||||
|
|
|
@ -89,7 +89,7 @@ class Ablation:
|
|||
|
||||
class AblationWithSaliency(Ablation):
|
||||
"""
|
||||
Perturbation generator to generate perturbations for a given array.
|
||||
Perturbation generator to generate perturbations w.r.t a given saliency map.
|
||||
|
||||
Args:
|
||||
perturb_percent (float): percentage of pixels to perturb
|
||||
|
@ -143,28 +143,20 @@ class AblationWithSaliency(Ablation):
|
|||
"""
|
||||
|
||||
batch_size = saliency.shape[0]
|
||||
expected_num_dim = len(saliency.shape) + 1
|
||||
has_channel = num_channels is not None
|
||||
num_channels = 1 if num_channels is None else num_channels
|
||||
|
||||
if has_channel:
|
||||
saliency = saliency.mean(axis=1)
|
||||
saliency_rank = rank_pixels(saliency, descending=True)
|
||||
|
||||
num_pixels = reduce(lambda x, y: x * y, saliency.shape[1:])
|
||||
|
||||
if self._pixel_per_step:
|
||||
pixel_per_step = self._pixel_per_step
|
||||
num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step)
|
||||
elif self._num_perturbations:
|
||||
pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations)
|
||||
num_perturbations = self._num_perturbations
|
||||
else:
|
||||
raise ValueError("Must provide either pixel_per_step or num_perturbations.")
|
||||
pixel_per_step, num_perturbations = self._check_and_format_perturb_param(num_pixels)
|
||||
|
||||
masks = np.zeros((batch_size, num_perturbations, num_channels, saliency_rank.shape[1], saliency_rank.shape[2]),
|
||||
dtype=np.bool)
|
||||
|
||||
# If the perturbation is added accumulately, the factor should be 0 to preserve the low bound of indexing.
|
||||
factor = 0 if self._is_accumulate else 1
|
||||
|
||||
for i in range(batch_size):
|
||||
|
@ -176,7 +168,23 @@ class AblationWithSaliency(Ablation):
|
|||
up_bound += pixel_per_step
|
||||
|
||||
masks = masks if has_channel else np.squeeze(masks, axis=2)
|
||||
return masks
|
||||
|
||||
if len(masks.shape) == expected_num_dim:
|
||||
return masks
|
||||
raise ValueError(f'Invalid masks shape {len(masks.shape)}, expect {expected_num_dim}-dim.')
|
||||
def _check_and_format_perturb_param(self, num_pixels):
|
||||
"""
|
||||
Check whether the self._pixel_per_step and self._num_perturbation is valid. If the parameters are unreasonable,
|
||||
this function will try to reassign the parameters and raise ValueError when reassignment is failed.
|
||||
"""
|
||||
if self._pixel_per_step:
|
||||
pixel_per_step = self._pixel_per_step
|
||||
num_perturbations = math.floor(num_pixels * self._perturb_percent / self._pixel_per_step)
|
||||
elif self._num_perturbations:
|
||||
pixel_per_step = math.floor(num_pixels * self._perturb_percent / self._num_perturbations)
|
||||
num_perturbations = self._num_perturbations
|
||||
else:
|
||||
# If neither pixel_per_step or num_perturbations is provided, num_perturbations is determined by the square
|
||||
# root of product from the spatial size of saliency map.
|
||||
num_perturbations = math.floor(np.sqrt(num_pixels))
|
||||
pixel_per_step = math.floor(num_pixels * self._perturb_percent / num_perturbations)
|
||||
|
||||
return pixel_per_step, num_perturbations
|
||||
|
|
|
@ -14,8 +14,9 @@
|
|||
# ============================================================================
|
||||
"""Occlusion explainer."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib.stride_tricks import as_strided
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
@ -25,24 +26,17 @@ from .replacement import Constant
|
|||
from ...._utils import abs_max
|
||||
|
||||
|
||||
def _generate_patches(array, window_size, stride):
|
||||
"""View as windows."""
|
||||
if not isinstance(array, np.ndarray):
|
||||
raise TypeError("`array` must be a numpy ndarray")
|
||||
|
||||
arr_shape = np.array(array.shape)
|
||||
window_size = np.array(window_size, dtype=arr_shape.dtype)
|
||||
|
||||
slices = tuple(slice(None, None, st) for st in stride)
|
||||
window_strides = np.array(array.strides)
|
||||
|
||||
def _generate_patches(array, window_size: Tuple, strides: Tuple):
|
||||
"""Generate patches from image w.r.t given window_size and strides."""
|
||||
window_strides = array.strides
|
||||
slices = tuple(slice(None, None, stride) for stride in strides)
|
||||
indexing_strides = array[slices].strides
|
||||
win_indices_shape = (((np.array(array.shape) - np.array(window_size)) // np.array(stride)) + 1)
|
||||
win_indices_shape = (np.array(array.shape) - np.array(window_size)) // np.array(strides) + 1
|
||||
|
||||
new_shape = tuple(list(win_indices_shape) + list(window_size))
|
||||
strides = tuple(list(indexing_strides) + list(window_strides))
|
||||
|
||||
patches = as_strided(array, shape=new_shape, strides=strides)
|
||||
patches_shape = tuple(win_indices_shape) + window_size
|
||||
strides_in_memory = indexing_strides + window_strides
|
||||
patches = np.lib.stride_tricks.as_strided(array, shape=patches_shape, strides=strides_in_memory, writeable=False)
|
||||
patches = patches.reshape((-1,) + window_size)
|
||||
return patches
|
||||
|
||||
|
||||
|
@ -159,7 +153,7 @@ class Occlusion(PerturbationAttribution):
|
|||
total_dim = np.prod(inputs.shape[1:]).item()
|
||||
template = np.arange(total_dim).reshape(inputs.shape[1:])
|
||||
indices = _generate_patches(template, window_size, strides)
|
||||
num_perturbations = indices.reshape((-1,) + window_size).shape[0]
|
||||
num_perturbations = indices.shape[0]
|
||||
indices = indices.reshape(num_perturbations, -1)
|
||||
|
||||
mask = np.zeros((num_perturbations, total_dim), dtype=np.bool)
|
||||
|
|
Loading…
Reference in New Issue