forked from mindspore-Ecosystem/mindspore
!7534 SummaryRecord support to record mindexplain data
Merge pull request !7534 from ougongchang/feature_mindexplain
This commit is contained in:
commit
2af6313f53
|
@ -40,6 +40,8 @@ message Event {
|
|||
|
||||
// Summary data
|
||||
Summary summary = 5;
|
||||
|
||||
Explain explain = 6;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,3 +103,50 @@ message Summary {
|
|||
// Set of values for the summary.
|
||||
repeated Value value = 1;
|
||||
}
|
||||
|
||||
message Explain {
|
||||
message Inference{
|
||||
repeated float ground_truth_prob = 1;
|
||||
repeated int32 predicted_label = 2;
|
||||
repeated float predicted_prob = 3;
|
||||
}
|
||||
|
||||
message Explanation{
|
||||
optional string explain_method = 1;
|
||||
optional int32 label = 2;
|
||||
optional bytes heatmap = 3;
|
||||
}
|
||||
|
||||
message Benchmark{
|
||||
message TotalScore{
|
||||
optional string benchmark_method = 1;
|
||||
optional float score = 2;
|
||||
}
|
||||
message LabelScore{
|
||||
repeated float score = 1;
|
||||
optional string benchmark_method = 2;
|
||||
}
|
||||
|
||||
optional string explain_method = 1;
|
||||
repeated TotalScore total_score = 2;
|
||||
repeated LabelScore label_score = 3;
|
||||
}
|
||||
|
||||
message Metadata{
|
||||
repeated string label = 1;
|
||||
repeated string explain_method = 2;
|
||||
repeated string benchmark_method = 3;
|
||||
}
|
||||
|
||||
optional string image_id = 1; // The Metadata and image id must have one fill in
|
||||
optional bytes image_data = 2;
|
||||
repeated int32 ground_truth_label = 3;
|
||||
|
||||
|
||||
optional Inference inference = 4;
|
||||
repeated Explanation explanation = 5;
|
||||
repeated Benchmark benchmark = 6;
|
||||
|
||||
optional Metadata metadata = 7;
|
||||
optional string status = 8; // enum value: run, end
|
||||
}
|
|
@ -26,7 +26,7 @@ from mindspore import log as logger
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from mindspore.train.summary.enum import PluginEnum, ModeEnum
|
||||
from mindspore.train.summary.enums import PluginEnum, ModeEnum
|
||||
from mindspore.train.callback import Callback, ModelCheckpoint
|
||||
from mindspore.train import lineage_pb2
|
||||
from mindspore.train.callback._dataset_graph import DatasetGraph
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Generate the explain event which conform to proto format."""
|
||||
import time
|
||||
|
||||
from ..summary_pb2 import Event, Explain
|
||||
|
||||
|
||||
def check_explain_proto(explain):
|
||||
"""
|
||||
Package the explain event.
|
||||
|
||||
Args:
|
||||
explain (Explain): The object of summary_pb2.Explain.
|
||||
"""
|
||||
if not isinstance(explain, Explain):
|
||||
raise TypeError(f'Plugin explainer expects a {Explain.__name__} value.')
|
||||
|
||||
if not explain.image_id and not explain.metadata.label and not explain.benchmark:
|
||||
raise ValueError(f'The Metadata and image id and benchmark must have one fill in.')
|
||||
|
||||
|
||||
def package_explain_event(explain_str):
|
||||
"""
|
||||
Package the explain event.
|
||||
|
||||
Args:
|
||||
explain_str (string): The serialize string of summary_pb2.Explain.
|
||||
|
||||
Returns:
|
||||
Event, event object.
|
||||
"""
|
||||
event = Event()
|
||||
event.wall_time = time.time()
|
||||
event.explain.ParseFromString(explain_str)
|
||||
return event.SerializeToString()
|
|
@ -21,7 +21,8 @@ import mindspore.log as logger
|
|||
|
||||
from ._lineage_adapter import serialize_to_lineage_event
|
||||
from ._summary_adapter import package_graph_event, package_summary_event
|
||||
from ._summary_writer import LineageWriter, SummaryWriter
|
||||
from ._explain_adapter import package_explain_event
|
||||
from .writer import LineageWriter, SummaryWriter, ExplainWriter
|
||||
|
||||
try:
|
||||
from multiprocessing import get_context
|
||||
|
@ -42,6 +43,8 @@ def _pack_data(datadict, wall_time):
|
|||
elif plugin in ('scalar', 'tensor', 'histogram', 'image'):
|
||||
summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')})
|
||||
step = data.get('step')
|
||||
elif plugin == 'explainer':
|
||||
result.append([plugin, package_explain_event(data.get('value'))])
|
||||
if summaries:
|
||||
result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()])
|
||||
return result
|
||||
|
@ -98,6 +101,8 @@ class WriterPool(ctx.Process):
|
|||
self._writers_.append(SummaryWriter(filepath, self._max_file_size))
|
||||
elif plugin == 'lineage':
|
||||
self._writers_.append(LineageWriter(filepath, self._max_file_size))
|
||||
elif plugin == 'explainer':
|
||||
self._writers_.append(ExplainWriter(filepath, self._max_file_size))
|
||||
return self._writers_
|
||||
|
||||
def _write(self, plugin, data):
|
||||
|
@ -125,7 +130,6 @@ class WriterPool(ctx.Process):
|
|||
Write the event to file.
|
||||
|
||||
Args:
|
||||
name (str): The key of a specified file.
|
||||
data (Optional[str, Tuple[list, int]]): The data to write.
|
||||
"""
|
||||
self._queue.put(('WRITE', data))
|
||||
|
|
|
@ -17,6 +17,7 @@ import atexit
|
|||
import os
|
||||
import re
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
from mindspore import log as logger
|
||||
|
||||
|
@ -24,6 +25,7 @@ from ..._c_expression import Tensor
|
|||
from ..._checkparam import Validator
|
||||
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory
|
||||
from ._summary_adapter import get_event_file_name, package_graph_event
|
||||
from ._explain_adapter import check_explain_proto
|
||||
from ._writer_pool import WriterPool
|
||||
|
||||
# for the moment, this lock is for caution's sake,
|
||||
|
@ -55,7 +57,6 @@ def _get_summary_tensor_data():
|
|||
|
||||
|
||||
def _dictlist():
|
||||
from collections import defaultdict
|
||||
return defaultdict(list)
|
||||
|
||||
|
||||
|
@ -133,7 +134,8 @@ class SummaryRecord:
|
|||
self._event_writer = WriterPool(log_dir,
|
||||
max_file_size,
|
||||
summary=self.full_file_name,
|
||||
lineage=get_event_file_name(self.prefix, '_lineage'))
|
||||
lineage=get_event_file_name(self.prefix, '_lineage'),
|
||||
explainer=get_event_file_name(self.prefix, '_explain'))
|
||||
_get_summary_tensor_data()
|
||||
atexit.register(self.close)
|
||||
|
||||
|
@ -149,10 +151,11 @@ class SummaryRecord:
|
|||
|
||||
def set_mode(self, mode):
|
||||
"""
|
||||
Set the mode for the recorder to be aware. The mode is set to 'train' by default.
|
||||
Sets the training phase. Different training phases affect data recording.
|
||||
|
||||
Args:
|
||||
mode (str): The mode to be set, which should be 'train' or 'eval'.
|
||||
mode (str): The mode to be set, which should be 'train' or 'eval'. When the mode is 'eval',
|
||||
summary_record will not record the data of summary operators.
|
||||
|
||||
Raises:
|
||||
ValueError: When the mode is not recognized.
|
||||
|
@ -170,29 +173,26 @@ class SummaryRecord:
|
|||
"""
|
||||
Add value to be recorded later.
|
||||
|
||||
When the plugin is 'tensor', 'scalar', 'image' or 'histogram',
|
||||
the name should be the tag name, and the value should be a Tensor.
|
||||
|
||||
When the plugin is 'graph', the value should be a GraphProto.
|
||||
|
||||
When the plugin is 'dataset_graph', 'train_lineage', 'eval_lineage',
|
||||
or 'custom_lineage_data', the value should be a proto message.
|
||||
|
||||
|
||||
Args:
|
||||
plugin (str): The value of the plugin.
|
||||
name (str): The value of the name.
|
||||
value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \
|
||||
The value to store.
|
||||
|
||||
- The data type of value should be 'GraphProto' when the plugin is 'graph'.
|
||||
- The data type of value should be 'Tensor' when the plugin is 'scalar', 'image', 'tensor'
|
||||
- The data type of value should be 'GraphProto' (see mindspore/ccsrc/anf_ir.proto) object
|
||||
when the plugin is 'graph'.
|
||||
- The data type of value should be 'Tensor' object when the plugin is 'scalar', 'image', 'tensor'
|
||||
or 'histogram'.
|
||||
- The data type of value should be 'TrainLineage' when the plugin is 'train_lineage'.
|
||||
- The data type of value should be 'EvaluationLineage' when the plugin is 'eval_lineage'.
|
||||
- The data type of value should be 'DatasetGraph' when the plugin is 'dataset_graph'.
|
||||
- The data type of value should be 'UserDefinedInfo' when the plugin is 'custom_lineage_data'.
|
||||
|
||||
- The data type of value should be a 'TrainLineage' object when the plugin is 'train_lineage',
|
||||
see mindspore/ccsrc/lineage.proto.
|
||||
- The data type of value should be a 'EvaluationLineage' object when the plugin is 'eval_lineage',
|
||||
see mindspore/ccsrc/lineage.proto.
|
||||
- The data type of value should be a 'DatasetGraph' object when the plugin is 'dataset_graph',
|
||||
see mindspore/ccsrc/lineage.proto.
|
||||
- The data type of value should be a 'UserDefinedInfo' object when the plugin is 'custom_lineage_data',
|
||||
see mindspore/ccsrc/lineage.proto.
|
||||
- The data type of value should be a 'Explain' object when the plugin is 'explainer',
|
||||
see mindspore/ccsrc/summary.proto.
|
||||
Raises:
|
||||
ValueError: When the name is not valid.
|
||||
TypeError: When the value is not a Tensor.
|
||||
|
@ -218,6 +218,9 @@ class SummaryRecord:
|
|||
elif plugin == 'graph':
|
||||
package_graph_event(value)
|
||||
self._data_pool[plugin].append(dict(value=value))
|
||||
elif plugin == 'explainer':
|
||||
check_explain_proto(value)
|
||||
self._data_pool[plugin].append(dict(value=value.SerializeToString()))
|
||||
else:
|
||||
raise ValueError(f'No such plugin of {repr(plugin)}')
|
||||
|
||||
|
|
|
@ -94,3 +94,12 @@ class LineageWriter(BaseWriter):
|
|||
"""Write data to file."""
|
||||
if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'):
|
||||
super().write(plugin, data)
|
||||
|
||||
|
||||
class ExplainWriter(BaseWriter):
|
||||
"""ExplainWriter for write explain data."""
|
||||
|
||||
def write(self, plugin, data):
|
||||
"""Write data to file."""
|
||||
if plugin == 'explainer':
|
||||
super().write(plugin, data)
|
|
@ -26,7 +26,7 @@ from mindspore import Tensor
|
|||
from mindspore import Parameter
|
||||
from mindspore.train.callback import SummaryCollector
|
||||
from mindspore.train.callback import _InternalCallbackParam
|
||||
from mindspore.train.summary.enum import ModeEnum, PluginEnum
|
||||
from mindspore.train.summary.enums import ModeEnum, PluginEnum
|
||||
from mindspore.train.summary import SummaryRecord
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
|
Loading…
Reference in New Issue