Optimize RISE for better performance

From: @lixiaohui33
Reviewed-by: @wuxuejian,@liucunwei
Signed-off-by: @liucunwei
This commit is contained in:
mindspore-ci-bot 2020-12-10 09:19:52 +08:00 committed by Gitee
commit dea05a015b
3 changed files with 21 additions and 25 deletions
mindspore/explainer
_image_classification_runner.py
benchmark/_attribution
explanation/_attribution/_perturbation

View File

@ -77,16 +77,16 @@ class ImageClassificationRunner:
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
>>> dataset = get_dataset('/path/to/Cifar10_dataset')
>>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck']
>>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
>>> param_dict = load_checkpoint("checkpoint.ckpt")
>>> net = resnet50(len(classes))
>>> net = resnet50(len(labels))
>>> activation_fn = Softmax()
>>> load_param_into_net(net, param_dict)
>>> gbp = GuidedBackprop(net)
>>> gradient = Gradient(net)
>>> explainers = [gbp, gradient]
>>> faithfulness = Faithfulness(len(labels), "NaiveFaithfulness", activation_fn)
>>> faithfulness = Faithfulness(len(labels), activation_fn, "NaiveFaithfulness")
>>> benchmarkers = [faithfulness]
>>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn)
>>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers)

View File

@ -16,6 +16,7 @@
import numpy as np
from mindspore.explainer.explanation import RISE
from .metric import LabelAgnosticMetric
from ... import _operators as ops
from ..._utils import calc_correlation
@ -55,7 +56,7 @@ class ClassSensitivity(LabelAgnosticMetric):
>>> # prepare your explainer to be evaluated, e.g., Gradient.
>>> gradient = Gradient(network)
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
>>> # class_sensitivity is a ClassSensitivity instance
>>> class_sensitivity = ClassSensitivity()
>>> res = class_sensitivity.evaluate(gradient, input_x)
"""
self._check_evaluate_param(explainer, inputs)
@ -64,7 +65,12 @@ class ClassSensitivity(LabelAgnosticMetric):
max_confidence_label = ops.argmax(outputs)
min_confidence_label = ops.argmin(outputs)
if isinstance(explainer, RISE):
labels = ops.stack([max_confidence_label, min_confidence_label], axis=1)
full_saliency = explainer(inputs, labels)
max_confidence_saliency = full_saliency[:, max_confidence_label].asnumpy()
min_confidence_saliency = full_saliency[:, min_confidence_label].asnumpy()
else:
max_confidence_saliency = explainer(inputs, max_confidence_label).asnumpy()
min_confidence_saliency = explainer(inputs, min_confidence_label).asnumpy()

View File

@ -14,11 +14,9 @@
# ============================================================================
"""RISE."""
import math
import random
import numpy as np
from mindspore.ops.operations import Concat
from mindspore import Tensor
from mindspore.train._utils import check_value_type
@ -107,18 +105,13 @@ class RISE(PerturbationAttribution):
up_size = (height + mask_size[0], width + mask_size[1])
mask = np.random.random((batch_size, 1) + mask_size) < self._mask_probability
upsample = resize(op.Tensor(mask, data.dtype), up_size,
self._resize_mode)
self._resize_mode).asnumpy()
shift_x = np.random.randint(0, mask_size[0] + 1, size=batch_size)
shift_y = np.random.randint(0, mask_size[1] + 1, size=batch_size)
# Pack operator not available for GPU, thus transfer to numpy first
masks_lst = []
for sample in upsample:
shift_x = random.randint(0, mask_size[0])
shift_y = random.randint(0, mask_size[1])
masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width])
concat = Concat()
masks = concat(tuple(masks_lst))
masks = op.reshape(masks, (batch_size, -1, height, width))
masks = [sample[:, x_i: x_i + height, y_i: y_i + width] for sample, x_i, y_i
in zip(upsample, shift_x, shift_y)]
masks = Tensor(np.array(masks), data.dtype)
return masks
def __call__(self, inputs, targets):
@ -157,11 +150,8 @@ class RISE(PerturbationAttribution):
attr_np = attr_np / self._num_masks
targets = self._unify_targets(inputs, targets)
attr_classes = []
for idx, target in enumerate(targets):
attr_np_idx = attr_np[idx]
attr_idx = attr_np_idx[target]
attr_classes.append(attr_idx)
attr_classes = [att_i[target] for att_i, target in zip(attr_np, targets)]
return op.Tensor(attr_classes, dtype=inputs.dtype)