!9275 Reform explain runner api

From: @ngtony
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-03 09:57:02 +08:00 committed by Gitee
commit 4b539980bf
4 changed files with 706 additions and 665 deletions

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Provide ExplainRunner High-level API."""
"""Provides explanation runner high-level APIs."""
from ._runner import ExplainRunner
from ._image_classification_runner import ImageClassificationRunner
__all__ = ['ExplainRunner']
__all__ = ['ImageClassificationRunner']

View File

@ -0,0 +1,699 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Image Classification Runner."""
import os
import re
from time import time
import numpy as np
from PIL import Image
import mindspore as ms
import mindspore.dataset as ds
from mindspore import log
from mindspore.dataset.engine.datasets import Dataset
from mindspore.nn import Cell, SequentialCell
from mindspore.ops.operations import ExpandDims
from mindspore.train._utils import check_value_type
from mindspore.train.summary._summary_adapter import _convert_image_format
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.train.summary_pb2 import Explain
from .benchmark import Localization
from .explanation import RISE
from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric
from .explanation._attribution.attribution import Attribution
_EXPAND_DIMS = ExpandDims()
def _normalize(img_np):
"""Normalize the numpy image to the range of [0, 1]. """
max_ = img_np.max()
min_ = img_np.min()
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
return normed
def _np_to_image(img_np, mode):
"""Convert numpy array to PIL image."""
return Image.fromarray(np.uint8(img_np * 255), mode=mode)
class ImageClassificationRunner:
"""
A high-level API for users to generate and store results of the explanation methods and the evaluation methods.
Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version
will be deprecated and will not be supported in MindInsight of current version.
Args:
summary_dir (str): The directory path to save the summary files which store the generated results.
data (tuple[Dataset, list[str]]): Tuple of dataset and the corresponding class label list. The dataset
should provides [images], [images, labels] or [images, labels, bboxes] as columns. The label list must
share the exact same length and order of the network outputs.
network (Cell): The network(with logit outputs) to be explained.
activation_fn (Cell): The activation function for converting network's output to probabilities.
Examples:
>>> from mindspore.explainer import ImageClassificationRunner
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
>>> from mindspore.explainer.benchmark import Faithfulness
>>> from mindspore.nn import Softmax
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
>>> dataset = get_dataset('/path/to/Cifar10_dataset')
>>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck']
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
>>> param_dict = load_checkpoint("checkpoint.ckpt")
>>> net = resnet50(len(classes))
>>> activation_fn = Softmax()
>>> load_param_into_net(net, param_dict)
>>> gbp = GuidedBackprop(net)
>>> gradient = Gradient(net)
>>> explainers = [gbp, gradient]
>>> faithfulness = Faithfulness(len(labels), "NaiveFaithfulness", activation_fn)
>>> benchmarkers = [faithfulness]
>>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn)
>>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers)
>>> runner.run()
"""
# datafile directory names
_DATAFILE_DIRNAME_PREFIX = "_explain_"
_ORIGINAL_IMAGE_DIRNAME = "origin_images"
_HEATMAP_DIRNAME = "heatmap"
# max. no. of sample per directory
_SAMPLE_PER_DIR = 1000
# seed for fixing the iterating order of the dataset
_DATASET_SEED = 58
# printing spacer
_SPACER = "{:120}\r"
# file permission for writing files
_FILE_MODE = 0o600
def __init__(self,
summary_dir,
data,
network,
activation_fn):
check_value_type("data", data, tuple)
if len(data) != 2:
raise ValueError("Argument data is not a tuple with 2 elements")
check_value_type("data[0]", data[0], Dataset)
check_value_type("data[1]", data[1], list)
if not all(isinstance(ele, str) for ele in data[1]):
raise ValueError("Argument data[1] is not list of str.")
check_value_type("summary_dir", summary_dir, str)
check_value_type("network", network, Cell)
check_value_type("activation_fn", activation_fn, Cell)
self._summary_dir = summary_dir
self._dataset = data[0]
self._labels = data[1]
self._network = network
self._explainers = None
self._benchmarkers = None
self._summary_timestamp = None
self._sample_index = -1
self._full_network = SequentialCell([self._network, activation_fn])
self._verify_data_n_settings(check_data_n_network=True)
def register_saliency(self,
explainers,
benchmarkers=None):
"""
Register saliency explanation instances.
Note:
This function call not be invoked more then once on each runner.
Args:
explainers (list[Attribution]): The explainers to be evaluated,
see `mindspore.explainer.explanation`. All explainers' class must be distinct and their network
must be the exact same instance of the runner's network.
benchmarkers (list[AttributionMetric], optional): The benchmarkers for scoring the explainers,
see `mindspore.explainer.benchmark`. All benchmarkers' class must be distinct.
Raises:
ValueError: Be raised for any data or settings' value problem.
TypeError: Be raised for any data or settings' type problem.
RuntimeError: Be raised if this function was invoked before.
"""
check_value_type("explainers", explainers, list)
if not all(isinstance(ele, Attribution) for ele in explainers):
raise TypeError("Argument explainers is not list of mindspore.explainer.explanation .")
if not explainers:
raise ValueError("Argument explainers is empty.")
if benchmarkers:
check_value_type("benchmarkers", benchmarkers, list)
if not all(isinstance(ele, AttributionMetric) for ele in benchmarkers):
raise TypeError("Argument benchmarkers is not list of mindspore.explainer.benchmark .")
if self._explainers is not None:
raise RuntimeError("Function register_saliency() was invoked already.")
self._explainers = explainers
self._benchmarkers = benchmarkers
try:
self._verify_data_n_settings(check_saliency=True)
except (ValueError, TypeError):
self._explainers = None
self._benchmarkers = None
raise
def run(self):
"""
Run the explain job and save the result as a summary in summary_dir.
Note:
User should call register_saliency() once before running this function.
Raises:
ValueError: Be raised for any data or settings' value problem.
TypeError: Be raised for any data or settings' type problem.
RuntimeError: Be raised for any runtime problem.
"""
self._verify_data_n_settings(check_all=True)
with SummaryRecord(self._summary_dir) as summary:
print("Start running and writing......")
begin = time()
self._summary_timestamp = self._extract_timestamp(summary.event_file_name)
if self._summary_timestamp is None:
raise RuntimeError("Cannot extract timestamp from summary filename!"
" It should contains a timestamp after 'summary.' .")
self._save_metadata(summary)
imageid_labels = self._run_inference(summary)
if self._is_saliency_registered:
self._run_saliency(summary, imageid_labels)
print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))
@property
def _is_saliency_registered(self):
"""Check if saliency module is registered."""
return bool(self._explainers)
def _save_metadata(self, summary):
"""Save metadata of the explain job to summary."""
print("Start writing metadata......")
explain = Explain()
explain.metadata.label.extend(self._labels)
if self._is_saliency_registered:
exp_names = [exp.__class__.__name__ for exp in self._explainers]
explain.metadata.explain_method.extend(exp_names)
if self._benchmarkers is not None:
bench_names = [bench.__class__.__name__ for bench in self._benchmarkers]
explain.metadata.benchmark_method.extend(bench_names)
summary.add_value("explainer", "metadata", explain)
summary.record(1)
print("Finish writing metadata.")
def _run_inference(self, summary, threshold=0.5):
"""
Run inference for the dataset and write the inference related data into summary.
Args:
summary (SummaryRecord): The summary object to store the data
threshold (float): The threshold for prediction.
Returns:
dict, The map of sample d to the union of its ground truth and predicted labels.
"""
sample_id_labels = {}
self._sample_index = 0
ds.config.set_seed(self._DATASET_SEED)
for j, next_element in enumerate(self._dataset):
now = time()
inputs, labels, _ = self._unpack_next_element(next_element)
prob = self._full_network(inputs).asnumpy()
for idx, inp in enumerate(inputs):
gt_labels = labels[idx]
gt_probs = [float(prob[idx][i]) for i in gt_labels]
data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
original_image = _np_to_image(_normalize(data_np), mode='RGB')
original_image_path = self._save_original_image(self._sample_index, original_image)
predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
union_labs = list(set(gt_labels + predicted_labels))
sample_id_labels[str(self._sample_index)] = union_labs
explain = Explain()
explain.sample_id = self._sample_index
explain.image_path = original_image_path
summary.add_value("explainer", "sample", explain)
explain = Explain()
explain.sample_id = self._sample_index
explain.ground_truth_label.extend(gt_labels)
explain.inference.ground_truth_prob.extend(gt_probs)
explain.inference.predicted_label.extend(predicted_labels)
explain.inference.predicted_prob.extend(predicted_probs)
summary.add_value("explainer", "inference", explain)
summary.record(1)
self._sample_index += 1
self._spaced_print("Finish running and writing {}-th batch inference data."
" Time elapsed: {:.3f} s".format(j, time() - now),
end='')
return sample_id_labels
def _run_saliency(self, summary, sample_id_labels):
"""Run the saliency explanations."""
if self._benchmarkers is None or not self._benchmarkers:
for exp in self._explainers:
start = time()
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
self._sample_index = 0
ds.config.set_seed(self._DATASET_SEED)
for idx, next_element in enumerate(self._dataset):
now = time()
self._run_exp_step(next_element, exp, sample_id_labels, summary)
self._spaced_print("Finish writing {}-th explanation data for {}. Time elapsed: "
"{:.3f} s".format(idx, exp.__class__.__name__, time() - now), end='')
self._spaced_print(
"Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
exp.__class__.__name__, time() - start))
else:
for exp in self._explainers:
explain = Explain()
for bench in self._benchmarkers:
bench.reset()
print(f"Start running and writing explanation and "
f"benchmark data for {exp.__class__.__name__}......")
self._sample_index = 0
start = time()
ds.config.set_seed(self._DATASET_SEED)
for idx, next_element in enumerate(self._dataset):
now = time()
saliency_dict_lst = self._run_exp_step(next_element, exp, sample_id_labels, summary)
self._spaced_print(
"Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
idx, exp.__class__.__name__, time() - now), end='')
for bench in self._benchmarkers:
now = time()
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
self._spaced_print(
"Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
idx, bench.__class__.__name__, exp.__class__.__name__, time() - now), end='')
for bench in self._benchmarkers:
benchmark = explain.benchmark.add()
benchmark.explain_method = exp.__class__.__name__
benchmark.benchmark_method = bench.__class__.__name__
benchmark.total_score = bench.performance
if isinstance(bench, LabelSensitiveMetric):
benchmark.label_score.extend(bench.class_performances)
self._spaced_print("Finish running and writing explanation and benchmark data for {}. "
"Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start))
summary.add_value('explainer', 'benchmark', explain)
summary.record(1)
def _run_exp_step(self, next_element, explainer, sample_id_labels, summary):
"""
Run the explanation for each step and write explanation results into summary.
Args:
next_element (Tuple): Data of one step
explainer (_Attribution): An Attribution object to generate saliency maps.
sample_id_labels (dict): A dict that maps the sample id and its union labels.
summary (SummaryRecord): The summary object to store the data
Returns:
list, List of dict that maps label to its corresponding saliency map.
"""
inputs, labels, _ = self._unpack_next_element(next_element)
sample_index = self._sample_index
unions = []
for _ in range(len(labels)):
unions_labels = sample_id_labels[str(sample_index)]
unions.append(unions_labels)
sample_index += 1
batch_unions = self._make_label_batch(unions)
saliency_dict_lst = []
if isinstance(explainer, RISE):
batch_saliency_full = explainer(inputs, batch_unions)
else:
batch_saliency_full = []
for i in range(len(batch_unions[0])):
batch_saliency = explainer(inputs, batch_unions[:, i])
batch_saliency_full.append(batch_saliency)
concat = ms.ops.operations.Concat(1)
batch_saliency_full = concat(tuple(batch_saliency_full))
for idx, union in enumerate(unions):
saliency_dict = {}
explain = Explain()
explain.sample_id = self._sample_index
for k, lab in enumerate(union):
saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
saliency_dict[lab] = saliency
saliency_np = _normalize(saliency.asnumpy().squeeze())
saliency_image = _np_to_image(saliency_np, mode='L')
heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._sample_index, saliency_image)
explanation = explain.explanation.add()
explanation.explain_method = explainer.__class__.__name__
explanation.heatmap_path = heatmap_path
explanation.label = lab
summary.add_value("explainer", "explanation", explain)
summary.record(1)
self._sample_index += 1
saliency_dict_lst.append(saliency_dict)
return saliency_dict_lst
def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
"""Run the explanation and evaluation for each step and write explanation results into summary."""
inputs, labels, _ = self._unpack_next_element(next_element)
for idx, inp in enumerate(inputs):
inp = _EXPAND_DIMS(inp, 0)
saliency_dict = saliency_dict_lst[idx]
for label, saliency in saliency_dict.items():
if isinstance(benchmarker, Localization):
_, _, bboxes = self._unpack_next_element(next_element, True)
if label in labels[idx]:
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
saliency=saliency)
if np.any(res == np.nan):
res = np.zeros_like(res)
benchmarker.aggregate(res, label)
elif isinstance(benchmarker, LabelSensitiveMetric):
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
if np.any(res == np.nan):
res = np.zeros_like(res)
benchmarker.aggregate(res, label)
elif isinstance(benchmarker, LabelAgnosticMetric):
res = benchmarker.evaluate(explainer, inp)
if np.any(res == np.nan):
res = np.zeros_like(res)
benchmarker.aggregate(res)
else:
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
'receive {}'.format(type(benchmarker)))
def _verify_data(self):
"""Verify dataset and labels."""
next_element = next(self._dataset.create_tuple_iterator())
if len(next_element) not in [1, 2, 3]:
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
" as columns.")
if len(next_element) == 3:
inputs, labels, bboxes = next_element
if bboxes.shape[-1] != 4:
raise ValueError("The third element of dataset should be bounding boxes with shape of "
"[batch_size, num_ground_truth, 4].")
else:
if self._benchmarkers is not None:
if any([isinstance(bench, Localization) for bench in self._benchmarkers]):
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
if len(next_element) == 2:
inputs, labels = next_element
if len(next_element) == 1:
inputs = next_element[0]
if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
raise ValueError(
"Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
inputs.shape))
if len(inputs.shape) == 3:
log.warning(
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
" dimension as batch data.".format(inputs.shape))
if len(next_element) > 1:
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
raise ValueError(
"Labels shape {} is unrecognizable: outputs should not have more than two dimensions"
" with length greater than 1.".format(labels.shape))
def _verify_network(self):
"""Verify the network."""
label_set = set()
for i, label in enumerate(self._labels):
if label.strip() == "":
raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is "
f"no empty label.")
if label in label_set:
raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.")
label_set.add(label)
next_element = next(self._dataset.create_tuple_iterator())
inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._full_network(inputs)
check_value_type("output of network in explainer", prop_test, ms.Tensor)
if prop_test.shape[1] != len(self._labels):
raise ValueError("The dimension of network output does not match the no. of classes. Please "
"check labels or the network in the explainer again.")
def _verify_saliency(self):
"""Verify the saliency settings."""
if self._explainers:
explainer_classes = []
for explainer in self._explainers:
if explainer.__class__ in explainer_classes:
raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! "
"Please make sure all explainers' class is distinct.")
if explainer.model != self._network:
raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different "
"instance from network of runner. Please make sure they are the same "
"instance.")
explainer_classes.append(explainer.__class__)
if self._benchmarkers:
benchmarker_classes = []
for benchmarker in self._benchmarkers:
if benchmarker.__class__ in benchmarker_classes:
raise ValueError(f"Repeated {benchmarker.__class__.__name__} benchmarker! "
"Please make sure all benchmarkers' class is distinct.")
if isinstance(benchmarker, LabelSensitiveMetric) and benchmarker.num_labels != len(self._labels):
raise ValueError(f"The num_labels of {benchmarker.__class__.__name__} benchmarker is different "
"from no. of labels of runner. Please make them are the same.")
benchmarker_classes.append(benchmarker.__class__)
def _verify_data_n_settings(self,
check_all=False,
check_registration=False,
check_data_n_network=False,
check_saliency=False):
"""
Verify the validity of dataset and other settings.
Args:
check_all (bool): Set it True for checking everything.
check_registration (bool): Set it True for checking registrations, check if it is enough to invoke run().
check_data_n_network (bool): Set it True for checking data and network.
check_saliency (bool): Set it True for checking saliency related settings.
Raises:
ValueError: Be raised for any data or settings' value problem.
TypeError: Be raised for any data or settings' type problem.
"""
if check_all:
check_registration = True
check_data_n_network = True
check_saliency = True
if check_registration:
if not self._is_saliency_registered:
raise ValueError("No explanation module was registered, user should at least call register_saliency()"
" once with proper explanation instances")
if check_data_n_network or check_saliency:
self._verify_data()
if check_data_n_network:
self._verify_network()
if check_saliency:
self._verify_saliency()
def _transform_data(self, inputs, labels, bboxes, ifbbox):
"""
Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
Args:
inputs (Tensor): the image data
labels (Tensor): the labels
bboxes (Tensor): the boudnding boxes data
ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t
label id will be returned. If False, the returned bboxes is the the parsed bboxes.
Returns:
inputs (Tensor): the image data, unified to a 4D Tensor.
labels (list[list[int]]): the ground truth labels.
bboxes (Union[list[dict], None, Tensor]): the bounding boxes
"""
inputs = ms.Tensor(inputs, ms.float32)
if len(inputs.shape) == 3:
inputs = _EXPAND_DIMS(inputs, 0)
if isinstance(labels, ms.Tensor):
labels = ms.Tensor(labels, ms.int32)
labels = _EXPAND_DIMS(labels, 0)
if isinstance(bboxes, ms.Tensor):
bboxes = ms.Tensor(bboxes, ms.int32)
bboxes = _EXPAND_DIMS(bboxes, 0)
input_len = len(inputs)
if bboxes is not None and ifbbox:
bboxes = ms.Tensor(bboxes, ms.int32)
masks_lst = []
labels = labels.asnumpy().reshape([input_len, -1])
bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
for idx, label in enumerate(labels):
height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
masks = {}
for j, label_item in enumerate(label):
target = int(label_item)
if -1 < target < len(self._labels):
if target not in masks:
mask = np.zeros((1, 1, height, width))
else:
mask = masks[target]
x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
masks[target] = mask
masks_lst.append(masks)
bboxes = masks_lst
labels = ms.Tensor(labels, ms.int32)
if len(labels.shape) == 1:
labels_lst = [[int(i)] for i in labels.asnumpy()]
else:
labels = labels.asnumpy().reshape([input_len, -1])
labels_lst = []
for item in labels:
labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._labels))))
labels = labels_lst
return inputs, labels, bboxes
def _unpack_next_element(self, next_element, ifbbox=False):
"""
Unpack a single iteration of dataset.
Args:
next_element (Tuple): a single element iterated from dataset object.
ifbbox (bool): whether to preprocess bboxes in self._transform_data.
Returns:
tuple, a unified Tuple contains image_data, labels, and bounding boxes.
"""
if len(next_element) == 3:
inputs, labels, bboxes = next_element
elif len(next_element) == 2:
inputs, labels = next_element
bboxes = None
else:
inputs = next_element[0]
labels = [[] for _ in inputs]
bboxes = None
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
return inputs, labels, bboxes
@staticmethod
def _make_label_batch(labels):
"""
Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
length of all the rows in labels.
Args:
labels (List[List]): the union labels of a data batch.
Returns:
2D Tensor.
"""
max_len = max([len(label) for label in labels])
batch_labels = np.zeros((len(labels), max_len))
for idx, _ in enumerate(batch_labels):
length = len(labels[idx])
batch_labels[idx, :length] = np.array(labels[idx])
return ms.Tensor(batch_labels, ms.int32)
def _save_original_image(self, sample_id, image):
"""Save an image to summary directory."""
id_dirname = self._get_sample_dirname(sample_id)
relative_dir = os.path.join(self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
self._ORIGINAL_IMAGE_DIRNAME,
id_dirname)
abs_dir_path = os.path.abspath(os.path.join(self._summary_dir, relative_dir))
os.makedirs(abs_dir_path, mode=self._FILE_MODE, exist_ok=True)
filename = f"{sample_id}.jpg"
save_path = os.path.join(abs_dir_path, filename)
image.save(save_path)
os.chmod(save_path, self._FILE_MODE)
return os.path.join(relative_dir, filename)
def _save_heatmap(self, explain_method, class_id, sample_id, image):
"""Save heatmap image to summary directory."""
id_dirname = self._get_sample_dirname(sample_id)
relative_dir = os.path.join(self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
self._HEATMAP_DIRNAME,
explain_method,
id_dirname)
abs_dir_path = os.path.abspath(os.path.join(self._summary_dir, relative_dir))
os.makedirs(abs_dir_path, mode=self._FILE_MODE, exist_ok=True)
filename = f"{sample_id}_{class_id}.jpg"
save_path = os.path.join(abs_dir_path, filename)
image.save(save_path)
os.chmod(save_path, self._FILE_MODE)
return os.path.join(relative_dir, filename)
@classmethod
def _get_sample_dirname(cls, sample_id):
"""Get the name of parent directory of the image id."""
return str(int(sample_id / cls._SAMPLE_PER_DIR) * cls._SAMPLE_PER_DIR)
@staticmethod
def _extract_timestamp(filename):
"""Extract timestamp from summary filename."""
matched = re.search(r"summary\.(\d+)", filename)
if matched:
return int(matched.group(1))
return None
@classmethod
def _spaced_print(cls, message, *args, **kwargs):
"""Spaced message printing."""
print(cls._SPACER.format(message), *args, **kwargs)

