forked from mindspore-Ecosystem/mindspore
!9115 fix get_next bug in explainer._runner and NaN dataset.
From: @yuhanshi Reviewed-by: @wuxuejian,@ouwenchang Signed-off-by: @wuxuejian
This commit is contained in:
commit
627d045376
|
@ -159,13 +159,14 @@ class ExplainRunner:
|
|||
label probability distribution :math:`P(y|x)`. Default: Softmax().
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer import ExplainRunner
|
||||
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
|
||||
>>> from mindspore.nn import Sigmoid
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # obtain dataset object
|
||||
>>> dataset = get_dataset()
|
||||
>>> classes = ["cat", "dog", ...]
|
||||
>>> # load checkpoint to a network, e.g. resnet50
|
||||
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
|
||||
>>> dataset = get_dataset('/path/to/Cifar10_dataset')
|
||||
>>> classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck']
|
||||
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
|
||||
>>> param_dict = load_checkpoint("checkpoint.ckpt")
|
||||
>>> net = resnet50(len(classes))
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
|
@ -204,7 +205,7 @@ class ExplainRunner:
|
|||
check_value_type("activation_fn", activation_fn, Cell)
|
||||
|
||||
self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn])
|
||||
next_element = dataset.create_tuple_iterator().get_next()
|
||||
next_element = next(dataset.create_tuple_iterator())
|
||||
inputs, _, _ = self._unpack_next_element(next_element)
|
||||
prop_test = self._model(inputs)
|
||||
check_value_type("output of model im explainer", prop_test, ms.Tensor)
|
||||
|
@ -314,7 +315,7 @@ class ExplainRunner:
|
|||
dataset (`ds`): the user parsed dataset.
|
||||
benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers.
|
||||
"""
|
||||
next_element = dataset.create_tuple_iterator().get_next()
|
||||
next_element = next(dataset.create_tuple_iterator())
|
||||
|
||||
if len(next_element) not in [1, 2, 3]:
|
||||
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
|
||||
|
@ -611,29 +612,27 @@ class ExplainRunner:
|
|||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
for idx, inp in enumerate(inputs):
|
||||
inp = _EXPAND_DIMS(inp, 0)
|
||||
saliency_dict = saliency_dict_lst[idx]
|
||||
for label, saliency in saliency_dict.items():
|
||||
if isinstance(benchmarker, Localization):
|
||||
_, _, bboxes = self._unpack_next_element(next_element, True)
|
||||
if label in labels[idx]:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
|
||||
saliency=saliency)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
if isinstance(benchmarker, LabelAgnosticMetric):
|
||||
res = benchmarker.evaluate(explainer, inp)
|
||||
res[np.isnan(res)] = 0.0
|
||||
benchmarker.aggregate(res)
|
||||
else:
|
||||
saliency_dict = saliency_dict_lst[idx]
|
||||
for label, saliency in saliency_dict.items():
|
||||
if isinstance(benchmarker, Localization):
|
||||
_, _, bboxes = self._unpack_next_element(next_element, True)
|
||||
if label in labels[idx]:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
|
||||
saliency=saliency)
|
||||
res[np.isnan(res)] = 0.0
|
||||
benchmarker.aggregate(res, label)
|
||||
elif isinstance(benchmarker, LabelSensitiveMetric):
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
|
||||
res[np.isnan(res)] = 0.0
|
||||
benchmarker.aggregate(res, label)
|
||||
elif isinstance(benchmarker, LabelSensitiveMetric):
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
benchmarker.aggregate(res, label)
|
||||
elif isinstance(benchmarker, LabelAgnosticMetric):
|
||||
res = benchmarker.evaluate(explainer, inp)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
benchmarker.aggregate(res)
|
||||
else:
|
||||
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
|
||||
'receive {}'.format(type(benchmarker)))
|
||||
else:
|
||||
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
|
||||
'receive {}'.format(type(benchmarker)))
|
||||
|
||||
def _save_original_image(self, sample_id: int, image):
|
||||
"""Save an image to summary directory."""
|
||||
|
|
Loading…
Reference in New Issue