forked from mindspore-Ecosystem/mindspore
add test cases for explainer
This commit is contained in:
parent
b82c4cba32
commit
7769be81ad
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Initialization of tests of explanation related classes."""
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Initialization of tests of mindspore.explainer.benchmark."""
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Initialization of tests of in mindspore.explainer.benchmark."""
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tests of Localization of mindspore.explainer.benchmark."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context
|
||||
from mindspore import nn
|
||||
from mindspore.explainer.benchmark import Localization
|
||||
from mindspore.explainer.explanation import Gradient
|
||||
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
H, W = 4, 4
|
||||
SALIENCY = ms.Tensor(np.random.rand(1, 1, H, W), ms.float32)
|
||||
|
||||
|
||||
class CustomNet(nn.Cell):
|
||||
"""Simple net for unit test."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, _):
|
||||
return ms.Tensor([[0.1, 0.9]], ms.float32)
|
||||
|
||||
|
||||
def mock_gradient_call(_, inputs, targets):
|
||||
del inputs, targets
|
||||
return SALIENCY
|
||||
|
||||
|
||||
class TestLocalization:
|
||||
"""Test on Localization."""
|
||||
|
||||
def setup_method(self):
|
||||
self.net = CustomNet()
|
||||
self.data = ms.Tensor(np.random.rand(1, 1, H, W), ms.float32)
|
||||
self.target = 1
|
||||
|
||||
masks_np = np.zeros((1, 1, H, W))
|
||||
masks_np[:, :, 1:3, 1:3] = 1
|
||||
self.masks_np = masks_np
|
||||
self.masks = ms.Tensor(masks_np, ms.float32)
|
||||
|
||||
self.explainer = Gradient(self.net)
|
||||
self.saliency_gt = mock_gradient_call(self.explainer, self.data, self.target)
|
||||
self.num_class = 2
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pointing_game(self):
|
||||
"""Test case for `metric="PointingGame"` without input saliency."""
|
||||
with patch.object(Gradient, "__call__", mock_gradient_call):
|
||||
max_pos = np.argmax(abs(self.saliency_gt.asnumpy().flatten()))
|
||||
x_gt, y_gt = max_pos // W, max_pos % W
|
||||
res_gt = self.masks_np[0, 0, x_gt, y_gt]
|
||||
|
||||
pg = Localization(self.num_class, metric="PointingGame")
|
||||
pg._metric_arg = 1 # make the tolerance smaller to simplify the test
|
||||
|
||||
res = pg.evaluate(self.explainer, self.data, targets=self.target, mask=self.masks)
|
||||
assert np.max(np.abs(np.array([res_gt]) - res)) < 1e-5
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_iosr(self):
|
||||
"""Test case for `metric="IoSR"` without input saliency."""
|
||||
with patch.object(Gradient, "__call__", mock_gradient_call):
|
||||
threshold = 0.5
|
||||
max_val = np.max(self.saliency_gt.asnumpy())
|
||||
sr = (self.saliency_gt.asnumpy() > (max_val * threshold)).astype(int)
|
||||
res_gt = np.sum(sr * self.masks_np) / (np.sum(sr).clip(1e-10))
|
||||
|
||||
iosr = Localization(self.num_class, metric="IoSR")
|
||||
iosr._metric_arg = threshold
|
||||
|
||||
res = iosr.evaluate(self.explainer, self.data, targets=self.target, mask=self.masks)
|
||||
|
||||
assert np.allclose(np.array([res_gt]), res)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pointing_game_with_saliency(self):
|
||||
"""Test metric PointingGame with input saliency."""
|
||||
max_pos = np.argmax(abs(self.saliency_gt.asnumpy().flatten()))
|
||||
x_gt, y_gt = max_pos // W, max_pos % W
|
||||
res_gt = self.masks_np[0, 0, x_gt, y_gt]
|
||||
|
||||
pg = Localization(self.num_class, metric="PointingGame")
|
||||
pg._metric_arg = 1 # make the tolerance smaller to simplify the test
|
||||
|
||||
res = pg.evaluate(self.explainer, self.data, targets=self.target, mask=self.masks, saliency=self.saliency_gt)
|
||||
assert np.allclose(np.array([res_gt]), res)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_iosr_with_saliency(self):
|
||||
"""Test metric IoSR with input saliency map."""
|
||||
threshold = 0.5
|
||||
max_val = np.max(self.saliency_gt.asnumpy())
|
||||
sr = (self.saliency_gt.asnumpy() > (max_val * threshold)).astype(int)
|
||||
res_gt = np.sum(sr * self.masks_np) / (np.sum(sr).clip(1e-10))
|
||||
|
||||
iosr = Localization(self.num_class, metric="IoSR")
|
||||
|
||||
res = iosr.evaluate(self.explainer, self.data, targets=self.target, mask=self.masks, saliency=self.saliency_gt)
|
||||
|
||||
assert np.allclose(np.array([res_gt]), res)
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Initialization of tests of mindspore.explainer.explanation."""
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Initialization of tests of explainers of mindspore.explainer.explanation."""
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Initialization of tests of back-propagation based explainers."""
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tests of GradCAM of mindspore.explainer.explanation."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context
|
||||
import mindspore.ops.operations as op
|
||||
from mindspore import nn
|
||||
from mindspore.explainer.explanation import GradCAM
|
||||
from mindspore.explainer.explanation._attribution._backprop.gradcam import _gradcam_aggregation as aggregation
|
||||
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
class SimpleAvgLinear(nn.Cell):
|
||||
"""Simple linear model for the unit test."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.avgpool = nn.AvgPool2d(2, 2)
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc2 = nn.Dense(4, 3)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.avgpool(x)
|
||||
x = self.flatten(x)
|
||||
return self.fc2(x)
|
||||
|
||||
|
||||
def resize_fn(attributions, inputs, mode):
|
||||
"""Mocked resize function for test."""
|
||||
del inputs, mode
|
||||
return attributions
|
||||
|
||||
|
||||
class TestGradCAM:
|
||||
"""Test GradCAM."""
|
||||
|
||||
def setup_method(self):
|
||||
self.net = SimpleAvgLinear()
|
||||
self.data = ms.Tensor(np.random.random(size=(1, 1, 4, 4)), ms.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gradcam_attribution(self):
|
||||
"""Test __call__ method in GradCAM."""
|
||||
with patch.object(GradCAM, "_resize_fn", side_effect=resize_fn):
|
||||
layer = "avgpool"
|
||||
|
||||
gradcam = GradCAM(self.net, layer=layer)
|
||||
|
||||
data = ms.Tensor(np.random.random(size=(1, 1, 4, 4)), ms.float32)
|
||||
num_classes = 3
|
||||
activation = self.net.avgpool(data)
|
||||
reshape = op.Reshape()
|
||||
for x in range(num_classes):
|
||||
target = ms.Tensor([x], ms.int32)
|
||||
attribution = gradcam(data, target)
|
||||
# intermediate grad should be reshape of weight of fc2
|
||||
intermediate_grad = self.net.fc2.weight.data[x]
|
||||
reshaped = reshape(intermediate_grad, (1, 1, 2, 2))
|
||||
gap_grad = self.net.avgpool(reshaped)
|
||||
res = aggregation(gap_grad * activation)
|
||||
assert np.allclose(res.asnumpy(), attribution.asnumpy(), atol=1e-5, rtol=1e-3)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_layer_default(self):
|
||||
"""Test layer argument of GradCAM."""
|
||||
with patch.object(GradCAM, "_resize_fn", side_effect=resize_fn):
|
||||
gradcam = GradCAM(self.net)
|
||||
num_classes = 3
|
||||
sum_ = op.ReduceSum()
|
||||
for x in range(num_classes):
|
||||
target = ms.Tensor([x], ms.int32)
|
||||
attribution = gradcam(self.data, target)
|
||||
|
||||
# intermediate_grad should be reshape of weight of fc2
|
||||
intermediate_grad = self.net.fc2.weight.data[x]
|
||||
avggrad = float(sum_(intermediate_grad).asnumpy() / 16)
|
||||
res = aggregation(avggrad * self.data)
|
||||
assert np.allclose(res.asnumpy(), attribution.asnumpy(), atol=1e-5, rtol=1e-3)
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tests of Gradient of mindspore.explainer.explanation."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import nn
|
||||
from mindspore.explainer.explanation import Gradient
|
||||
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
class SimpleLinear(nn.Cell):
|
||||
"""Simple linear model for the unit test."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc2 = nn.Dense(16, 3)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.flatten(x)
|
||||
return self.fc2(x)
|
||||
|
||||
|
||||
class TestGradient:
|
||||
"""Test Gradient."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup the test case."""
|
||||
self.net = SimpleLinear()
|
||||
self.relu = P.ReLU()
|
||||
self.abs_ = P.Abs()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gradient(self):
|
||||
"""Test gradient __call__ function."""
|
||||
data = (ms.Tensor(np.random.random(size=(1, 1, 4, 4)),
|
||||
ms.float32) - 0.5) * 2
|
||||
explainer = Gradient(self.net)
|
||||
|
||||
num_classes = 3
|
||||
reshape = P.Reshape()
|
||||
for x in range(num_classes):
|
||||
target = ms.Tensor([x], ms.int32)
|
||||
|
||||
attribution = explainer(data, target)
|
||||
|
||||
# intermediate_grad should be reshape of weight of fc2
|
||||
grad = self.net.fc2.weight.data[x]
|
||||
grad = self.abs_(reshape(grad, (1, 1, 4, 4)) * (self.abs_(self.relu(data) / data)))
|
||||
assert np.allclose(grad.asnumpy(), attribution.asnumpy())
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tests of Deconvolution and GuidedBackprop of mindspore.explainer.explanation."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import context
|
||||
from mindspore import nn
|
||||
from mindspore.explainer.explanation import Deconvolution, GuidedBackprop
|
||||
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
class SimpleLinear(nn.Cell):
|
||||
"""Simple linear model for the unit test."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc2 = nn.Dense(16, 3)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.flatten(x)
|
||||
return self.fc2(x)
|
||||
|
||||
|
||||
class TestModifiedReLU:
|
||||
"""Test on modified_relu module, Deconvolution and GuidedBackprop specifically."""
|
||||
def setup_method(self):
|
||||
"""Setup the test case."""
|
||||
self.net = SimpleLinear()
|
||||
self.relu = P.ReLU()
|
||||
self.abs_ = P.Abs()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_deconvolution(self):
|
||||
"""Test deconvolution attribution."""
|
||||
data = (ms.Tensor(np.random.random(size=(1, 1, 4, 4)),
|
||||
ms.float32) - 0.5) * 2
|
||||
deconv = Deconvolution(self.net)
|
||||
|
||||
num_classes = 3
|
||||
for x in range(num_classes):
|
||||
target = ms.Tensor([x], ms.int32)
|
||||
|
||||
attribution = deconv(data, target)
|
||||
|
||||
# intermediate_grad should be reshape of weight of fc2
|
||||
grad = self.net.fc2.weight.data[x]
|
||||
grad = self.abs_(self.relu(self.reshape(grad, (1, 1, 4, 4))))
|
||||
assert np.allclose(attribution.asnumpy(), grad.asnumpy())
|
||||
|
||||
def test_guided_backprop(self):
|
||||
"""Test deconvolution attribution."""
|
||||
data = (ms.Tensor(np.random.random(size=(1, 1, 4, 4)),
|
||||
ms.float32) - 0.5) * 2
|
||||
explainer = GuidedBackprop(self.net)
|
||||
|
||||
num_classes = 3
|
||||
for x in range(num_classes):
|
||||
target = ms.Tensor([x], ms.int32)
|
||||
|
||||
attribution = explainer(data, target)
|
||||
|
||||
# intermediate_grad should be reshape of weight of fc2
|
||||
grad = self.net.fc2.weight.data[x]
|
||||
grad = self.reshape(grad, (1, 1, 4, 4))
|
||||
guided_grad = self.abs_(self.relu(grad * (self.abs_(self.relu(data) / data))))
|
||||
|
||||
assert np.allclose(guided_grad.asnumpy(), attribution.asnumpy())
|
|
@ -0,0 +1,200 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Tests on mindspore.explainer.ImageClassificationRunner."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from random import random
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from mindspore import context
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.dataset import GeneratorDataset
|
||||
from mindspore.explainer import ImageClassificationRunner
|
||||
from mindspore.explainer._image_classification_runner import _normalize
|
||||
from mindspore.explainer.benchmark import Faithfulness
|
||||
from mindspore.explainer.explanation import Gradient
|
||||
from mindspore.train.summary import SummaryRecord
|
||||
|
||||
CONST = random()
|
||||
NUMDATA = 2
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
def image_label_bbox_generator():
|
||||
for i in range(NUMDATA):
|
||||
image = np.arange(i, i + 16 * 3).reshape((3, 4, 4)) / 50
|
||||
label = np.array(i)
|
||||
bbox = np.array([1, 1, 2, 2])
|
||||
yield (image, label, bbox)
|
||||
|
||||
|
||||
class SimpleNet(nn.Cell):
|
||||
"""
|
||||
Simple model for the unit test.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
self.reshape = ms.ops.operations.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
prob = ms.Tensor([0.1, 0.9], ms.float32)
|
||||
prob = self.reshape(prob, (1, 2))
|
||||
return prob
|
||||
|
||||
|
||||
class ActivationFn(nn.Cell):
|
||||
"""
|
||||
Simple activation function for unit test.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ActivationFn, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def mock_gradient_call(_, inputs, targets):
|
||||
return inputs[:, 0:1, :, :]
|
||||
|
||||
|
||||
def mock_faithfulness_evaluate(_, explainer, inputs, targets, saliency):
|
||||
return CONST * targets
|
||||
|
||||
|
||||
def mock_make_rgba(array):
|
||||
return array.asnumpy()
|
||||
|
||||
|
||||
class TestRunner:
|
||||
"""Test on Runner."""
|
||||
|
||||
def setup_method(self):
|
||||
self.dataset = GeneratorDataset(image_label_bbox_generator, ["image", "label", "bbox"])
|
||||
self.labels = ["label_{}".format(i) for i in range(2)]
|
||||
self.network = SimpleNet()
|
||||
self.summary_dir = "summary_test_temp"
|
||||
self.explainer = [Gradient(self.network)]
|
||||
self.activation_fn = ActivationFn()
|
||||
self.benchmarkers = [Faithfulness(num_labels=len(self.labels),
|
||||
metric="NaiveFaithfulness",
|
||||
activation_fn=self.activation_fn)]
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_run_saliency_no_benchmark(self):
|
||||
"""Test case when argument benchmarkers is not parsed."""
|
||||
res = []
|
||||
runner = ImageClassificationRunner(summary_dir=self.summary_dir, data=(self.dataset, self.labels),
|
||||
network=self.network, activation_fn=self.activation_fn)
|
||||
|
||||
def mock_summary_add_value(_, plugin, name, value):
|
||||
res.append((plugin, name, value))
|
||||
|
||||
with patch.object(SummaryRecord, "add_value", mock_summary_add_value), \
|
||||
patch.object(Gradient, "__call__", mock_gradient_call):
|
||||
runner.register_saliency(self.explainer)
|
||||
runner.run()
|
||||
|
||||
# test on meta data
|
||||
idx = 0
|
||||
assert res[idx][0] == "explainer"
|
||||
assert res[idx][1] == "metadata"
|
||||
assert res[idx][2].metadata.label == self.labels
|
||||
assert res[idx][2].metadata.explain_method == ["Gradient"]
|
||||
|
||||
# test on inference data
|
||||
for i in range(NUMDATA):
|
||||
idx += 1
|
||||
data_np = np.arange(i, i + 3 * 16).reshape((3, 4, 4)) / 50
|
||||
assert res[idx][0] == "explainer"
|
||||
assert res[idx][1] == "sample"
|
||||
assert res[idx][2].sample_id == i
|
||||
original_path = os.path.join(self.summary_dir, res[idx][2].image_path)
|
||||
with open(original_path, "rb") as f:
|
||||
image_data = np.asarray(Image.open(f)) / 255.0
|
||||
original_image = _normalize(np.transpose(data_np, [1, 2, 0]))
|
||||
assert np.allclose(image_data, original_image, rtol=3e-2, atol=3e-2)
|
||||
|
||||
idx += 1
|
||||
assert res[idx][0] == "explainer"
|
||||
assert res[idx][1] == "inference"
|
||||
assert res[idx][2].sample_id == i
|
||||
assert res[idx][2].ground_truth_label == [i]
|
||||
|
||||
diff = np.array(res[idx][2].inference.ground_truth_prob) - np.array([[0.1, 0.9][i]])
|
||||
assert np.max(np.abs(diff)) < 1e-6
|
||||
assert res[idx][2].inference.predicted_label == [1]
|
||||
diff = np.array(res[idx][2].inference.predicted_prob) - np.array([0.9])
|
||||
assert np.max(np.abs(diff)) < 1e-6
|
||||
|
||||
# test on explanation data
|
||||
for i in range(NUMDATA):
|
||||
idx += 1
|
||||
data_np = np.arange(i, i + 3 * 16).reshape((3, 4, 4)) / 50
|
||||
saliency_np = data_np[0, :, :]
|
||||
assert res[idx][0] == "explainer"
|
||||
assert res[idx][1] == "explanation"
|
||||
assert res[idx][2].sample_id == i
|
||||
assert res[idx][2].explanation[0].explain_method == "Gradient"
|
||||
|
||||
assert res[idx][2].explanation[0].label in [i, 1]
|
||||
|
||||
heatmap_path = os.path.join(self.summary_dir, res[idx][2].explanation[0].heatmap_path)
|
||||
assert os.path.exists(heatmap_path)
|
||||
|
||||
with open(heatmap_path, "rb") as f:
|
||||
heatmap_data = np.asarray(Image.open(f)) / 255.0
|
||||
heatmap_image = _normalize(saliency_np)
|
||||
assert np.allclose(heatmap_data, heatmap_image, atol=3e-2, rtol=3e-2)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_run_saliency_with_benchmark(self):
|
||||
"""Test case when argument benchmarkers is parsed."""
|
||||
res = []
|
||||
|
||||
def mock_summary_add_value(_, plugin, name, value):
|
||||
res.append((plugin, name, value))
|
||||
|
||||
runner = ImageClassificationRunner(summary_dir=self.summary_dir, data=(self.dataset, self.labels),
|
||||
network=self.network, activation_fn=self.activation_fn)
|
||||
|
||||
with patch.object(SummaryRecord, "add_value", mock_summary_add_value), \
|
||||
patch.object(Gradient, "__call__", mock_gradient_call), \
|
||||
patch.object(Faithfulness, "evaluate", mock_faithfulness_evaluate):
|
||||
runner.register_saliency(self.explainer, self.benchmarkers)
|
||||
runner.run()
|
||||
|
||||
idx = 3 * NUMDATA + 1 # start index of benchmark data
|
||||
assert res[idx][0] == "explainer"
|
||||
assert res[idx][1] == "benchmark"
|
||||
assert abs(res[idx][2].benchmark[0].total_score - 2 / 3 * CONST) < 1e-6
|
||||
diff = np.array(res[idx][2].benchmark[0].label_score) - np.array([i * CONST for i in range(NUMDATA)])
|
||||
assert np.max(np.abs(diff)) < 1e-6
|
||||
|
||||
def teardown_method(self):
|
||||
shutil.rmtree(self.summary_dir)
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Unit test on mindspore.explainer._utils."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
from mindspore.explainer._utils import (
|
||||
ForwardProbe,
|
||||
rank_pixels,
|
||||
retrieve_layer,
|
||||
retrieve_layer_by_name)
|
||||
from mindspore.explainer.explanation._attribution._backprop.backprop_utils import GradNet, get_bp_weights
|
||||
|
||||
|
||||
class CustomNet(nn.Cell):
|
||||
"""Simple net for test."""
|
||||
|
||||
def __init__(self):
|
||||
super(CustomNet, self).__init__()
|
||||
self.fc1 = nn.Dense(10, 10)
|
||||
self.fc2 = nn.Dense(10, 10)
|
||||
self.fc3 = nn.Dense(10, 10)
|
||||
self.fc4 = nn.Dense(10, 10)
|
||||
|
||||
def construct(self, inputs):
|
||||
out = self.fc1(inputs)
|
||||
out = self.fc2(out)
|
||||
out = self.fc3(out)
|
||||
out = self.fc4(out)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_rank_pixels():
|
||||
"""Test on rank_pixels."""
|
||||
saliency = np.array([[4., 3., 1.], [5., 9., 1.]])
|
||||
descending_target = np.array([[0, 1, 2], [1, 0, 2]])
|
||||
ascending_target = np.array([[2, 1, 0], [1, 2, 0]])
|
||||
descending_rank = rank_pixels(saliency)
|
||||
ascending_rank = rank_pixels(saliency, descending=False)
|
||||
assert (descending_rank - descending_target).any() == 0
|
||||
assert (ascending_rank - ascending_target).any() == 0
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_retrieve_layer_by_name():
|
||||
"""Test on rank_pixels."""
|
||||
model = CustomNet()
|
||||
target_layer_name = 'fc3'
|
||||
target_layer = retrieve_layer_by_name(model, target_layer_name)
|
||||
|
||||
assert target_layer is model.fc3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_retrieve_layer_by_name_no_name():
|
||||
"""Test on retrieve layer."""
|
||||
model = CustomNet()
|
||||
target_layer = retrieve_layer_by_name(model, '')
|
||||
|
||||
assert target_layer is model
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_forward_probe():
|
||||
"""Test case for ForwardProbe."""
|
||||
model = CustomNet()
|
||||
model.set_grad()
|
||||
inputs = np.random.random((1, 10))
|
||||
inputs = ms.Tensor(inputs, ms.float32)
|
||||
gt_activation = model.fc3(model.fc2(model.fc1(inputs))).asnumpy()
|
||||
|
||||
targets = 1
|
||||
weights = get_bp_weights(model, inputs, targets=targets)
|
||||
|
||||
gradnet = GradNet(model)
|
||||
grad_before_probe = gradnet(inputs, weights).asnumpy()
|
||||
|
||||
# Probe forward tensor
|
||||
saliency_layer = retrieve_layer(model, 'fc3')
|
||||
|
||||
with ForwardProbe(saliency_layer) as probe:
|
||||
grad_after_probe = gradnet(inputs, weights).asnumpy()
|
||||
activation = probe.value.asnumpy()
|
||||
|
||||
grad_after_unprobe = gradnet(inputs, weights).asnumpy()
|
||||
|
||||
assert np.array_equal(gt_activation, activation)
|
||||
assert np.array_equal(grad_before_probe, grad_after_probe)
|
||||
assert np.array_equal(grad_before_probe, grad_after_unprobe)
|
||||
assert probe.value is None
|
Loading…
Reference in New Issue