forked from mindspore-Ecosystem/mindspore
!9275 Reform explain runner api
From: @ngtony Reviewed-by: Signed-off-by:
This commit is contained in:
commit
4b539980bf
|
@ -12,8 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Provide ExplainRunner High-level API."""
|
||||
"""Provides explanation runner high-level APIs."""
|
||||
|
||||
from ._runner import ExplainRunner
|
||||
from ._image_classification_runner import ImageClassificationRunner
|
||||
|
||||
__all__ = ['ExplainRunner']
|
||||
__all__ = ['ImageClassificationRunner']
|
||||
|
|
|
@ -0,0 +1,699 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Image Classification Runner."""
|
||||
import os
|
||||
import re
|
||||
from time import time
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log
|
||||
from mindspore.dataset.engine.datasets import Dataset
|
||||
from mindspore.nn import Cell, SequentialCell
|
||||
from mindspore.ops.operations import ExpandDims
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.train.summary._summary_adapter import _convert_image_format
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from mindspore.train.summary_pb2 import Explain
|
||||
from .benchmark import Localization
|
||||
from .explanation import RISE
|
||||
from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric
|
||||
from .explanation._attribution.attribution import Attribution
|
||||
|
||||
_EXPAND_DIMS = ExpandDims()
|
||||
|
||||
|
||||
def _normalize(img_np):
|
||||
"""Normalize the numpy image to the range of [0, 1]. """
|
||||
max_ = img_np.max()
|
||||
min_ = img_np.min()
|
||||
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
|
||||
return normed
|
||||
|
||||
|
||||
def _np_to_image(img_np, mode):
|
||||
"""Convert numpy array to PIL image."""
|
||||
return Image.fromarray(np.uint8(img_np * 255), mode=mode)
|
||||
|
||||
|
||||
class ImageClassificationRunner:
|
||||
"""
|
||||
A high-level API for users to generate and store results of the explanation methods and the evaluation methods.
|
||||
|
||||
Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version
|
||||
will be deprecated and will not be supported in MindInsight of current version.
|
||||
|
||||
Args:
|
||||
summary_dir (str): The directory path to save the summary files which store the generated results.
|
||||
data (tuple[Dataset, list[str]]): Tuple of dataset and the corresponding class label list. The dataset
|
||||
should provides [images], [images, labels] or [images, labels, bboxes] as columns. The label list must
|
||||
share the exact same length and order of the network outputs.
|
||||
network (Cell): The network(with logit outputs) to be explained.
|
||||
activation_fn (Cell): The activation function for converting network's output to probabilities.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer import ImageClassificationRunner
|
||||
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
|
||||
>>> from mindspore.explainer.benchmark import Faithfulness
|
||||
>>> from mindspore.nn import Softmax
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
|
||||
>>> dataset = get_dataset('/path/to/Cifar10_dataset')
|
||||
>>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck']
|
||||
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
|
||||
>>> param_dict = load_checkpoint("checkpoint.ckpt")
|
||||
>>> net = resnet50(len(classes))
|
||||
>>> activation_fn = Softmax()
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> gbp = GuidedBackprop(net)
|
||||
>>> gradient = Gradient(net)
|
||||
>>> explainers = [gbp, gradient]
|
||||
>>> faithfulness = Faithfulness(len(labels), "NaiveFaithfulness", activation_fn)
|
||||
>>> benchmarkers = [faithfulness]
|
||||
>>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn)
|
||||
>>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers)
|
||||
>>> runner.run()
|
||||
"""
|
||||
|
||||
# datafile directory names
|
||||
_DATAFILE_DIRNAME_PREFIX = "_explain_"
|
||||
_ORIGINAL_IMAGE_DIRNAME = "origin_images"
|
||||
_HEATMAP_DIRNAME = "heatmap"
|
||||
# max. no. of sample per directory
|
||||
_SAMPLE_PER_DIR = 1000
|
||||
# seed for fixing the iterating order of the dataset
|
||||
_DATASET_SEED = 58
|
||||
# printing spacer
|
||||
_SPACER = "{:120}\r"
|
||||
# file permission for writing files
|
||||
_FILE_MODE = 0o600
|
||||
|
||||
def __init__(self,
|
||||
summary_dir,
|
||||
data,
|
||||
network,
|
||||
activation_fn):
|
||||
|
||||
check_value_type("data", data, tuple)
|
||||
if len(data) != 2:
|
||||
raise ValueError("Argument data is not a tuple with 2 elements")
|
||||
check_value_type("data[0]", data[0], Dataset)
|
||||
check_value_type("data[1]", data[1], list)
|
||||
if not all(isinstance(ele, str) for ele in data[1]):
|
||||
raise ValueError("Argument data[1] is not list of str.")
|
||||
|
||||
check_value_type("summary_dir", summary_dir, str)
|
||||
check_value_type("network", network, Cell)
|
||||
check_value_type("activation_fn", activation_fn, Cell)
|
||||
|
||||
self._summary_dir = summary_dir
|
||||
self._dataset = data[0]
|
||||
self._labels = data[1]
|
||||
self._network = network
|
||||
self._explainers = None
|
||||
self._benchmarkers = None
|
||||
self._summary_timestamp = None
|
||||
self._sample_index = -1
|
||||
|
||||
self._full_network = SequentialCell([self._network, activation_fn])
|
||||
|
||||
self._verify_data_n_settings(check_data_n_network=True)
|
||||
|
||||
def register_saliency(self,
|
||||
explainers,
|
||||
benchmarkers=None):
|
||||
"""
|
||||
Register saliency explanation instances.
|
||||
|
||||
Note:
|
||||
This function call not be invoked more then once on each runner.
|
||||
|
||||
Args:
|
||||
explainers (list[Attribution]): The explainers to be evaluated,
|
||||
see `mindspore.explainer.explanation`. All explainers' class must be distinct and their network
|
||||
must be the exact same instance of the runner's network.
|
||||
benchmarkers (list[AttributionMetric], optional): The benchmarkers for scoring the explainers,
|
||||
see `mindspore.explainer.benchmark`. All benchmarkers' class must be distinct.
|
||||
|
||||
Raises:
|
||||
ValueError: Be raised for any data or settings' value problem.
|
||||
TypeError: Be raised for any data or settings' type problem.
|
||||
RuntimeError: Be raised if this function was invoked before.
|
||||
"""
|
||||
check_value_type("explainers", explainers, list)
|
||||
if not all(isinstance(ele, Attribution) for ele in explainers):
|
||||
raise TypeError("Argument explainers is not list of mindspore.explainer.explanation .")
|
||||
|
||||
if not explainers:
|
||||
raise ValueError("Argument explainers is empty.")
|
||||
|
||||
if benchmarkers:
|
||||
check_value_type("benchmarkers", benchmarkers, list)
|
||||
if not all(isinstance(ele, AttributionMetric) for ele in benchmarkers):
|
||||
raise TypeError("Argument benchmarkers is not list of mindspore.explainer.benchmark .")
|
||||
|
||||
if self._explainers is not None:
|
||||
raise RuntimeError("Function register_saliency() was invoked already.")
|
||||
|
||||
self._explainers = explainers
|
||||
self._benchmarkers = benchmarkers
|
||||
|
||||
try:
|
||||
self._verify_data_n_settings(check_saliency=True)
|
||||
except (ValueError, TypeError):
|
||||
self._explainers = None
|
||||
self._benchmarkers = None
|
||||
raise
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the explain job and save the result as a summary in summary_dir.
|
||||
|
||||
Note:
|
||||
User should call register_saliency() once before running this function.
|
||||
|
||||
Raises:
|
||||
ValueError: Be raised for any data or settings' value problem.
|
||||
TypeError: Be raised for any data or settings' type problem.
|
||||
RuntimeError: Be raised for any runtime problem.
|
||||
"""
|
||||
self._verify_data_n_settings(check_all=True)
|
||||
|
||||
with SummaryRecord(self._summary_dir) as summary:
|
||||
print("Start running and writing......")
|
||||
begin = time()
|
||||
|
||||
self._summary_timestamp = self._extract_timestamp(summary.event_file_name)
|
||||
if self._summary_timestamp is None:
|
||||
raise RuntimeError("Cannot extract timestamp from summary filename!"
|
||||
" It should contains a timestamp after 'summary.' .")
|
||||
|
||||
self._save_metadata(summary)
|
||||
|
||||
imageid_labels = self._run_inference(summary)
|
||||
if self._is_saliency_registered:
|
||||
self._run_saliency(summary, imageid_labels)
|
||||
|
||||
print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))
|
||||
|
||||
@property
|
||||
def _is_saliency_registered(self):
|
||||
"""Check if saliency module is registered."""
|
||||
return bool(self._explainers)
|
||||
|
||||
def _save_metadata(self, summary):
|
||||
"""Save metadata of the explain job to summary."""
|
||||
print("Start writing metadata......")
|
||||
|
||||
explain = Explain()
|
||||
explain.metadata.label.extend(self._labels)
|
||||
|
||||
if self._is_saliency_registered:
|
||||
exp_names = [exp.__class__.__name__ for exp in self._explainers]
|
||||
explain.metadata.explain_method.extend(exp_names)
|
||||
if self._benchmarkers is not None:
|
||||
bench_names = [bench.__class__.__name__ for bench in self._benchmarkers]
|
||||
explain.metadata.benchmark_method.extend(bench_names)
|
||||
|
||||
summary.add_value("explainer", "metadata", explain)
|
||||
summary.record(1)
|
||||
|
||||
print("Finish writing metadata.")
|
||||
|
||||
def _run_inference(self, summary, threshold=0.5):
|
||||
"""
|
||||
Run inference for the dataset and write the inference related data into summary.
|
||||
|
||||
Args:
|
||||
summary (SummaryRecord): The summary object to store the data
|
||||
threshold (float): The threshold for prediction.
|
||||
|
||||
Returns:
|
||||
dict, The map of sample d to the union of its ground truth and predicted labels.
|
||||
"""
|
||||
sample_id_labels = {}
|
||||
self._sample_index = 0
|
||||
ds.config.set_seed(self._DATASET_SEED)
|
||||
for j, next_element in enumerate(self._dataset):
|
||||
now = time()
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
prob = self._full_network(inputs).asnumpy()
|
||||
|
||||
for idx, inp in enumerate(inputs):
|
||||
gt_labels = labels[idx]
|
||||
gt_probs = [float(prob[idx][i]) for i in gt_labels]
|
||||
|
||||
data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
|
||||
original_image = _np_to_image(_normalize(data_np), mode='RGB')
|
||||
original_image_path = self._save_original_image(self._sample_index, original_image)
|
||||
|
||||
predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
|
||||
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
|
||||
|
||||
union_labs = list(set(gt_labels + predicted_labels))
|
||||
sample_id_labels[str(self._sample_index)] = union_labs
|
||||
|
||||
explain = Explain()
|
||||
explain.sample_id = self._sample_index
|
||||
explain.image_path = original_image_path
|
||||
summary.add_value("explainer", "sample", explain)
|
||||
|
||||
explain = Explain()
|
||||
explain.sample_id = self._sample_index
|
||||
explain.ground_truth_label.extend(gt_labels)
|
||||
explain.inference.ground_truth_prob.extend(gt_probs)
|
||||
explain.inference.predicted_label.extend(predicted_labels)
|
||||
explain.inference.predicted_prob.extend(predicted_probs)
|
||||
|
||||
summary.add_value("explainer", "inference", explain)
|
||||
|
||||
summary.record(1)
|
||||
|
||||
self._sample_index += 1
|
||||
self._spaced_print("Finish running and writing {}-th batch inference data."
|
||||
" Time elapsed: {:.3f} s".format(j, time() - now),
|
||||
end='')
|
||||
return sample_id_labels
|
||||
|
||||
def _run_saliency(self, summary, sample_id_labels):
|
||||
"""Run the saliency explanations."""
|
||||
if self._benchmarkers is None or not self._benchmarkers:
|
||||
for exp in self._explainers:
|
||||
start = time()
|
||||
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
|
||||
self._sample_index = 0
|
||||
ds.config.set_seed(self._DATASET_SEED)
|
||||
for idx, next_element in enumerate(self._dataset):
|
||||
now = time()
|
||||
self._run_exp_step(next_element, exp, sample_id_labels, summary)
|
||||
self._spaced_print("Finish writing {}-th explanation data for {}. Time elapsed: "
|
||||
"{:.3f} s".format(idx, exp.__class__.__name__, time() - now), end='')
|
||||
self._spaced_print(
|
||||
"Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
|
||||
exp.__class__.__name__, time() - start))
|
||||
else:
|
||||
for exp in self._explainers:
|
||||
explain = Explain()
|
||||
for bench in self._benchmarkers:
|
||||
bench.reset()
|
||||
print(f"Start running and writing explanation and "
|
||||
f"benchmark data for {exp.__class__.__name__}......")
|
||||
self._sample_index = 0
|
||||
start = time()
|
||||
ds.config.set_seed(self._DATASET_SEED)
|
||||
for idx, next_element in enumerate(self._dataset):
|
||||
now = time()
|
||||
saliency_dict_lst = self._run_exp_step(next_element, exp, sample_id_labels, summary)
|
||||
self._spaced_print(
|
||||
"Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
|
||||
idx, exp.__class__.__name__, time() - now), end='')
|
||||
for bench in self._benchmarkers:
|
||||
now = time()
|
||||
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
|
||||
self._spaced_print(
|
||||
"Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
|
||||
idx, bench.__class__.__name__, exp.__class__.__name__, time() - now), end='')
|
||||
|
||||
for bench in self._benchmarkers:
|
||||
benchmark = explain.benchmark.add()
|
||||
benchmark.explain_method = exp.__class__.__name__
|
||||
benchmark.benchmark_method = bench.__class__.__name__
|
||||
|
||||
benchmark.total_score = bench.performance
|
||||
if isinstance(bench, LabelSensitiveMetric):
|
||||
benchmark.label_score.extend(bench.class_performances)
|
||||
|
||||
self._spaced_print("Finish running and writing explanation and benchmark data for {}. "
|
||||
"Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start))
|
||||
summary.add_value('explainer', 'benchmark', explain)
|
||||
summary.record(1)
|
||||
|
||||
def _run_exp_step(self, next_element, explainer, sample_id_labels, summary):
|
||||
"""
|
||||
Run the explanation for each step and write explanation results into summary.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): Data of one step
|
||||
explainer (_Attribution): An Attribution object to generate saliency maps.
|
||||
sample_id_labels (dict): A dict that maps the sample id and its union labels.
|
||||
summary (SummaryRecord): The summary object to store the data
|
||||
|
||||
Returns:
|
||||
list, List of dict that maps label to its corresponding saliency map.
|
||||
"""
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
sample_index = self._sample_index
|
||||
unions = []
|
||||
for _ in range(len(labels)):
|
||||
unions_labels = sample_id_labels[str(sample_index)]
|
||||
unions.append(unions_labels)
|
||||
sample_index += 1
|
||||
|
||||
batch_unions = self._make_label_batch(unions)
|
||||
saliency_dict_lst = []
|
||||
|
||||
if isinstance(explainer, RISE):
|
||||
batch_saliency_full = explainer(inputs, batch_unions)
|
||||
else:
|
||||
batch_saliency_full = []
|
||||
for i in range(len(batch_unions[0])):
|
||||
batch_saliency = explainer(inputs, batch_unions[:, i])
|
||||
batch_saliency_full.append(batch_saliency)
|
||||
concat = ms.ops.operations.Concat(1)
|
||||
batch_saliency_full = concat(tuple(batch_saliency_full))
|
||||
|
||||
for idx, union in enumerate(unions):
|
||||
saliency_dict = {}
|
||||
explain = Explain()
|
||||
explain.sample_id = self._sample_index
|
||||
for k, lab in enumerate(union):
|
||||
saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
|
||||
saliency_dict[lab] = saliency
|
||||
|
||||
saliency_np = _normalize(saliency.asnumpy().squeeze())
|
||||
saliency_image = _np_to_image(saliency_np, mode='L')
|
||||
heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._sample_index, saliency_image)
|
||||
|
||||
explanation = explain.explanation.add()
|
||||
explanation.explain_method = explainer.__class__.__name__
|
||||
explanation.heatmap_path = heatmap_path
|
||||
explanation.label = lab
|
||||
|
||||
summary.add_value("explainer", "explanation", explain)
|
||||
summary.record(1)
|
||||
|
||||
self._sample_index += 1
|
||||
saliency_dict_lst.append(saliency_dict)
|
||||
return saliency_dict_lst
|
||||
|
||||
def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
|
||||
"""Run the explanation and evaluation for each step and write explanation results into summary."""
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
for idx, inp in enumerate(inputs):
|
||||
inp = _EXPAND_DIMS(inp, 0)
|
||||
saliency_dict = saliency_dict_lst[idx]
|
||||
for label, saliency in saliency_dict.items():
|
||||
if isinstance(benchmarker, Localization):
|
||||
_, _, bboxes = self._unpack_next_element(next_element, True)
|
||||
if label in labels[idx]:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
|
||||
saliency=saliency)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
benchmarker.aggregate(res, label)
|
||||
elif isinstance(benchmarker, LabelSensitiveMetric):
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
benchmarker.aggregate(res, label)
|
||||
elif isinstance(benchmarker, LabelAgnosticMetric):
|
||||
res = benchmarker.evaluate(explainer, inp)
|
||||
if np.any(res == np.nan):
|
||||
res = np.zeros_like(res)
|
||||
benchmarker.aggregate(res)
|
||||
else:
|
||||
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
|
||||
'receive {}'.format(type(benchmarker)))
|
||||
|
||||
def _verify_data(self):
|
||||
"""Verify dataset and labels."""
|
||||
next_element = next(self._dataset.create_tuple_iterator())
|
||||
|
||||
if len(next_element) not in [1, 2, 3]:
|
||||
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
|
||||
" as columns.")
|
||||
|
||||
if len(next_element) == 3:
|
||||
inputs, labels, bboxes = next_element
|
||||
if bboxes.shape[-1] != 4:
|
||||
raise ValueError("The third element of dataset should be bounding boxes with shape of "
|
||||
"[batch_size, num_ground_truth, 4].")
|
||||
else:
|
||||
if self._benchmarkers is not None:
|
||||
if any([isinstance(bench, Localization) for bench in self._benchmarkers]):
|
||||
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
|
||||
|
||||
if len(next_element) == 2:
|
||||
inputs, labels = next_element
|
||||
if len(next_element) == 1:
|
||||
inputs = next_element[0]
|
||||
|
||||
if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
|
||||
raise ValueError(
|
||||
"Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
|
||||
inputs.shape))
|
||||
if len(inputs.shape) == 3:
|
||||
log.warning(
|
||||
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
|
||||
" dimension as batch data.".format(inputs.shape))
|
||||
if len(next_element) > 1:
|
||||
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
|
||||
raise ValueError(
|
||||
"Labels shape {} is unrecognizable: outputs should not have more than two dimensions"
|
||||
" with length greater than 1.".format(labels.shape))
|
||||
|
||||
def _verify_network(self):
|
||||
"""Verify the network."""
|
||||
label_set = set()
|
||||
for i, label in enumerate(self._labels):
|
||||
if label.strip() == "":
|
||||
raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is "
|
||||
f"no empty label.")
|
||||
if label in label_set:
|
||||
raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.")
|
||||
label_set.add(label)
|
||||
|
||||
next_element = next(self._dataset.create_tuple_iterator())
|
||||
inputs, _, _ = self._unpack_next_element(next_element)
|
||||
prop_test = self._full_network(inputs)
|
||||
check_value_type("output of network in explainer", prop_test, ms.Tensor)
|
||||
if prop_test.shape[1] != len(self._labels):
|
||||
raise ValueError("The dimension of network output does not match the no. of classes. Please "
|
||||
"check labels or the network in the explainer again.")
|
||||
|
||||
def _verify_saliency(self):
|
||||
"""Verify the saliency settings."""
|
||||
if self._explainers:
|
||||
explainer_classes = []
|
||||
for explainer in self._explainers:
|
||||
if explainer.__class__ in explainer_classes:
|
||||
raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! "
|
||||
"Please make sure all explainers' class is distinct.")
|
||||
if explainer.model != self._network:
|
||||
raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different "
|
||||
"instance from network of runner. Please make sure they are the same "
|
||||
"instance.")
|
||||
explainer_classes.append(explainer.__class__)
|
||||
if self._benchmarkers:
|
||||
benchmarker_classes = []
|
||||
for benchmarker in self._benchmarkers:
|
||||
if benchmarker.__class__ in benchmarker_classes:
|
||||
raise ValueError(f"Repeated {benchmarker.__class__.__name__} benchmarker! "
|
||||
"Please make sure all benchmarkers' class is distinct.")
|
||||
if isinstance(benchmarker, LabelSensitiveMetric) and benchmarker.num_labels != len(self._labels):
|
||||
raise ValueError(f"The num_labels of {benchmarker.__class__.__name__} benchmarker is different "
|
||||
"from no. of labels of runner. Please make them are the same.")
|
||||
benchmarker_classes.append(benchmarker.__class__)
|
||||
|
||||
def _verify_data_n_settings(self,
|
||||
check_all=False,
|
||||
check_registration=False,
|
||||
check_data_n_network=False,
|
||||
check_saliency=False):
|
||||
"""
|
||||
Verify the validity of dataset and other settings.
|
||||
|
||||
Args:
|
||||
check_all (bool): Set it True for checking everything.
|
||||
check_registration (bool): Set it True for checking registrations, check if it is enough to invoke run().
|
||||
check_data_n_network (bool): Set it True for checking data and network.
|
||||
check_saliency (bool): Set it True for checking saliency related settings.
|
||||
|
||||
Raises:
|
||||
ValueError: Be raised for any data or settings' value problem.
|
||||
TypeError: Be raised for any data or settings' type problem.
|
||||
"""
|
||||
if check_all:
|
||||
check_registration = True
|
||||
check_data_n_network = True
|
||||
check_saliency = True
|
||||
|
||||
if check_registration:
|
||||
if not self._is_saliency_registered:
|
||||
raise ValueError("No explanation module was registered, user should at least call register_saliency()"
|
||||
" once with proper explanation instances")
|
||||
|
||||
if check_data_n_network or check_saliency:
|
||||
self._verify_data()
|
||||
|
||||
if check_data_n_network:
|
||||
self._verify_network()
|
||||
|
||||
if check_saliency:
|
||||
self._verify_saliency()
|
||||
|
||||
def _transform_data(self, inputs, labels, bboxes, ifbbox):
|
||||
"""
|
||||
Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): the image data
|
||||
labels (Tensor): the labels
|
||||
bboxes (Tensor): the boudnding boxes data
|
||||
ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t
|
||||
label id will be returned. If False, the returned bboxes is the the parsed bboxes.
|
||||
|
||||
Returns:
|
||||
inputs (Tensor): the image data, unified to a 4D Tensor.
|
||||
labels (list[list[int]]): the ground truth labels.
|
||||
bboxes (Union[list[dict], None, Tensor]): the bounding boxes
|
||||
"""
|
||||
inputs = ms.Tensor(inputs, ms.float32)
|
||||
if len(inputs.shape) == 3:
|
||||
inputs = _EXPAND_DIMS(inputs, 0)
|
||||
if isinstance(labels, ms.Tensor):
|
||||
labels = ms.Tensor(labels, ms.int32)
|
||||
labels = _EXPAND_DIMS(labels, 0)
|
||||
if isinstance(bboxes, ms.Tensor):
|
||||
bboxes = ms.Tensor(bboxes, ms.int32)
|
||||
bboxes = _EXPAND_DIMS(bboxes, 0)
|
||||
|
||||
input_len = len(inputs)
|
||||
if bboxes is not None and ifbbox:
|
||||
bboxes = ms.Tensor(bboxes, ms.int32)
|
||||
masks_lst = []
|
||||
labels = labels.asnumpy().reshape([input_len, -1])
|
||||
bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
|
||||
for idx, label in enumerate(labels):
|
||||
height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
|
||||
masks = {}
|
||||
for j, label_item in enumerate(label):
|
||||
target = int(label_item)
|
||||
if -1 < target < len(self._labels):
|
||||
if target not in masks:
|
||||
mask = np.zeros((1, 1, height, width))
|
||||
else:
|
||||
mask = masks[target]
|
||||
x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
|
||||
mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
|
||||
masks[target] = mask
|
||||
|
||||
masks_lst.append(masks)
|
||||
bboxes = masks_lst
|
||||
|
||||
labels = ms.Tensor(labels, ms.int32)
|
||||
if len(labels.shape) == 1:
|
||||
labels_lst = [[int(i)] for i in labels.asnumpy()]
|
||||
else:
|
||||
labels = labels.asnumpy().reshape([input_len, -1])
|
||||
labels_lst = []
|
||||
for item in labels:
|
||||
labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._labels))))
|
||||
labels = labels_lst
|
||||
return inputs, labels, bboxes
|
||||
|
||||
def _unpack_next_element(self, next_element, ifbbox=False):
|
||||
"""
|
||||
Unpack a single iteration of dataset.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): a single element iterated from dataset object.
|
||||
ifbbox (bool): whether to preprocess bboxes in self._transform_data.
|
||||
|
||||
Returns:
|
||||
tuple, a unified Tuple contains image_data, labels, and bounding boxes.
|
||||
"""
|
||||
if len(next_element) == 3:
|
||||
inputs, labels, bboxes = next_element
|
||||
elif len(next_element) == 2:
|
||||
inputs, labels = next_element
|
||||
bboxes = None
|
||||
else:
|
||||
inputs = next_element[0]
|
||||
labels = [[] for _ in inputs]
|
||||
bboxes = None
|
||||
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
|
||||
return inputs, labels, bboxes
|
||||
|
||||
@staticmethod
|
||||
def _make_label_batch(labels):
|
||||
"""
|
||||
Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
|
||||
length of all the rows in labels.
|
||||
|
||||
Args:
|
||||
labels (List[List]): the union labels of a data batch.
|
||||
|
||||
Returns:
|
||||
2D Tensor.
|
||||
"""
|
||||
max_len = max([len(label) for label in labels])
|
||||
batch_labels = np.zeros((len(labels), max_len))
|
||||
|
||||
for idx, _ in enumerate(batch_labels):
|
||||
length = len(labels[idx])
|
||||
batch_labels[idx, :length] = np.array(labels[idx])
|
||||
|
||||
return ms.Tensor(batch_labels, ms.int32)
|
||||
|
||||
def _save_original_image(self, sample_id, image):
|
||||
"""Save an image to summary directory."""
|
||||
id_dirname = self._get_sample_dirname(sample_id)
|
||||
relative_dir = os.path.join(self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
|
||||
self._ORIGINAL_IMAGE_DIRNAME,
|
||||
id_dirname)
|
||||
abs_dir_path = os.path.abspath(os.path.join(self._summary_dir, relative_dir))
|
||||
os.makedirs(abs_dir_path, mode=self._FILE_MODE, exist_ok=True)
|
||||
filename = f"{sample_id}.jpg"
|
||||
save_path = os.path.join(abs_dir_path, filename)
|
||||
image.save(save_path)
|
||||
os.chmod(save_path, self._FILE_MODE)
|
||||
return os.path.join(relative_dir, filename)
|
||||
|
||||
def _save_heatmap(self, explain_method, class_id, sample_id, image):
|
||||
"""Save heatmap image to summary directory."""
|
||||
id_dirname = self._get_sample_dirname(sample_id)
|
||||
relative_dir = os.path.join(self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
|
||||
self._HEATMAP_DIRNAME,
|
||||
explain_method,
|
||||
id_dirname)
|
||||
abs_dir_path = os.path.abspath(os.path.join(self._summary_dir, relative_dir))
|
||||
os.makedirs(abs_dir_path, mode=self._FILE_MODE, exist_ok=True)
|
||||
filename = f"{sample_id}_{class_id}.jpg"
|
||||
save_path = os.path.join(abs_dir_path, filename)
|
||||
image.save(save_path)
|
||||
os.chmod(save_path, self._FILE_MODE)
|
||||
return os.path.join(relative_dir, filename)
|
||||
|
||||
@classmethod
|
||||
def _get_sample_dirname(cls, sample_id):
|
||||
"""Get the name of parent directory of the image id."""
|
||||
return str(int(sample_id / cls._SAMPLE_PER_DIR) * cls._SAMPLE_PER_DIR)
|
||||
|
||||
@staticmethod
|
||||
def _extract_timestamp(filename):
|
||||
"""Extract timestamp from summary filename."""
|
||||
matched = re.search(r"summary\.(\d+)", filename)
|
||||
if matched:
|
||||
return int(matched.group(1))
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _spaced_print(cls, message, *args, **kwargs):
|
||||
"""Spaced message printing."""
|
||||
print(cls._SPACER.format(message), *args, **kwargs)
|
|
@ -1,662 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Runner."""
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from time import time
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from scipy.stats import beta
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log
|
||||
from mindspore.nn import Softmax, Cell
|
||||
from mindspore.nn.probability.toolbox import UncertaintyEvaluation
|
||||
from mindspore.ops.operations import ExpandDims
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.train.summary._summary_adapter import _convert_image_format
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from mindspore.train.summary_pb2 import Explain
|
||||
from .benchmark import Localization
|
||||
from .explanation import RISE
|
||||
from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric
|
||||
from .explanation._attribution.attribution import Attribution
|
||||
|
||||
# datafile directory names
|
||||
_DATAFILE_DIRNAME_PREFIX = "_explain_"
|
||||
_ORIGINAL_IMAGE_DIRNAME = "origin_images"
|
||||
_HEATMAP_DIRNAME = "heatmap"
|
||||
# max. no. of sample per directory
|
||||
_SAMPLE_PER_DIR = 1000
|
||||
|
||||
_EXPAND_DIMS = ExpandDims()
|
||||
_SEED = 58 # set a seed to fix the iterating order of the dataset
|
||||
|
||||
|
||||
def _normalize(img_np):
|
||||
"""Normalize the numpy image to the range of [0, 1]. """
|
||||
max_ = img_np.max()
|
||||
min_ = img_np.min()
|
||||
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
|
||||
return normed
|
||||
|
||||
|
||||
def _np_to_image(img_np, mode):
|
||||
"""Convert numpy array to PIL image."""
|
||||
return Image.fromarray(np.uint8(img_np * 255), mode=mode)
|
||||
|
||||
|
||||
def _calc_prob_interval(volume, probs, prob_vars):
|
||||
"""Compute the confidence interval of probability."""
|
||||
if not isinstance(probs, np.ndarray):
|
||||
probs = np.asarray(probs)
|
||||
if not isinstance(prob_vars, np.ndarray):
|
||||
prob_vars = np.asarray(prob_vars)
|
||||
one_minus_probs = 1 - probs
|
||||
alpha_coef = (np.square(probs) * one_minus_probs / prob_vars) - probs
|
||||
beta_coef = alpha_coef * one_minus_probs / probs
|
||||
intervals = beta.interval(volume, alpha_coef, beta_coef)
|
||||
|
||||
# avoid invalid result due to extreme small value of prob_vars
|
||||
lows = []
|
||||
highs = []
|
||||
for i, low in enumerate(intervals[0]):
|
||||
high = intervals[1][i]
|
||||
if prob_vars[i] <= 0 or \
|
||||
not np.isfinite(low) or low > probs[i] or \
|
||||
not np.isfinite(high) or high < probs[i]:
|
||||
low = probs[i]
|
||||
high = probs[i]
|
||||
lows.append(low)
|
||||
highs.append(high)
|
||||
|
||||
return lows, highs
|
||||
|
||||
|
||||
def _get_id_dirname(sample_id: int):
|
||||
"""Get the name of parent directory of the image id."""
|
||||
return str(int(sample_id / _SAMPLE_PER_DIR) * _SAMPLE_PER_DIR)
|
||||
|
||||
|
||||
def _extract_timestamp(filename: str):
|
||||
"""Extract timestamp from summary filename."""
|
||||
matched = re.search(r"summary\.(\d+)", filename)
|
||||
if matched:
|
||||
return int(matched.group(1))
|
||||
return None
|
||||
|
||||
|
||||
class ExplainRunner:
|
||||
"""
|
||||
A high-level API for users to generate and store results of the explanation methods and the evaluation methods.
|
||||
|
||||
After generating results with the explanation methods and the evaluation methods, the results will be written into
|
||||
a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight.
|
||||
|
||||
Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version
|
||||
will be deprecated and will not be supported in MindInsight of current version.
|
||||
|
||||
Args:
|
||||
summary_dir (str, optional): The directory path to save the summary files which store the generated results.
|
||||
Default: "./"
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer import ExplainRunner
|
||||
>>> # init a runner with a specified directory
|
||||
>>> summary_dir = "summary_dir"
|
||||
>>> runner = ExplainRunner(summary_dir)
|
||||
"""
|
||||
|
||||
def __init__(self, summary_dir: Optional[str] = "./"):
|
||||
check_value_type("summary_dir", summary_dir, str)
|
||||
self._summary_dir = summary_dir
|
||||
self._count = 0
|
||||
self._classes = None
|
||||
self._model = None
|
||||
self._uncertainty = None
|
||||
self._summary_timestamp = None
|
||||
|
||||
def run(self,
|
||||
dataset: Tuple,
|
||||
explainers: List,
|
||||
benchmarkers: Optional[List] = None,
|
||||
uncertainty: Optional[UncertaintyEvaluation] = None,
|
||||
activation_fn: Optional[Cell] = Softmax()):
|
||||
"""
|
||||
Genereates results and writes results into the summary files in `summary_dir` specified during the object
|
||||
initialization.
|
||||
|
||||
Args:
|
||||
dataset (tuple): A tuple that contains `mindspore.dataset` object for iteration and its labels.
|
||||
|
||||
- dataset[0]: A `mindspore.dataset` object to provide data to explain.
|
||||
- dataset[1]: A list of string that specifies the label names of the dataset.
|
||||
|
||||
explainers (list[Explanation]): A list of explanation objects to generate attribution results. Explanation
|
||||
object is an instance initialized with the explanation methods in module
|
||||
`mindspore.explainer.explanation`.
|
||||
benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results.
|
||||
Default: None
|
||||
uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference
|
||||
uncertainty of samples.
|
||||
activation_fn (Cell, optional): The activation layer that transforms the output of the network to
|
||||
label probability distribution :math:`P(y|x)`. Default: Softmax().
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.explainer import ExplainRunner
|
||||
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
|
||||
>>> from mindspore.nn import Softmax
|
||||
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
|
||||
>>> dataset = get_dataset('/path/to/Cifar10_dataset')
|
||||
>>> classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck']
|
||||
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
|
||||
>>> param_dict = load_checkpoint("checkpoint.ckpt")
|
||||
>>> net = resnet50(len(classes))
|
||||
>>> load_param_into_net(net, param_dict)
|
||||
>>> gbp = GuidedBackprop(net)
|
||||
>>> gradient = Gradient(net)
|
||||
>>> explainers = [gbp, gradient]
|
||||
>>> # runner is an ExplainRunner object
|
||||
>>> runner.run((dataset, classes), explainers, activation_fn=Softmax())
|
||||
"""
|
||||
|
||||
check_value_type("dataset", dataset, tuple)
|
||||
if len(dataset) != 2:
|
||||
raise ValueError("Argument `dataset` should be a tuple with length = 2.")
|
||||
|
||||
dataset, classes = dataset
|
||||
if benchmarkers is None:
|
||||
benchmarkers = []
|
||||
|
||||
self._verify_data_form(dataset, benchmarkers)
|
||||
self._classes = classes
|
||||
|
||||
check_value_type("explainers", explainers, list)
|
||||
if not explainers:
|
||||
raise ValueError("Argument `explainers` must be a non-empty list")
|
||||
|
||||
for exp in explainers:
|
||||
if not isinstance(exp, Attribution):
|
||||
raise TypeError("Argument `explainers` should be a list of objects of classes in "
|
||||
"`mindspore.explainer.explanation`.")
|
||||
if benchmarkers:
|
||||
check_value_type("benchmarkers", benchmarkers, list)
|
||||
for bench in benchmarkers:
|
||||
if not isinstance(bench, AttributionMetric):
|
||||
raise TypeError("Argument `benchmarkers` should be a list of objects of classes in explanation"
|
||||
"`mindspore.explainer.benchmark`.")
|
||||
check_value_type("activation_fn", activation_fn, Cell)
|
||||
|
||||
self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn])
|
||||
next_element = next(dataset.create_tuple_iterator())
|
||||
inputs, _, _ = self._unpack_next_element(next_element)
|
||||
prop_test = self._model(inputs)
|
||||
check_value_type("output of model im explainer", prop_test, ms.Tensor)
|
||||
if prop_test.shape[1] != len(self._classes):
|
||||
raise ValueError("The dimension of model output does not match the length of dataset classes. Please "
|
||||
"check dataset classes or the black-box model in the explainer again.")
|
||||
|
||||
if uncertainty is not None:
|
||||
check_value_type("uncertainty", uncertainty, UncertaintyEvaluation)
|
||||
prop_var_test = uncertainty.eval_epistemic_uncertainty(inputs)
|
||||
check_value_type("output of uncertainty", prop_var_test, np.ndarray)
|
||||
if prop_var_test.shape[1] != len(self._classes):
|
||||
raise ValueError("The dimension of uncertainty output does not match the length of dataset classes"
|
||||
"classes. Please check dataset classes or the black-box model in the explainer again.")
|
||||
self._uncertainty = uncertainty
|
||||
else:
|
||||
self._uncertainty = None
|
||||
|
||||
with SummaryRecord(self._summary_dir) as summary:
|
||||
spacer = '{:120}\r'
|
||||
print("Start running and writing......")
|
||||
begin = time()
|
||||
print("Start writing metadata......")
|
||||
|
||||
self._summary_timestamp = _extract_timestamp(summary.event_file_name)
|
||||
if self._summary_timestamp is None:
|
||||
raise RuntimeError("Cannot extract timestamp from summary filename!"
|
||||
" It should contains a timestamp of 10 digits.")
|
||||
|
||||
explain = Explain()
|
||||
explain.metadata.label.extend(classes)
|
||||
exp_names = [exp.__class__.__name__ for exp in explainers]
|
||||
explain.metadata.explain_method.extend(exp_names)
|
||||
if benchmarkers:
|
||||
bench_names = [bench.__class__.__name__ for bench in benchmarkers]
|
||||
explain.metadata.benchmark_method.extend(bench_names)
|
||||
|
||||
summary.add_value("explainer", "metadata", explain)
|
||||
summary.record(1)
|
||||
|
||||
print("Finish writing metadata.")
|
||||
|
||||
now = time()
|
||||
print("Start running and writing inference data.....")
|
||||
imageid_labels = self._run_inference(dataset, summary)
|
||||
print(spacer.format("Finish running and writing inference data. "
|
||||
"Time elapsed: {:.3f} s".format(time() - now)))
|
||||
|
||||
if not benchmarkers:
|
||||
for exp in explainers:
|
||||
start = time()
|
||||
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
|
||||
self._count = 0
|
||||
ds.config.set_seed(_SEED)
|
||||
for idx, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
self._run_exp_step(next_element, exp, imageid_labels, summary)
|
||||
print(spacer.format("Finish writing {}-th explanation data for {}. Time elapsed: "
|
||||
"{:.3f} s".format(idx, exp.__class__.__name__, time() - now)), end='')
|
||||
print(spacer.format(
|
||||
"Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
|
||||
exp.__class__.__name__, time() - start)))
|
||||
else:
|
||||
for exp in explainers:
|
||||
explain = Explain()
|
||||
for bench in benchmarkers:
|
||||
bench.reset()
|
||||
print(f"Start running and writing explanation and "
|
||||
f"benchmark data for {exp.__class__.__name__}......")
|
||||
self._count = 0
|
||||
start = time()
|
||||
ds.config.set_seed(_SEED)
|
||||
for idx, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary)
|
||||
print(spacer.format(
|
||||
"Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
|
||||
idx, exp.__class__.__name__, time() - now)), end='')
|
||||
for bench in benchmarkers:
|
||||
now = time()
|
||||
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
|
||||
print(spacer.format(
|
||||
"Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
|
||||
idx, bench.__class__.__name__, exp.__class__.__name__, time() - now)), end='')
|
||||
|
||||
for bench in benchmarkers:
|
||||
benchmark = explain.benchmark.add()
|
||||
benchmark.explain_method = exp.__class__.__name__
|
||||
benchmark.benchmark_method = bench.__class__.__name__
|
||||
|
||||
benchmark.total_score = bench.performance
|
||||
if isinstance(bench, LabelSensitiveMetric):
|
||||
benchmark.label_score.extend(bench.class_performances)
|
||||
|
||||
print(spacer.format("Finish running and writing explanation and benchmark data for {}. "
|
||||
"Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start)))
|
||||
summary.add_value('explainer', 'benchmark', explain)
|
||||
summary.record(1)
|
||||
print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))
|
||||
|
||||
@staticmethod
|
||||
def _verify_data_form(dataset, benchmarkers):
|
||||
"""
|
||||
Verify the validity of dataset.
|
||||
|
||||
Args:
|
||||
dataset (`ds`): the user parsed dataset.
|
||||
benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers.
|
||||
"""
|
||||
next_element = next(dataset.create_tuple_iterator())
|
||||
|
||||
if len(next_element) not in [1, 2, 3]:
|
||||
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
|
||||
" as columns.")
|
||||
|
||||
if len(next_element) == 3:
|
||||
inputs, labels, bboxes = next_element
|
||||
if bboxes.shape[-1] != 4:
|
||||
raise ValueError("The third element of dataset should be bounding boxes with shape of "
|
||||
"[batch_size, num_ground_truth, 4].")
|
||||
else:
|
||||
if any(map(lambda benchmarker: isinstance(benchmarker, Localization), benchmarkers)):
|
||||
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
|
||||
|
||||
if len(next_element) == 2:
|
||||
inputs, labels = next_element
|
||||
if len(next_element) == 1:
|
||||
inputs = next_element[0]
|
||||
|
||||
if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
|
||||
raise ValueError(
|
||||
"Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
|
||||
inputs.shape))
|
||||
if len(inputs.shape) == 3:
|
||||
log.warning(
|
||||
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
|
||||
" dimension as batch data.".format(inputs.shape))
|
||||
|
||||
if len(next_element) > 1:
|
||||
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
|
||||
raise ValueError(
|
||||
"Labels shape {} is unrecognizable: labels should not have more than two dimensions"
|
||||
" with length greater than 1.".format(labels.shape))
|
||||
|
||||
def _transform_data(self, inputs, labels, bboxes, ifbbox):
|
||||
"""
|
||||
Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): the image data
|
||||
labels (Tensor): the labels
|
||||
bboxes (Tensor): the boudnding boxes data
|
||||
ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t label
|
||||
id will be returned. If False, the returned bboxes is the the parsed bboxes.
|
||||
|
||||
Returns:
|
||||
inputs (Tensor): the image data, unified to a 4D Tensor.
|
||||
labels (List[List[int]]): the ground truth labels.
|
||||
bboxes (Union[List[Dict], None, Tensor]): the bounding boxes
|
||||
"""
|
||||
inputs = ms.Tensor(inputs, ms.float32)
|
||||
if len(inputs.shape) == 3:
|
||||
inputs = _EXPAND_DIMS(inputs, 0)
|
||||
if isinstance(labels, ms.Tensor):
|
||||
labels = ms.Tensor(labels, ms.int32)
|
||||
labels = _EXPAND_DIMS(labels, 0)
|
||||
if isinstance(bboxes, ms.Tensor):
|
||||
bboxes = ms.Tensor(bboxes, ms.int32)
|
||||
bboxes = _EXPAND_DIMS(bboxes, 0)
|
||||
|
||||
input_len = len(inputs)
|
||||
if bboxes is not None and ifbbox:
|
||||
bboxes = ms.Tensor(bboxes, ms.int32)
|
||||
masks_lst = []
|
||||
labels = labels.asnumpy().reshape([input_len, -1])
|
||||
bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
|
||||
for idx, label in enumerate(labels):
|
||||
height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
|
||||
masks = {}
|
||||
for j, label_item in enumerate(label):
|
||||
target = int(label_item)
|
||||
if -1 < target < len(self._classes):
|
||||
if target not in masks:
|
||||
mask = np.zeros((1, 1, height, width))
|
||||
else:
|
||||
mask = masks[target]
|
||||
x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
|
||||
mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
|
||||
masks[target] = mask
|
||||
|
||||
masks_lst.append(masks)
|
||||
bboxes = masks_lst
|
||||
|
||||
labels = ms.Tensor(labels, ms.int32)
|
||||
if len(labels.shape) == 1:
|
||||
labels_lst = [[int(i)] for i in labels.asnumpy()]
|
||||
else:
|
||||
labels = labels.asnumpy().reshape([input_len, -1])
|
||||
labels_lst = []
|
||||
for item in labels:
|
||||
labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._classes))))
|
||||
labels = labels_lst
|
||||
return inputs, labels, bboxes
|
||||
|
||||
def _unpack_next_element(self, next_element, ifbbox=False):
|
||||
"""
|
||||
Unpack a single iteration of dataset.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): a single element iterated from dataset object.
|
||||
ifbbox (bool): whether to preprocess bboxes in self._transform_data.
|
||||
|
||||
Returns:
|
||||
Tuple, a unified Tuple contains image_data, labels, and bounding boxes.
|
||||
"""
|
||||
if len(next_element) == 3:
|
||||
inputs, labels, bboxes = next_element
|
||||
elif len(next_element) == 2:
|
||||
inputs, labels = next_element
|
||||
bboxes = None
|
||||
else:
|
||||
inputs = next_element[0]
|
||||
labels = [[] for _ in inputs]
|
||||
bboxes = None
|
||||
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
|
||||
return inputs, labels, bboxes
|
||||
|
||||
@staticmethod
|
||||
def _make_label_batch(labels):
|
||||
"""
|
||||
Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
|
||||
length of all the rows in labels.
|
||||
|
||||
Args:
|
||||
labels (List[List]): the union labels of a data batch.
|
||||
|
||||
Returns:
|
||||
2D Tensor.
|
||||
"""
|
||||
|
||||
max_len = max([len(label) for label in labels])
|
||||
batch_labels = np.zeros((len(labels), max_len))
|
||||
|
||||
for idx, _ in enumerate(batch_labels):
|
||||
length = len(labels[idx])
|
||||
batch_labels[idx, :length] = np.array(labels[idx])
|
||||
|
||||
return ms.Tensor(batch_labels, ms.int32)
|
||||
|
||||
def _run_inference(self, dataset, summary, threshold=0.5):
|
||||
"""
|
||||
Run inference for the dataset and write the inference related data into summary.
|
||||
|
||||
Args:
|
||||
dataset (`ds`): the parsed dataset
|
||||
summary (`SummaryRecord`): the summary object to store the data
|
||||
threshold (float): the threshold for prediction.
|
||||
|
||||
Returns:
|
||||
imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
|
||||
"""
|
||||
spacer = '{:120}\r'
|
||||
imageid_labels = {}
|
||||
ds.config.set_seed(_SEED)
|
||||
self._count = 0
|
||||
for j, next_element in enumerate(dataset):
|
||||
now = time()
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
prob = self._model(inputs).asnumpy()
|
||||
if self._uncertainty is not None:
|
||||
prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
|
||||
prob_sd = np.sqrt(prob_var)
|
||||
else:
|
||||
prob_var = prob_sd = None
|
||||
|
||||
for idx, inp in enumerate(inputs):
|
||||
gt_labels = labels[idx]
|
||||
gt_probs = [float(prob[idx][i]) for i in gt_labels]
|
||||
|
||||
data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
|
||||
original_image = _np_to_image(_normalize(data_np), mode='RGB')
|
||||
original_image_path = self._save_original_image(self._count, original_image)
|
||||
|
||||
predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
|
||||
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
|
||||
|
||||
has_uncertainty = False
|
||||
gt_prob_sds = gt_prob_itl95_lows = gt_prob_itl95_his = None
|
||||
predicted_prob_sds = predicted_prob_itl95_lows = predicted_prob_itl95_his = None
|
||||
if prob_var is not None:
|
||||
gt_prob_sds = [float(prob_sd[idx][i]) for i in gt_labels]
|
||||
predicted_prob_sds = [float(prob_sd[idx][i]) for i in predicted_labels]
|
||||
try:
|
||||
gt_prob_itl95_lows, gt_prob_itl95_his = \
|
||||
_calc_prob_interval(0.95, gt_probs, [float(prob_var[idx][i]) for i in gt_labels])
|
||||
predicted_prob_itl95_lows, predicted_prob_itl95_his = \
|
||||
_calc_prob_interval(0.95, predicted_probs, [float(prob_var[idx][i])
|
||||
for i in predicted_labels])
|
||||
has_uncertainty = True
|
||||
except ValueError:
|
||||
log.error(traceback.format_exc())
|
||||
log.error("Error on calculating uncertainty")
|
||||
|
||||
union_labs = list(set(gt_labels + predicted_labels))
|
||||
imageid_labels[str(self._count)] = union_labs
|
||||
|
||||
explain = Explain()
|
||||
explain.sample_id = self._count
|
||||
explain.image_path = original_image_path
|
||||
summary.add_value("explainer", "sample", explain)
|
||||
|
||||
explain = Explain()
|
||||
explain.sample_id = self._count
|
||||
explain.ground_truth_label.extend(gt_labels)
|
||||
explain.inference.ground_truth_prob.extend(gt_probs)
|
||||
explain.inference.predicted_label.extend(predicted_labels)
|
||||
explain.inference.predicted_prob.extend(predicted_probs)
|
||||
|
||||
if has_uncertainty:
|
||||
explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
|
||||
explain.inference.ground_truth_prob_itl95_low.extend(gt_prob_itl95_lows)
|
||||
explain.inference.ground_truth_prob_itl95_hi.extend(gt_prob_itl95_his)
|
||||
|
||||
explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
|
||||
explain.inference.predicted_prob_itl95_low.extend(predicted_prob_itl95_lows)
|
||||
explain.inference.predicted_prob_itl95_hi.extend(predicted_prob_itl95_his)
|
||||
|
||||
summary.add_value("explainer", "inference", explain)
|
||||
|
||||
summary.record(1)
|
||||
|
||||
self._count += 1
|
||||
print(spacer.format("Finish running and writing {}-th batch inference data."
|
||||
" Time elapsed: {:.3f} s".format(j, time() - now)),
|
||||
end='')
|
||||
return imageid_labels
|
||||
|
||||
def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
|
||||
"""
|
||||
Run the explanation for each step and write explanation results into summary.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): data of one step
|
||||
explainer (_Attribution): an Attribution object to generate saliency maps.
|
||||
imageid_labels (dict): a dict that maps the image_id and its union labels.
|
||||
summary (SummaryRecord): the summary object to store the data
|
||||
|
||||
Returns:
|
||||
List of dict that maps label to its corresponding saliency map.
|
||||
"""
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
count = self._count
|
||||
unions = []
|
||||
for _ in range(len(labels)):
|
||||
unions_labels = imageid_labels[str(count)]
|
||||
unions.append(unions_labels)
|
||||
count += 1
|
||||
|
||||
batch_unions = self._make_label_batch(unions)
|
||||
saliency_dict_lst = []
|
||||
|
||||
if isinstance(explainer, RISE):
|
||||
batch_saliency_full = explainer(inputs, batch_unions)
|
||||
else:
|
||||
batch_saliency_full = []
|
||||
for i in range(len(batch_unions[0])):
|
||||
batch_saliency = explainer(inputs, batch_unions[:, i])
|
||||
batch_saliency_full.append(batch_saliency)
|
||||
concat = ms.ops.operations.Concat(1)
|
||||
batch_saliency_full = concat(tuple(batch_saliency_full))
|
||||
|
||||
for idx, union in enumerate(unions):
|
||||
saliency_dict = {}
|
||||
explain = Explain()
|
||||
explain.sample_id = self._count
|
||||
for k, lab in enumerate(union):
|
||||
saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
|
||||
saliency_dict[lab] = saliency
|
||||
|
||||
saliency_np = _normalize(saliency.asnumpy().squeeze())
|
||||
saliency_image = _np_to_image(saliency_np, mode='L')
|
||||
heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._count, saliency_image)
|
||||
|
||||
explanation = explain.explanation.add()
|
||||
explanation.explain_method = explainer.__class__.__name__
|
||||
explanation.heatmap_path = heatmap_path
|
||||
explanation.label = lab
|
||||
|
||||
summary.add_value("explainer", "explanation", explain)
|
||||
summary.record(1)
|
||||
|
||||
self._count += 1
|
||||
saliency_dict_lst.append(saliency_dict)
|
||||
return saliency_dict_lst
|
||||
|
||||
def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
|
||||
"""
|
||||
Run the explanation and evaluation for each step and write explanation results into summary.
|
||||
|
||||
Args:
|
||||
next_element (Tuple): Data of one step
|
||||
explainer (`_Attribution`): An Attribution object to generate saliency maps.
|
||||
"""
|
||||
inputs, labels, _ = self._unpack_next_element(next_element)
|
||||
for idx, inp in enumerate(inputs):
|
||||
inp = _EXPAND_DIMS(inp, 0)
|
||||
if isinstance(benchmarker, LabelAgnosticMetric):
|
||||
res = benchmarker.evaluate(explainer, inp)
|
||||
res[np.isnan(res)] = 0.0
|
||||
benchmarker.aggregate(res)
|
||||
else:
|
||||
saliency_dict = saliency_dict_lst[idx]
|
||||
for label, saliency in saliency_dict.items():
|
||||
if isinstance(benchmarker, Localization):
|
||||
_, _, bboxes = self._unpack_next_element(next_element, True)
|
||||
if label in labels[idx]:
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
|
||||
saliency=saliency)
|
||||
res[np.isnan(res)] = 0.0
|
||||
benchmarker.aggregate(res, label)
|
||||
elif isinstance(benchmarker, LabelSensitiveMetric):
|
||||
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
|
||||
res[np.isnan(res)] = 0.0
|
||||
benchmarker.aggregate(res, label)
|
||||
else:
|
||||
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
|
||||
'receive {}'.format(type(benchmarker)))
|
||||
|
||||
def _save_original_image(self, sample_id: int, image):
|
||||
"""Save an image to summary directory."""
|
||||
id_dirname = _get_id_dirname(sample_id)
|
||||
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
|
||||
_ORIGINAL_IMAGE_DIRNAME,
|
||||
id_dirname)
|
||||
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
|
||||
relative_path = os.path.join(relative_dir, f"{sample_id}.jpg")
|
||||
save_path = os.path.join(self._summary_dir, relative_path)
|
||||
with open(save_path, "wb") as file:
|
||||
image.save(file)
|
||||
return relative_path
|
||||
|
||||
def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image):
|
||||
"""Save heatmap image to summary directory."""
|
||||
id_dirname = _get_id_dirname(sample_id)
|
||||
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
|
||||
_HEATMAP_DIRNAME,
|
||||
explain_method,
|
||||
id_dirname)
|
||||
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
|
||||
relative_path = os.path.join(relative_dir, f"{sample_id}_{class_id}.jpg")
|
||||
save_path = os.path.join(self._summary_dir, relative_path)
|
||||
with open(save_path, "wb") as file:
|
||||
image.save(file, optimize=True)
|
||||
return relative_path
|
|
@ -128,6 +128,10 @@ class LabelSensitiveMetric(AttributionMetric):
|
|||
self._num_labels = num_labels
|
||||
self._global_results = {i: [] for i in range(num_labels)}
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
return self._num_labels
|
||||
|
||||
@staticmethod
|
||||
def _verify_params(num_labels):
|
||||
check_value_type("num_labels", num_labels, int)
|
||||
|
|
Loading…
Reference in New Issue