View File

@ -1,662 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Runner."""
import os
import re
import traceback
from time import time
from typing import Tuple, List, Optional
import numpy as np
from PIL import Image
from scipy.stats import beta
import mindspore as ms
import mindspore.dataset as ds
from mindspore import log
from mindspore.nn import Softmax, Cell
from mindspore.nn.probability.toolbox import UncertaintyEvaluation
from mindspore.ops.operations import ExpandDims
from mindspore.train._utils import check_value_type
from mindspore.train.summary._summary_adapter import _convert_image_format
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.train.summary_pb2 import Explain
from .benchmark import Localization
from .explanation import RISE
from .benchmark._attribution.metric import AttributionMetric, LabelSensitiveMetric, LabelAgnosticMetric
from .explanation._attribution.attribution import Attribution
# datafile directory names
_DATAFILE_DIRNAME_PREFIX = "_explain_"
_ORIGINAL_IMAGE_DIRNAME = "origin_images"
_HEATMAP_DIRNAME = "heatmap"
# max. no. of sample per directory
_SAMPLE_PER_DIR = 1000
_EXPAND_DIMS = ExpandDims()
_SEED = 58 # set a seed to fix the iterating order of the dataset
def _normalize(img_np):
"""Normalize the numpy image to the range of [0, 1]. """
max_ = img_np.max()
min_ = img_np.min()
normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
return normed
def _np_to_image(img_np, mode):
"""Convert numpy array to PIL image."""
return Image.fromarray(np.uint8(img_np * 255), mode=mode)
def _calc_prob_interval(volume, probs, prob_vars):
"""Compute the confidence interval of probability."""
if not isinstance(probs, np.ndarray):
probs = np.asarray(probs)
if not isinstance(prob_vars, np.ndarray):
prob_vars = np.asarray(prob_vars)
one_minus_probs = 1 - probs
alpha_coef = (np.square(probs) * one_minus_probs / prob_vars) - probs
beta_coef = alpha_coef * one_minus_probs / probs
intervals = beta.interval(volume, alpha_coef, beta_coef)
# avoid invalid result due to extreme small value of prob_vars
lows = []
highs = []
for i, low in enumerate(intervals[0]):
high = intervals[1][i]
if prob_vars[i] <= 0 or \
not np.isfinite(low) or low > probs[i] or \
not np.isfinite(high) or high < probs[i]:
low = probs[i]
high = probs[i]
lows.append(low)
highs.append(high)
return lows, highs
def _get_id_dirname(sample_id: int):
"""Get the name of parent directory of the image id."""
return str(int(sample_id / _SAMPLE_PER_DIR) * _SAMPLE_PER_DIR)
def _extract_timestamp(filename: str):
"""Extract timestamp from summary filename."""
matched = re.search(r"summary\.(\d+)", filename)
if matched:
return int(matched.group(1))
return None
class ExplainRunner:
"""
A high-level API for users to generate and store results of the explanation methods and the evaluation methods.
After generating results with the explanation methods and the evaluation methods, the results will be written into
a specified file with `mindspore.summary.SummaryRecord`. The stored content can be viewed using MindInsight.
Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version
will be deprecated and will not be supported in MindInsight of current version.
Args:
summary_dir (str, optional): The directory path to save the summary files which store the generated results.
Default: "./"
Examples:
>>> from mindspore.explainer import ExplainRunner
>>> # init a runner with a specified directory
>>> summary_dir = "summary_dir"
>>> runner = ExplainRunner(summary_dir)
"""
def __init__(self, summary_dir: Optional[str] = "./"):
check_value_type("summary_dir", summary_dir, str)
self._summary_dir = summary_dir
self._count = 0
self._classes = None
self._model = None
self._uncertainty = None
self._summary_timestamp = None
def run(self,
dataset: Tuple,
explainers: List,
benchmarkers: Optional[List] = None,
uncertainty: Optional[UncertaintyEvaluation] = None,
activation_fn: Optional[Cell] = Softmax()):
"""
Genereates results and writes results into the summary files in `summary_dir` specified during the object
initialization.
Args:
dataset (tuple): A tuple that contains `mindspore.dataset` object for iteration and its labels.
- dataset[0]: A `mindspore.dataset` object to provide data to explain.
- dataset[1]: A list of string that specifies the label names of the dataset.
explainers (list[Explanation]): A list of explanation objects to generate attribution results. Explanation
object is an instance initialized with the explanation methods in module
`mindspore.explainer.explanation`.
benchmarkers (list[Benchmark], optional): A list of benchmark objects to generate evaluation results.
Default: None
uncertainty (UncertaintyEvaluation, optional): An uncertainty evaluation object to evaluate the inference
uncertainty of samples.
activation_fn (Cell, optional): The activation layer that transforms the output of the network to
label probability distribution :math:`P(y|x)`. Default: Softmax().
Examples:
>>> from mindspore.explainer import ExplainRunner
>>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
>>> from mindspore.nn import Softmax
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
>>> # Prepare the dataset for explaining and evaluation, e.g., Cifar10
>>> dataset = get_dataset('/path/to/Cifar10_dataset')
>>> classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'turck']
>>> # load checkpoint to a network, e.g. checkpoint of resnet50 trained on Cifar10
>>> param_dict = load_checkpoint("checkpoint.ckpt")
>>> net = resnet50(len(classes))
>>> load_param_into_net(net, param_dict)
>>> gbp = GuidedBackprop(net)
>>> gradient = Gradient(net)
>>> explainers = [gbp, gradient]
>>> # runner is an ExplainRunner object
>>> runner.run((dataset, classes), explainers, activation_fn=Softmax())
"""
check_value_type("dataset", dataset, tuple)
if len(dataset) != 2:
raise ValueError("Argument `dataset` should be a tuple with length = 2.")
dataset, classes = dataset
if benchmarkers is None:
benchmarkers = []
self._verify_data_form(dataset, benchmarkers)
self._classes = classes
check_value_type("explainers", explainers, list)
if not explainers:
raise ValueError("Argument `explainers` must be a non-empty list")
for exp in explainers:
if not isinstance(exp, Attribution):
raise TypeError("Argument `explainers` should be a list of objects of classes in "
"`mindspore.explainer.explanation`.")
if benchmarkers:
check_value_type("benchmarkers", benchmarkers, list)
for bench in benchmarkers:
if not isinstance(bench, AttributionMetric):
raise TypeError("Argument `benchmarkers` should be a list of objects of classes in explanation"
"`mindspore.explainer.benchmark`.")
check_value_type("activation_fn", activation_fn, Cell)
self._model = ms.nn.SequentialCell([explainers[0].model, activation_fn])
next_element = next(dataset.create_tuple_iterator())
inputs, _, _ = self._unpack_next_element(next_element)
prop_test = self._model(inputs)
check_value_type("output of model im explainer", prop_test, ms.Tensor)
if prop_test.shape[1] != len(self._classes):
raise ValueError("The dimension of model output does not match the length of dataset classes. Please "
"check dataset classes or the black-box model in the explainer again.")
if uncertainty is not None:
check_value_type("uncertainty", uncertainty, UncertaintyEvaluation)
prop_var_test = uncertainty.eval_epistemic_uncertainty(inputs)
check_value_type("output of uncertainty", prop_var_test, np.ndarray)
if prop_var_test.shape[1] != len(self._classes):
raise ValueError("The dimension of uncertainty output does not match the length of dataset classes"
"classes. Please check dataset classes or the black-box model in the explainer again.")
self._uncertainty = uncertainty
else:
self._uncertainty = None
with SummaryRecord(self._summary_dir) as summary:
spacer = '{:120}\r'
print("Start running and writing......")
begin = time()
print("Start writing metadata......")
self._summary_timestamp = _extract_timestamp(summary.event_file_name)
if self._summary_timestamp is None:
raise RuntimeError("Cannot extract timestamp from summary filename!"
" It should contains a timestamp of 10 digits.")
explain = Explain()
explain.metadata.label.extend(classes)
exp_names = [exp.__class__.__name__ for exp in explainers]
explain.metadata.explain_method.extend(exp_names)
if benchmarkers:
bench_names = [bench.__class__.__name__ for bench in benchmarkers]
explain.metadata.benchmark_method.extend(bench_names)
summary.add_value("explainer", "metadata", explain)
summary.record(1)
print("Finish writing metadata.")
now = time()
print("Start running and writing inference data.....")
imageid_labels = self._run_inference(dataset, summary)
print(spacer.format("Finish running and writing inference data. "
"Time elapsed: {:.3f} s".format(time() - now)))
if not benchmarkers:
for exp in explainers:
start = time()
print("Start running and writing explanation data for {}......".format(exp.__class__.__name__))
self._count = 0
ds.config.set_seed(_SEED)
for idx, next_element in enumerate(dataset):
now = time()
self._run_exp_step(next_element, exp, imageid_labels, summary)
print(spacer.format("Finish writing {}-th explanation data for {}. Time elapsed: "
"{:.3f} s".format(idx, exp.__class__.__name__, time() - now)), end='')
print(spacer.format(
"Finish running and writing explanation data for {}. Time elapsed: {:.3f} s".format(
exp.__class__.__name__, time() - start)))
else:
for exp in explainers:
explain = Explain()
for bench in benchmarkers:
bench.reset()
print(f"Start running and writing explanation and "
f"benchmark data for {exp.__class__.__name__}......")
self._count = 0
start = time()
ds.config.set_seed(_SEED)
for idx, next_element in enumerate(dataset):
now = time()
saliency_dict_lst = self._run_exp_step(next_element, exp, imageid_labels, summary)
print(spacer.format(
"Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
idx, exp.__class__.__name__, time() - now)), end='')
for bench in benchmarkers:
now = time()
self._run_exp_benchmark_step(next_element, exp, bench, saliency_dict_lst)
print(spacer.format(
"Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
idx, bench.__class__.__name__, exp.__class__.__name__, time() - now)), end='')
for bench in benchmarkers:
benchmark = explain.benchmark.add()
benchmark.explain_method = exp.__class__.__name__
benchmark.benchmark_method = bench.__class__.__name__
benchmark.total_score = bench.performance
if isinstance(bench, LabelSensitiveMetric):
benchmark.label_score.extend(bench.class_performances)
print(spacer.format("Finish running and writing explanation and benchmark data for {}. "
"Time elapsed: {:.3f} s".format(exp.__class__.__name__, time() - start)))
summary.add_value('explainer', 'benchmark', explain)
summary.record(1)
print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))
@staticmethod
def _verify_data_form(dataset, benchmarkers):
"""
Verify the validity of dataset.
Args:
dataset (`ds`): the user parsed dataset.
benchmarkers (list[`AttributionMetric`]): the user parsed benchmarkers.
"""
next_element = next(dataset.create_tuple_iterator())
if len(next_element) not in [1, 2, 3]:
raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
" as columns.")
if len(next_element) == 3:
inputs, labels, bboxes = next_element
if bboxes.shape[-1] != 4:
raise ValueError("The third element of dataset should be bounding boxes with shape of "
"[batch_size, num_ground_truth, 4].")
else:
if any(map(lambda benchmarker: isinstance(benchmarker, Localization), benchmarkers)):
raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
if len(next_element) == 2:
inputs, labels = next_element
if len(next_element) == 1:
inputs = next_element[0]
if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
raise ValueError(
"Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
inputs.shape))
if len(inputs.shape) == 3:
log.warning(
"Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
" dimension as batch data.".format(inputs.shape))
if len(next_element) > 1:
if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
raise ValueError(
"Labels shape {} is unrecognizable: labels should not have more than two dimensions"
" with length greater than 1.".format(labels.shape))
def _transform_data(self, inputs, labels, bboxes, ifbbox):
"""
Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
Args:
inputs (Tensor): the image data
labels (Tensor): the labels
bboxes (Tensor): the boudnding boxes data
ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t label
id will be returned. If False, the returned bboxes is the the parsed bboxes.
Returns:
inputs (Tensor): the image data, unified to a 4D Tensor.
labels (List[List[int]]): the ground truth labels.
bboxes (Union[List[Dict], None, Tensor]): the bounding boxes
"""
inputs = ms.Tensor(inputs, ms.float32)
if len(inputs.shape) == 3:
inputs = _EXPAND_DIMS(inputs, 0)
if isinstance(labels, ms.Tensor):
labels = ms.Tensor(labels, ms.int32)
labels = _EXPAND_DIMS(labels, 0)
if isinstance(bboxes, ms.Tensor):
bboxes = ms.Tensor(bboxes, ms.int32)
bboxes = _EXPAND_DIMS(bboxes, 0)
input_len = len(inputs)
if bboxes is not None and ifbbox:
bboxes = ms.Tensor(bboxes, ms.int32)
masks_lst = []
labels = labels.asnumpy().reshape([input_len, -1])
bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
for idx, label in enumerate(labels):
height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
masks = {}
for j, label_item in enumerate(label):
target = int(label_item)
if -1 < target < len(self._classes):
if target not in masks:
mask = np.zeros((1, 1, height, width))
else:
mask = masks[target]
x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
masks[target] = mask
masks_lst.append(masks)
bboxes = masks_lst
labels = ms.Tensor(labels, ms.int32)
if len(labels.shape) == 1:
labels_lst = [[int(i)] for i in labels.asnumpy()]
else:
labels = labels.asnumpy().reshape([input_len, -1])
labels_lst = []
for item in labels:
labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._classes))))
labels = labels_lst
return inputs, labels, bboxes
def _unpack_next_element(self, next_element, ifbbox=False):
"""
Unpack a single iteration of dataset.
Args:
next_element (Tuple): a single element iterated from dataset object.
ifbbox (bool): whether to preprocess bboxes in self._transform_data.
Returns:
Tuple, a unified Tuple contains image_data, labels, and bounding boxes.
"""
if len(next_element) == 3:
inputs, labels, bboxes = next_element
elif len(next_element) == 2:
inputs, labels = next_element
bboxes = None
else:
inputs = next_element[0]
labels = [[] for _ in inputs]
bboxes = None
inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
return inputs, labels, bboxes
@staticmethod
def _make_label_batch(labels):
"""
Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
length of all the rows in labels.
Args:
labels (List[List]): the union labels of a data batch.
Returns:
2D Tensor.
"""
max_len = max([len(label) for label in labels])
batch_labels = np.zeros((len(labels), max_len))
for idx, _ in enumerate(batch_labels):
length = len(labels[idx])
batch_labels[idx, :length] = np.array(labels[idx])
return ms.Tensor(batch_labels, ms.int32)
def _run_inference(self, dataset, summary, threshold=0.5):
"""
Run inference for the dataset and write the inference related data into summary.
Args:
dataset (`ds`): the parsed dataset
summary (`SummaryRecord`): the summary object to store the data
threshold (float): the threshold for prediction.
Returns:
imageid_labels (dict): a dict that maps image_id and the union of its ground truth and predicted labels.
"""
spacer = '{:120}\r'
imageid_labels = {}
ds.config.set_seed(_SEED)
self._count = 0
for j, next_element in enumerate(dataset):
now = time()
inputs, labels, _ = self._unpack_next_element(next_element)
prob = self._model(inputs).asnumpy()
if self._uncertainty is not None:
prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
prob_sd = np.sqrt(prob_var)
else:
prob_var = prob_sd = None
for idx, inp in enumerate(inputs):
gt_labels = labels[idx]
gt_probs = [float(prob[idx][i]) for i in gt_labels]
data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
original_image = _np_to_image(_normalize(data_np), mode='RGB')
original_image_path = self._save_original_image(self._count, original_image)
predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
has_uncertainty = False
gt_prob_sds = gt_prob_itl95_lows = gt_prob_itl95_his = None
predicted_prob_sds = predicted_prob_itl95_lows = predicted_prob_itl95_his = None
if prob_var is not None:
gt_prob_sds = [float(prob_sd[idx][i]) for i in gt_labels]
predicted_prob_sds = [float(prob_sd[idx][i]) for i in predicted_labels]
try:
gt_prob_itl95_lows, gt_prob_itl95_his = \
_calc_prob_interval(0.95, gt_probs, [float(prob_var[idx][i]) for i in gt_labels])
predicted_prob_itl95_lows, predicted_prob_itl95_his = \
_calc_prob_interval(0.95, predicted_probs, [float(prob_var[idx][i])
for i in predicted_labels])
has_uncertainty = True
except ValueError:
log.error(traceback.format_exc())
log.error("Error on calculating uncertainty")
union_labs = list(set(gt_labels + predicted_labels))
imageid_labels[str(self._count)] = union_labs
explain = Explain()
explain.sample_id = self._count
explain.image_path = original_image_path
summary.add_value("explainer", "sample", explain)
explain = Explain()
explain.sample_id = self._count
explain.ground_truth_label.extend(gt_labels)
explain.inference.ground_truth_prob.extend(gt_probs)
explain.inference.predicted_label.extend(predicted_labels)
explain.inference.predicted_prob.extend(predicted_probs)
if has_uncertainty:
explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
explain.inference.ground_truth_prob_itl95_low.extend(gt_prob_itl95_lows)
explain.inference.ground_truth_prob_itl95_hi.extend(gt_prob_itl95_his)
explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
explain.inference.predicted_prob_itl95_low.extend(predicted_prob_itl95_lows)
explain.inference.predicted_prob_itl95_hi.extend(predicted_prob_itl95_his)
summary.add_value("explainer", "inference", explain)
summary.record(1)
self._count += 1
print(spacer.format("Finish running and writing {}-th batch inference data."
" Time elapsed: {:.3f} s".format(j, time() - now)),
end='')
return imageid_labels
def _run_exp_step(self, next_element, explainer, imageid_labels, summary):
"""
Run the explanation for each step and write explanation results into summary.
Args:
next_element (Tuple): data of one step
explainer (_Attribution): an Attribution object to generate saliency maps.
imageid_labels (dict): a dict that maps the image_id and its union labels.
summary (SummaryRecord): the summary object to store the data
Returns:
List of dict that maps label to its corresponding saliency map.
"""
inputs, labels, _ = self._unpack_next_element(next_element)
count = self._count
unions = []
for _ in range(len(labels)):
unions_labels = imageid_labels[str(count)]
unions.append(unions_labels)
count += 1
batch_unions = self._make_label_batch(unions)
saliency_dict_lst = []
if isinstance(explainer, RISE):
batch_saliency_full = explainer(inputs, batch_unions)
else:
batch_saliency_full = []
for i in range(len(batch_unions[0])):
batch_saliency = explainer(inputs, batch_unions[:, i])
batch_saliency_full.append(batch_saliency)
concat = ms.ops.operations.Concat(1)
batch_saliency_full = concat(tuple(batch_saliency_full))
for idx, union in enumerate(unions):
saliency_dict = {}
explain = Explain()
explain.sample_id = self._count
for k, lab in enumerate(union):
saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
saliency_dict[lab] = saliency
saliency_np = _normalize(saliency.asnumpy().squeeze())
saliency_image = _np_to_image(saliency_np, mode='L')
heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, self._count, saliency_image)
explanation = explain.explanation.add()
explanation.explain_method = explainer.__class__.__name__
explanation.heatmap_path = heatmap_path
explanation.label = lab
summary.add_value("explainer", "explanation", explain)
summary.record(1)
self._count += 1
saliency_dict_lst.append(saliency_dict)
return saliency_dict_lst
def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
"""
Run the explanation and evaluation for each step and write explanation results into summary.
Args:
next_element (Tuple): Data of one step
explainer (`_Attribution`): An Attribution object to generate saliency maps.
"""
inputs, labels, _ = self._unpack_next_element(next_element)
for idx, inp in enumerate(inputs):
inp = _EXPAND_DIMS(inp, 0)
if isinstance(benchmarker, LabelAgnosticMetric):
res = benchmarker.evaluate(explainer, inp)
res[np.isnan(res)] = 0.0
benchmarker.aggregate(res)
else:
saliency_dict = saliency_dict_lst[idx]
for label, saliency in saliency_dict.items():
if isinstance(benchmarker, Localization):
_, _, bboxes = self._unpack_next_element(next_element, True)
if label in labels[idx]:
res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
saliency=saliency)
res[np.isnan(res)] = 0.0
benchmarker.aggregate(res, label)
elif isinstance(benchmarker, LabelSensitiveMetric):
res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
res[np.isnan(res)] = 0.0
benchmarker.aggregate(res, label)
else:
raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
'receive {}'.format(type(benchmarker)))
def _save_original_image(self, sample_id: int, image):
"""Save an image to summary directory."""
id_dirname = _get_id_dirname(sample_id)
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
_ORIGINAL_IMAGE_DIRNAME,
id_dirname)
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
relative_path = os.path.join(relative_dir, f"{sample_id}.jpg")
save_path = os.path.join(self._summary_dir, relative_path)
with open(save_path, "wb") as file:
image.save(file)
return relative_path
def _save_heatmap(self, explain_method: str, class_id: int, sample_id: int, image):
"""Save heatmap image to summary directory."""
id_dirname = _get_id_dirname(sample_id)
relative_dir = os.path.join(_DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
_HEATMAP_DIRNAME,
explain_method,
id_dirname)
os.makedirs(os.path.join(self._summary_dir, relative_dir), exist_ok=True)
relative_path = os.path.join(relative_dir, f"{sample_id}_{class_id}.jpg")
save_path = os.path.join(self._summary_dir, relative_path)
with open(save_path, "wb") as file:
image.save(file, optimize=True)
return relative_path

View File

@ -128,6 +128,10 @@ class LabelSensitiveMetric(AttributionMetric):
self._num_labels = num_labels
self._global_results = {i: [] for i in range(num_labels)}
@property
def num_labels(self):
return self._num_labels
@staticmethod
def _verify_params(num_labels):
check_value_type("num_labels", num_labels, int)