forked from mindspore-Ecosystem/mindspore
!7778 Fixbug: Faithfulness score NaN, add input params validity check to fix system error bug
Merge pull request !7778 from lixiaohui33/feature_explain_core
This commit is contained in:
commit
25c388e01e
|
@ -17,8 +17,9 @@ from time import time
|
|||
from typing import Tuple, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from mindspore.train.summary_pb2 import Explain
|
||||
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.train.summary_pb2 import Explain
|
||||
import mindspore as ms
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log
|
||||
|
@ -71,6 +72,7 @@ class ExplainRunner:
|
|||
"""
|
||||
|
||||
def __init__(self, summary_dir: Optional[str] = "./"):
|
||||
check_value_type("summary_dir", summary_dir, str)
|
||||
self._summary_dir = summary_dir
|
||||
self._count = 0
|
||||
self._classes = None
|
||||
|
@ -123,14 +125,21 @@ class ExplainRunner:
|
|||
for exp in explainers:
|
||||
if not isinstance(exp, Attribution) or not isinstance(explainers, list):
|
||||
raise TypeError("Argument explainers should be a list of objects of classes in "
|
||||
"`mindspore.explainer.explanation._attribution`.")
|
||||
"`mindspore.explainer.explanation`.")
|
||||
if benchmarkers is not None:
|
||||
for bench in benchmarkers:
|
||||
if not isinstance(bench, AttributionMetric) or not isinstance(explainers, list):
|
||||
raise TypeError("Argument benchmarkers should be a list of objects of classes in explanation"
|
||||
"`mindspore.explainer.benchmark._attribution`.")
|
||||
"`mindspore.explainer.benchmark`.")
|
||||
|
||||
self._model = explainers[0].model
|
||||
next_element = dataset.create_tuple_iterator().get_next()
|
||||
inputs, _, _ = self._unpack_next_element(next_element)
|
||||
prop_test = self._model(inputs)
|
||||
check_value_type("output of model im explainer", prop_test, ms.Tensor)
|
||||
if prop_test.shape[1] > len(self._classes):
|
||||
raise ValueError("The dimension of model output should not exceed the length of dataset classes. Please "
|
||||
"check dataset classes or the black-box model in the explainer again.")
|
||||
|
||||
with SummaryRecord(self._summary_dir) as summary:
|
||||
print("Start running and writing......")
|
||||
|
|
|
@ -29,6 +29,7 @@ __all__ = [
|
|||
]
|
||||
|
||||
from typing import Tuple, Union
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -204,7 +205,8 @@ def calc_correlation(x: Union[ms.Tensor, np.ndarray],
|
|||
x = format_tensor_to_ndarray(x)
|
||||
y = format_tensor_to_ndarray(y)
|
||||
faithfulness = -np.corrcoef(x, y)[0, 1]
|
||||
|
||||
if math.isnan(faithfulness):
|
||||
return np.float(0)
|
||||
return faithfulness
|
||||
|
||||
|
||||
|
@ -232,7 +234,6 @@ def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
|
|||
>> np.array([[2, 3, 4], [1, 0, 5]])
|
||||
rank_pixels(x, descending=False)
|
||||
>> np.array([[3, 2, 0], [4, 5, 1]])
|
||||
|
||||
"""
|
||||
if len(inputs.shape) != 2:
|
||||
raise ValueError('Only support 2D array currently')
|
||||
|
|
|
@ -339,6 +339,7 @@ class NaiveFaithfulness(_FaithfulnessHelper):
|
|||
|
||||
perturbations = ms.Tensor(perturbations, dtype=ms.float32)
|
||||
predictions = model(perturbations).asnumpy()[:, targets]
|
||||
|
||||
faithfulness = calc_correlation(feature_importance, predictions)
|
||||
normalized_faithfulness = (faithfulness + 1) / 2
|
||||
return np.array([normalized_faithfulness], np.float)
|
||||
|
|
|
@ -90,7 +90,7 @@ class Localization(AttributionMetric):
|
|||
Evaluate localization on a single data sample.
|
||||
|
||||
Args:
|
||||
explainer (Explanation): The explainer to be evaluated, see `mindspore/explainer/explanation`.
|
||||
explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`.
|
||||
inputs (Tensor): data sample. Currently only support single sample at each call.
|
||||
targets (int): target label to evaluate on.
|
||||
saliency (Tensor): A saliency tensor.
|
||||
|
@ -113,7 +113,7 @@ class Localization(AttributionMetric):
|
|||
>>> saliency = gradient(inputs, targets)
|
||||
>>> res = localization.evaluate(gradient, inputs, targets, saliency, mask=masks)
|
||||
"""
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
self._check_evaluate_param_with_mask(explainer, inputs, targets, saliency, mask)
|
||||
|
||||
mask_np = format_tensor_to_ndarray(mask)[0]
|
||||
|
||||
|
@ -141,6 +141,10 @@ class Localization(AttributionMetric):
|
|||
|
||||
def _check_evaluate_param_with_mask(self, explainer, inputs, targets, saliency, mask):
|
||||
self._check_evaluate_param(explainer, inputs, targets, saliency)
|
||||
check_value_type('mask', mask, (Tensor, np.ndarray))
|
||||
if len(inputs.shape) != 4:
|
||||
raise ValueError('Argument mask must be 4D Tensor')
|
||||
if mask is None:
|
||||
raise ValueError('To compute localization, mask must be provided.')
|
||||
check_value_type('mask', mask, (Tensor, np.ndarray))
|
||||
if len(mask.shape) != 4 or len(mask) != len(inputs):
|
||||
raise ValueError("The input mask must be 4-dimensional (1, 1, h, w) with same length of inputs.")
|
||||
|
|
|
@ -47,9 +47,16 @@ class AttributionMetric:
|
|||
"""Super class of XAI metric class used in classification scenarios."""
|
||||
|
||||
def __init__(self, num_labels=None):
|
||||
self._verify_params(num_labels)
|
||||
self._num_labels = num_labels
|
||||
self._global_results = {i: [] for i in range(num_labels)}
|
||||
|
||||
@staticmethod
|
||||
def _verify_params(num_labels):
|
||||
check_value_type("num_labels", num_labels, int)
|
||||
if num_labels < 1:
|
||||
raise ValueError("Argument num_labels must be parsed with a integer > 0.")
|
||||
|
||||
def evaluate(self, explainer, inputs, targets, saliency=None):
|
||||
"""This function evaluates on a single sample and return the result."""
|
||||
raise NotImplementedError
|
||||
|
@ -119,5 +126,11 @@ class AttributionMetric:
|
|||
"""Check the evaluate parameters."""
|
||||
check_value_type('explainer', explainer, Attribution)
|
||||
verify_argument(inputs, 'inputs')
|
||||
output = explainer.model(inputs)
|
||||
check_value_type("output of explainer model", output, Tensor)
|
||||
output_dim = explainer.model(inputs).shape[1]
|
||||
if output_dim > self._num_labels:
|
||||
raise ValueError("The output dimension of of black-box model in explainer should not exceed the dimension "
|
||||
"of num_labels set in the __init__, please set num_labels larger.")
|
||||
verify_targets(targets, self._num_labels)
|
||||
check_value_type('saliency', saliency, (Tensor, type(None)))
|
||||
|
|
Loading…
Reference in New Issue