!9115 fix get_next bug in explainer._runner and NaN dataset.

From: @yuhanshi
Reviewed-by: @wuxuejian,@ouwenchang
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2020-11-28 08:48:45 +08:00 committed by Gitee
commit 627d045376
1 changed files with 27 additions and 28 deletions

View File

@ -159,13 +159,14 @@ class ExplainRunner:
label probability distribution :math:`P(y|x)`. Default: Softmax().
Examples:
>>> from mindspore.explainer import ExplainRunner
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
>>> from mindspore.nn import Sigmoid
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
>>> # obtain dataset object
>>> dataset = get_dataset()
>>> classes = ["cat", "dog", ...]
>>> # load checkpoint to a network, e.g. resnet50
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
>>> dataset = get_dataset('/path/to/Cifar10_dataset')
>>> classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck']
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
>>> param_dict = load_checkpoint("checkpoint.ckpt")
>>> net = resnet50(len(classes))
>>> load_param_into_net(net, param_dict)
@ -204,7 +205,7 @@ class ExplainRunner:
check_value_type("activation_fn", activation_fn, Cell)
self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn])
next_element = dataset.create_tuple_iterator().get_next()
next_element = next(dataset.create_tuple_iterator())
inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._model(inputs)
check_value_type("output of model im explainer", prop_test, ms.Tensor)
@ -314,7 +315,7 @@ class ExplainRunner:
dataset (`ds`): the user parsed dataset.
benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers.
"""
next_element = dataset.create_tuple_iterator().get_next()
next_element = next(dataset.create_tuple_iterator())
if len(next_element) not in [1, 2, 3]:
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
@ -611,29 +612,27 @@ class ExplainRunner:
inputs, labels, _ = self._unpack_next_element(next_element)
for idx, inp in enumerate(inputs):
inp = _EXPAND_DIMS(inp, 0)
saliency_dict = saliency_dict_lst[idx]
for label, saliency in saliency_dict.items():
if isinstance(benchmarker, Localization):
_, _, bboxes = self._unpack_next_element(next_element, True)
if label in labels[idx]:
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
saliency=saliency)
if np.any(res == np.nan):
res = np.zeros_like(res)
if isinstance(benchmarker, LabelAgnosticMetric):
res = benchmarker.evaluate(explainer, inp)
res[np.isnan(res)] = 0.0
benchmarker.aggregate(res)
else:
saliency_dict = saliency_dict_lst[idx]
for label, saliency in saliency_dict.items():
if isinstance(benchmarker, Localization):
_, _, bboxes = self._unpack_next_element(next_element, True)
if label in labels[idx]:
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
saliency=saliency)
res[np.isnan(res)] = 0.0
benchmarker.aggregate(res, label)
elif isinstance(benchmarker, LabelSensitiveMetric):
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
res[np.isnan(res)] = 0.0
benchmarker.aggregate(res, label)
elif isinstance(benchmarker, LabelSensitiveMetric):
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
if np.any(res == np.nan):
res = np.zeros_like(res)
benchmarker.aggregate(res, label)
elif isinstance(benchmarker, LabelAgnosticMetric):
res = benchmarker.evaluate(explainer, inp)
if np.any(res == np.nan):
res = np.zeros_like(res)
benchmarker.aggregate(res)
else:
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
'receive {}'.format(type(benchmarker)))
else:
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
'receive {}'.format(type(benchmarker)))
def _save_original_image(self, sample_id: int, image):
"""Save an image to summary directory."""