!7534 SummaryRecord support to record mindexplain data

Merge pull request !7534 from ougongchang/feature_mindexplain
This commit is contained in:
mindspore-ci-bot 2020-10-23 14:38:33 +08:00 committed by Gitee
commit 2af6313f53
8 changed files with 137 additions and 24 deletions

View File

@ -40,6 +40,8 @@ message Event {
// Summary data
Summary summary = 5;
Explain explain = 6;
}
}
@ -101,3 +103,50 @@ message Summary {
// Set of values for the summary.
repeated Value value = 1;
}
message Explain {
message Inference{
repeated float ground_truth_prob = 1;
repeated int32 predicted_label = 2;
repeated float predicted_prob = 3;
}
message Explanation{
optional string explain_method = 1;
optional int32 label = 2;
optional bytes heatmap = 3;
}
message Benchmark{
message TotalScore{
optional string benchmark_method = 1;
optional float score = 2;
}
message LabelScore{
repeated float score = 1;
optional string benchmark_method = 2;
}
optional string explain_method = 1;
repeated TotalScore total_score = 2;
repeated LabelScore label_score = 3;
}
message Metadata{
repeated string label = 1;
repeated string explain_method = 2;
repeated string benchmark_method = 3;
}
optional string image_id = 1; // The Metadata and image id must have one fill in
optional bytes image_data = 2;
repeated int32 ground_truth_label = 3;
optional Inference inference = 4;
repeated Explanation explanation = 5;
repeated Benchmark benchmark = 6;
optional Metadata metadata = 7;
optional string status = 8; // enum value: run, end
}

View File

@ -26,7 +26,7 @@ from mindspore import log as logger
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.train.summary.enum import PluginEnum, ModeEnum
from mindspore.train.summary.enums import PluginEnum, ModeEnum
from mindspore.train.callback import Callback, ModelCheckpoint
from mindspore.train import lineage_pb2
from mindspore.train.callback._dataset_graph import DatasetGraph

View File

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate the explain event which conform to proto format."""
import time
from ..summary_pb2 import Event, Explain
def check_explain_proto(explain):
"""
Package the explain event.
Args:
explain (Explain): The object of summary_pb2.Explain.
"""
if not isinstance(explain, Explain):
raise TypeError(f'Plugin explainer expects a {Explain.__name__} value.')
if not explain.image_id and not explain.metadata.label and not explain.benchmark:
raise ValueError(f'The Metadata and image id and benchmark must have one fill in.')
def package_explain_event(explain_str):
"""
Package the explain event.
Args:
explain_str (string): The serialize string of summary_pb2.Explain.
Returns:
Event, event object.
"""
event = Event()
event.wall_time = time.time()
event.explain.ParseFromString(explain_str)
return event.SerializeToString()

View File

@ -21,7 +21,8 @@ import mindspore.log as logger
from ._lineage_adapter import serialize_to_lineage_event
from ._summary_adapter import package_graph_event, package_summary_event
from ._summary_writer import LineageWriter, SummaryWriter
from ._explain_adapter import package_explain_event
from .writer import LineageWriter, SummaryWriter, ExplainWriter
try:
from multiprocessing import get_context
@ -42,6 +43,8 @@ def _pack_data(datadict, wall_time):
elif plugin in ('scalar', 'tensor', 'histogram', 'image'):
summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')})
step = data.get('step')
elif plugin == 'explainer':
result.append([plugin, package_explain_event(data.get('value'))])
if summaries:
result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()])
return result
@ -98,6 +101,8 @@ class WriterPool(ctx.Process):
self._writers_.append(SummaryWriter(filepath, self._max_file_size))
elif plugin == 'lineage':
self._writers_.append(LineageWriter(filepath, self._max_file_size))
elif plugin == 'explainer':
self._writers_.append(ExplainWriter(filepath, self._max_file_size))
return self._writers_
def _write(self, plugin, data):
@ -125,7 +130,6 @@ class WriterPool(ctx.Process):
Write the event to file.
Args:
name (str): The key of a specified file.
data (Optional[str, Tuple[list, int]]): The data to write.
"""
self._queue.put(('WRITE', data))

View File

