forked from mindspore-Ecosystem/mindspore
!9731 Optimize RISE for better performance
From: @lixiaohui33 Reviewed-by: @wuxuejian,@liucunwei Signed-off-by: @liucunwei
This commit is contained in:
commit
dea05a015b
|
@ -77,16 +77,16 @@ class ImageClassificationRunner:
|
||||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
|
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
|
||||||
>>> dataset = get_dataset('/path/to/Cifar10_dataset')
|
>>> dataset = get_dataset('/path/to/Cifar10_dataset')
|
||||||
>>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck']
|
>>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
||||||
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
|
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
|
||||||
>>> param_dict = load_checkpoint("checkpoint.ckpt")
|
>>> param_dict = load_checkpoint("checkpoint.ckpt")
|
||||||
>>> net = resnet50(len(classes))
|
>>> net = resnet50(len(labels))
|
||||||
>>> activation_fn = Softmax()
|
>>> activation_fn = Softmax()
|
||||||
>>> load_param_into_net(net, param_dict)
|
>>> load_param_into_net(net, param_dict)
|
||||||
>>> gbp = GuidedBackprop(net)
|
>>> gbp = GuidedBackprop(net)
|
||||||
>>> gradient = Gradient(net)
|
>>> gradient = Gradient(net)
|
||||||
>>> explainers = [gbp, gradient]
|
>>> explainers = [gbp, gradient]
|
||||||
>>> faithfulness = Faithfulness(len(labels), "NaiveFaithfulness", activation_fn)
|
>>> faithfulness = Faithfulness(len(labels), activation_fn, "NaiveFaithfulness")
|
||||||
>>> benchmarkers = [faithfulness]
|
>>> benchmarkers = [faithfulness]
|
||||||
>>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn)
|
>>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn)
|
||||||
>>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers)
|
>>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.explainer.explanation import RISE
|
||||||
from .metric import LabelAgnosticMetric
|
from .metric import LabelAgnosticMetric
|
||||||
from ... import _operators as ops
|
from ... import _operators as ops
|
||||||
from ..._utils import calc_correlation
|
from ..._utils import calc_correlation
|
||||||
|
@ -55,7 +56,7 @@ class ClassSensitivity(LabelAgnosticMetric):
|
||||||
>>> # prepare your explainer to be evaluated, e.g., Gradient.
|
>>> # prepare your explainer to be evaluated, e.g., Gradient.
|
||||||
>>> gradient = Gradient(network)
|
>>> gradient = Gradient(network)
|
||||||
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
>>> input_x = ms.Tensor(np.random.rand(1, 3, 224, 224), ms.float32)
|
||||||
>>> # class_sensitivity is a ClassSensitivity instance
|
>>> class_sensitivity = ClassSensitivity()
|
||||||
>>> res = class_sensitivity.evaluate(gradient, input_x)
|
>>> res = class_sensitivity.evaluate(gradient, input_x)
|
||||||
"""
|
"""
|
||||||
self._check_evaluate_param(explainer, inputs)
|
self._check_evaluate_param(explainer, inputs)
|
||||||
|
@ -64,7 +65,12 @@ class ClassSensitivity(LabelAgnosticMetric):
|
||||||
|
|
||||||
max_confidence_label = ops.argmax(outputs)
|
max_confidence_label = ops.argmax(outputs)
|
||||||
min_confidence_label = ops.argmin(outputs)
|
min_confidence_label = ops.argmin(outputs)
|
||||||
|
if isinstance(explainer, RISE):
|
||||||
|
labels = ops.stack([max_confidence_label, min_confidence_label], axis=1)
|
||||||
|
full_saliency = explainer(inputs, labels)
|
||||||
|
max_confidence_saliency = full_saliency[:, max_confidence_label].asnumpy()
|
||||||
|
min_confidence_saliency = full_saliency[:, min_confidence_label].asnumpy()
|
||||||
|
else:
|
||||||
max_confidence_saliency = explainer(inputs, max_confidence_label).asnumpy()
|
max_confidence_saliency = explainer(inputs, max_confidence_label).asnumpy()
|
||||||
min_confidence_saliency = explainer(inputs, min_confidence_label).asnumpy()
|
min_confidence_saliency = explainer(inputs, min_confidence_label).asnumpy()
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""RISE."""
|
"""RISE."""
|
||||||
import math
|
import math
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore.ops.operations import Concat
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.train._utils import check_value_type
|
from mindspore.train._utils import check_value_type
|
||||||
|
|
||||||
|
@ -107,18 +105,13 @@ class RISE(PerturbationAttribution):
|
||||||
up_size = (height + mask_size[0], width + mask_size[1])
|
up_size = (height + mask_size[0], width + mask_size[1])
|
||||||
mask = np.random.random((batch_size, 1) + mask_size) < self._mask_probability
|
mask = np.random.random((batch_size, 1) + mask_size) < self._mask_probability
|
||||||
upsample = resize(op.Tensor(mask, data.dtype), up_size,
|
upsample = resize(op.Tensor(mask, data.dtype), up_size,
|
||||||
self._resize_mode)
|
self._resize_mode).asnumpy()
|
||||||
|
shift_x = np.random.randint(0, mask_size[0] + 1, size=batch_size)
|
||||||
|
shift_y = np.random.randint(0, mask_size[1] + 1, size=batch_size)
|
||||||
|
|
||||||
# Pack operator not available for GPU, thus transfer to numpy first
|
masks = [sample[:, x_i: x_i + height, y_i: y_i + width] for sample, x_i, y_i
|
||||||
masks_lst = []
|
in zip(upsample, shift_x, shift_y)]
|
||||||
for sample in upsample:
|
masks = Tensor(np.array(masks), data.dtype)
|
||||||
shift_x = random.randint(0, mask_size[0])
|
|
||||||
shift_y = random.randint(0, mask_size[1])
|
|
||||||
masks_lst.append(sample[:, shift_x: shift_x + height, shift_y:shift_y + width])
|
|
||||||
|
|
||||||
concat = Concat()
|
|
||||||
masks = concat(tuple(masks_lst))
|
|
||||||
masks = op.reshape(masks, (batch_size, -1, height, width))
|
|
||||||
return masks
|
return masks
|
||||||
|
|
||||||
def __call__(self, inputs, targets):
|
def __call__(self, inputs, targets):
|
||||||
|
@ -157,11 +150,8 @@ class RISE(PerturbationAttribution):
|
||||||
|
|
||||||
attr_np = attr_np / self._num_masks
|
attr_np = attr_np / self._num_masks
|
||||||
targets = self._unify_targets(inputs, targets)
|
targets = self._unify_targets(inputs, targets)
|
||||||
attr_classes = []
|
|
||||||
for idx, target in enumerate(targets):
|
attr_classes = [att_i[target] for att_i, target in zip(attr_np, targets)]
|
||||||
attr_np_idx = attr_np[idx]
|
|
||||||
attr_idx = attr_np_idx[target]
|
|
||||||
attr_classes.append(attr_idx)
|
|
||||||
|
|
||||||
return op.Tensor(attr_classes, dtype=inputs.dtype)
|
return op.Tensor(attr_classes, dtype=inputs.dtype)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue