!7778 Fixbug: Faithfulness score NaN, add input params validity check to fix system error bug

Merge pull request !7778 from lixiaohui33/feature_explain_core
This commit is contained in:
mindspore-ci-bot 2020-10-29 11:52:34 +08:00 committed by Gitee
commit 25c388e01e
5 changed files with 36 additions and 8 deletions

View File

@ -17,8 +17,9 @@ from time import time
from typing import Tuple, List, Optional
import numpy as np
from mindspore.train.summary_pb2 import Explain
from mindspore.train._utils import check_value_type
from mindspore.train.summary_pb2 import Explain
import mindspore as ms
import mindspore.dataset as ds
from mindspore import log
@ -71,6 +72,7 @@ class ExplainRunner:
"""
def __init__(self, summary_dir: Optional[str] = "./"):
check_value_type("summary_dir", summary_dir, str)
self._summary_dir = summary_dir
self._count = 0
self._classes = None
@ -123,14 +125,21 @@ class ExplainRunner:
for exp in explainers:
if not isinstance(exp, Attribution) or not isinstance(explainers, list):
raise TypeError("Argument explainers should be a list of objects of classes in "
"`mindspore.explainer.explanation._attribution`.")
"`mindspore.explainer.explanation`.")
if benchmarkers is not None:
for bench in benchmarkers:
if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list):
raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation"
"`mindspore.explainer.benchmark._attribution`.")
"`mindspore.explainer.benchmark`.")
self._model = explainers[0].model
next_element = dataset.create_tuple_iterator().get_next()
inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._model(inputs)
check_value_type("output of model im explainer", prop_test, ms.Tensor)
if prop_test.shape[1] > len(self._classes):
raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please "
"check dataset classes or the black-box model in the explainer again.")
with SummaryRecord(self._summary_dir) as summary:
print("Start running and writing......")

View File

@ -29,6 +29,7 @@ __all__ = [
]
from typing import Tuple, Union
import math
import numpy as np
from PIL import Image
@ -204,7 +205,8 @@ def calc_correlation(x: Union[ms.Tensor, np.ndarray],
x = format_tensor_to_ndarray(x)
y = format_tensor_to_ndarray(y)
faithfulness = -np.corrcoef(x, y)[0, 1]
if math.isnan(faithfulness):
return np.float(0)
return faithfulness
@ -232,7 +234,6 @@ def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
>> np.array([[2, 3, 4], [1, 0, 5]])
rank_pixels(x, descending=False)
>> np.array([[3, 2, 0], [4, 5, 1]])
"""
if len(inputs.shape) != 2:
raise ValueError('Only support 2D array currently')

View File

@ -339,6 +339,7 @@ class NaiveFaithfulness(_FaithfulnessHelper):
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
predictions = model(perturbations).asnumpy()[:, targets]
faithfulness = calc_correlation(feature_importance, predictions)
normalized_faithfulness = (faithfulness + 1) / 2
return np.array([normalized_faithfulness], np.float)

View File

@ -90,7 +90,7 @@ class Localization(AttributionMetric):
Evaluate localization on a single data sample.
Args:
explainer (Explanation): The explainer to be evaluated, see `mindspore/explainer/explanation`.
explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`.
inputs (Tensor): data sample. Currently only support single sample at each call.
targets (int): target label to evaluate on.
saliency (Tensor): A saliency tensor.
@ -113,7 +113,7 @@ class Localization(AttributionMetric):
>>> saliency = gradient(inputs, targets)
>>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks)
"""
self._check_evaluate_param(explainer, inputs, targets, saliency)
self._check_evaluate_param_with_mask(explainer, inputs, targets, saliency, mask)
mask_np = format_tensor_to_ndarray(mask)[0]
@ -141,6 +141,10 @@ class Localization(AttributionMetric):
def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask):
self._check_evaluate_param(explainer, inputs, targets, saliency)
check_value_type('mask', mask, (Tensor, np.ndarray))
if len(inputs.shape) != 4:
raise ValueError('Argument mask must be 4D Tensor')
if mask is None:
raise ValueError('To compute localization, mask must be provided.')
check_value_type('mask', mask, (Tensor, np.ndarray))
if len(mask.shape) != 4 or len(mask) != len(inputs):
raise ValueError("The input mask must be 4-dimensional (1, 1, h, w) with same length of inputs.")

View File

@ -47,9 +47,16 @@ class AttributionMetric:
"""Super class of XAI metric class used in classification scenarios."""
def __init__(self, num_labels=None):
self._verify_params(num_labels)
self._num_labels = num_labels
self._global_results = {i: [] for i in range(num_labels)}
@staticmethod
def _verify_params(num_labels):
check_value_type("num_labels", num_labels, int)
if num_labels < 1:
raise ValueError("Argument num_labels must be parsed with a integer > 0.")
def evaluate(self, explainer, inputs, targets, saliency=None):
"""This function evaluates on a single sample and return the result."""
raise NotImplementedError
@ -119,5 +126,11 @@ class AttributionMetric:
"""Check the evaluate parameters."""
check_value_type('explainer', explainer, Attribution)
verify_argument(inputs, 'inputs')
output = explainer.model(inputs)
check_value_type("output of explainer model", output, Tensor)
output_dim = explainer.model(inputs).shape[1]
if output_dim > self._num_labels:
raise ValueError("The output dimension of of black-box model in explainer should not exceed the dimension "
"of num_labels set in the __init__, please set num_labels larger.")
verify_targets(targets, self._num_labels)
check_value_type('saliency', saliency, (Tensor, type(None)))