@ -17,6 +17,7 @@ import atexit
import os
import re
import threading
from collections import defaultdict
from mindspore import log as logger
@ -24,6 +25,7 @@ from ..._c_expression import Tensor
from ..._checkparam import Validator
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory
from ._summary_adapter import get_event_file_name, package_graph_event
from ._explain_adapter import check_explain_proto
from ._writer_pool import WriterPool
# for the moment, this lock is for caution's sake,
@ -55,7 +57,6 @@ def _get_summary_tensor_data():
def _dictlist():
from collections import defaultdict
return defaultdict(list)
@ -133,7 +134,8 @@ class SummaryRecord:
self._event_writer = WriterPool(log_dir,
max_file_size,
summary=self.full_file_name,
lineage=get_event_file_name(self.prefix, '_lineage'))
lineage=get_event_file_name(self.prefix, '_lineage'),
explainer=get_event_file_name(self.prefix, '_explain'))
_get_summary_tensor_data()
atexit.register(self.close)
@ -149,10 +151,11 @@ class SummaryRecord:
def set_mode(self, mode):
"""
Set the mode for the recorder to be aware. The mode is set to 'train' by default.
Sets the training phase. Different training phases affect data recording.
Args:
mode (str): The mode to be set, which should be 'train' or 'eval'.
mode (str): The mode to be set, which should be 'train' or 'eval'. When the mode is 'eval',
summary_record will not record the data of summary operators.
Raises:
ValueError: When the mode is not recognized.
@ -170,29 +173,26 @@ class SummaryRecord:
"""
Add value to be recorded later.
When the plugin is 'tensor', 'scalar', 'image' or 'histogram',
the name should be the tag name, and the value should be a Tensor.
When the plugin is 'graph', the value should be a GraphProto.
When the plugin is 'dataset_graph', 'train_lineage', 'eval_lineage',
or 'custom_lineage_data', the value should be a proto message.
Args:
plugin (str): The value of the plugin.
name (str): The value of the name.
value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \
The value to store.
- The data type of value should be 'GraphProto' when the plugin is 'graph'.
- The data type of value should be 'Tensor' when the plugin is 'scalar', 'image', 'tensor'
- The data type of value should be 'GraphProto' (see mindspore/ccsrc/anf_ir.proto) object
when the plugin is 'graph'.
- The data type of value should be 'Tensor' object when the plugin is 'scalar', 'image', 'tensor'
or 'histogram'.
- The data type of value should be 'TrainLineage' when the plugin is 'train_lineage'.
- The data type of value should be 'EvaluationLineage' when the plugin is 'eval_lineage'.
- The data type of value should be 'DatasetGraph' when the plugin is 'dataset_graph'.
- The data type of value should be 'UserDefinedInfo' when the plugin is 'custom_lineage_data'.
- The data type of value should be a 'TrainLineage' object when the plugin is 'train_lineage',
see mindspore/ccsrc/lineage.proto.
- The data type of value should be a 'EvaluationLineage' object when the plugin is 'eval_lineage',
see mindspore/ccsrc/lineage.proto.
- The data type of value should be a 'DatasetGraph' object when the plugin is 'dataset_graph',
see mindspore/ccsrc/lineage.proto.
- The data type of value should be a 'UserDefinedInfo' object when the plugin is 'custom_lineage_data',
see mindspore/ccsrc/lineage.proto.
- The data type of value should be a 'Explain' object when the plugin is 'explainer',
see mindspore/ccsrc/summary.proto.
Raises:
ValueError: When the name is not valid.
TypeError: When the value is not a Tensor.
@ -218,6 +218,9 @@ class SummaryRecord:
elif plugin == 'graph':
package_graph_event(value)
self._data_pool[plugin].append(dict(value=value))
elif plugin == 'explainer':
check_explain_proto(value)
self._data_pool[plugin].append(dict(value=value.SerializeToString()))
else:
raise ValueError(f'No such plugin of {repr(plugin)}')

View File

@ -94,3 +94,12 @@ class LineageWriter(BaseWriter):
"""Write data to file."""
if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'):
super().write(plugin, data)
class ExplainWriter(BaseWriter):
"""ExplainWriter for write explain data."""
def write(self, plugin, data):
"""Write data to file."""
if plugin == 'explainer':
super().write(plugin, data)

View File

@ -26,7 +26,7 @@ from mindspore import Tensor
from mindspore import Parameter
from mindspore.train.callback import SummaryCollector
from mindspore.train.callback import _InternalCallbackParam
from mindspore.train.summary.enum import ModeEnum, PluginEnum
from mindspore.train.summary.enums import ModeEnum, PluginEnum
from mindspore.train.summary import SummaryRecord
from mindspore.nn import Cell
from mindspore.nn.optim.optimizer import Optimizer