forked from mindspore-Ecosystem/mindspore
enhance the SummaryRecord with set_mode and add_value
This commit is contained in:
parent
19e66f06e2
commit
0921c1e538
|
@ -79,6 +79,7 @@ if (ENABLE_DUMP_PROTO)
|
|||
file(GLOB_RECURSE PROTO_PY RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"utils/anf_ir.proto"
|
||||
"utils/summary.proto"
|
||||
"utils/lineage.proto"
|
||||
"utils/checkpoint.proto"
|
||||
)
|
||||
ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY})
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright 2020 Huawei Technologies Co., Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mindspore.irpb;
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
|
||||
// Event Protocol buffer, Top define
|
||||
message LineageEvent {
|
||||
// Timestamp
|
||||
required double wall_time = 1;
|
||||
|
||||
// The step of train.
|
||||
optional int64 step = 2;
|
||||
|
||||
oneof what {
|
||||
// An event file was started, with the specified version.
|
||||
// Now version is "Mindspore.Event:1"
|
||||
string version = 3;
|
||||
|
||||
// Train lineage
|
||||
TrainLineage train_lineage = 6;
|
||||
|
||||
// Evaluation lineage
|
||||
EvaluationLineage evaluation_lineage = 7;
|
||||
|
||||
// Dataset graph
|
||||
DatasetGraph dataset_graph = 9;
|
||||
|
||||
// User defined info
|
||||
UserDefinedInfo user_defined_info = 10;
|
||||
}
|
||||
}
|
||||
|
||||
// User defined info
|
||||
message UserDefinedInfo{
|
||||
// repeated user defined info
|
||||
repeated UserDefinedInfo user_info = 1;
|
||||
|
||||
// key/value which contains both scalar and dict
|
||||
map<string, UserDefinedInfo> map_dict = 2;
|
||||
map<string, int32> map_int32 = 3;
|
||||
map<string, string> map_str = 4;
|
||||
map<string, double> map_double = 5;
|
||||
}
|
||||
|
||||
// TrainLineage records infos of a train.
|
||||
message TrainLineage{
|
||||
message HyperParameters{
|
||||
optional string optimizer = 1;
|
||||
optional float learning_rate = 2;
|
||||
optional string loss_function = 3;
|
||||
optional int32 epoch = 4;
|
||||
optional string parallel_mode = 5;
|
||||
optional int32 device_num = 6;
|
||||
optional int32 batch_size = 8;
|
||||
}
|
||||
|
||||
message TrainDataset{
|
||||
optional string train_dataset_path = 1;
|
||||
optional int32 train_dataset_size = 2;
|
||||
}
|
||||
|
||||
message Algorithm{
|
||||
optional string network = 1;
|
||||
optional float loss = 2;
|
||||
}
|
||||
|
||||
message Model{
|
||||
optional string path = 3;
|
||||
optional int64 size = 4;
|
||||
}
|
||||
|
||||
optional HyperParameters hyper_parameters = 1;
|
||||
optional TrainDataset train_dataset = 2;
|
||||
optional Algorithm algorithm = 3;
|
||||
optional Model model = 4;
|
||||
}
|
||||
|
||||
//EvalLineage records infos of evaluation.
|
||||
message EvaluationLineage{
|
||||
message ValidDataset{
|
||||
optional string valid_dataset_path = 1;
|
||||
optional int32 valid_dataset_size = 2;
|
||||
}
|
||||
|
||||
optional string metric = 2;
|
||||
optional ValidDataset valid_dataset = 3;
|
||||
}
|
||||
|
||||
|
||||
// DatasetGraph
|
||||
message DatasetGraph {
|
||||
repeated DatasetGraph children = 1;
|
||||
optional OperationParameter parameter = 2;
|
||||
repeated Operation operations = 3;
|
||||
optional Operation sampler = 4;
|
||||
}
|
||||
|
||||
message Operation {
|
||||
optional OperationParameter operationParam = 1;
|
||||
repeated int32 size = 2;
|
||||
repeated float weights = 3;
|
||||
}
|
||||
|
||||
message OperationParameter{
|
||||
map<string, string> mapStr = 1;
|
||||
map<string, StrList> mapStrList = 2;
|
||||
map<string, bool> mapBool = 3;
|
||||
map<string, int32> mapInt = 4;
|
||||
map<string, double> mapDouble = 5;
|
||||
}
|
||||
|
||||
message StrList {
|
||||
repeated string strValue = 1;
|
||||
}
|
|
@ -22,6 +22,7 @@ from mindspore import log as logger
|
|||
from mindspore.common.api import _executor
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
|
||||
from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
|
||||
|
||||
def _convert_type(types):
|
||||
"""
|
||||
|
@ -196,3 +197,38 @@ def _to_full_shapes(shapes, device_num):
|
|||
new_shape += (item,)
|
||||
new_shapes.append(new_shape)
|
||||
return new_shapes
|
||||
|
||||
|
||||
def _check_to_numpy(plugin, tensor):
|
||||
"""Check the tensor and return a numpy.ndarray."""
|
||||
np_value = tensor.asnumpy()
|
||||
if plugin == 'scalar':
|
||||
if np_value.size == 1:
|
||||
return np_value
|
||||
raise ValueError('The tensor holds more than one value, but the scalar plugin expects on value.')
|
||||
if plugin == 'image':
|
||||
if np_value.ndim == 4:
|
||||
return np_value
|
||||
raise ValueError('The tensor seems not to hold a valid image.')
|
||||
if plugin in ('tensor', 'histogram'):
|
||||
if np_value.ndim > 0:
|
||||
return np_value
|
||||
raise ValueError('The tensor should not be empty.')
|
||||
return np_value
|
||||
|
||||
def _check_lineage_value(plugin, value):
|
||||
"""Check the lineage value."""
|
||||
def raises(plugin, prototype):
|
||||
raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.')
|
||||
|
||||
if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph):
|
||||
raises(plugin, DatasetGraph)
|
||||
|
||||
if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage):
|
||||
raises(plugin, EvaluationLineage)
|
||||
|
||||
if plugin == 'train_lineage' and not isinstance(value, TrainLineage):
|
||||
raises(plugin, TrainLineage)
|
||||
|
||||
if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo):
|
||||
raises(plugin, UserDefinedInfo)
|
||||
|
|
|
@ -1,88 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Writes events to disk in a logdir."""
|
||||
import os
|
||||
import stat
|
||||
from collections import deque
|
||||
from multiprocessing import Pool, Process, Queue, cpu_count
|
||||
|
||||
from ..._c_expression import EventWriter_
|
||||
from ._summary_adapter import package_summary_event
|
||||
|
||||
|
||||
def _pack(result, step):
|
||||
summary_event = package_summary_event(result, step)
|
||||
return summary_event.SerializeToString()
|
||||
|
||||
|
||||
class EventWriter(Process):
|
||||
"""
|
||||
Creates a `EventWriter` and write event to file.
|
||||
|
||||
Args:
|
||||
filepath (str): Summary event file path and file name.
|
||||
flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120.
|
||||
"""
|
||||
|
||||
def __init__(self, filepath: str, flush_interval: int) -> None:
|
||||
super().__init__()
|
||||
_ = flush_interval
|
||||
with open(filepath, 'w'):
|
||||
os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR)
|
||||
self._writer = EventWriter_(filepath)
|
||||
self._queue = Queue(cpu_count() * 2)
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
|
||||
with Pool(min(cpu_count(), 32)) as pool:
|
||||
deq = deque()
|
||||
while True:
|
||||
while deq and deq[0].ready():
|
||||
self._writer.Write(deq.popleft().get())
|
||||
|
||||
if not self._queue.empty():
|
||||
action, data = self._queue.get()
|
||||
if action == 'WRITE':
|
||||
if not isinstance(data, (str, bytes)):
|
||||
deq.append(pool.apply_async(_pack, data))
|
||||
else:
|
||||
self._writer.Write(data)
|
||||
elif action == 'FLUSH':
|
||||
self._writer.Flush()
|
||||
elif action == 'END':
|
||||
break
|
||||
for res in deq:
|
||||
self._writer.Write(res.get())
|
||||
|
||||
self._writer.Shut()
|
||||
|
||||
def write(self, data) -> None:
|
||||
"""
|
||||
Write the event to file.
|
||||
|
||||
Args:
|
||||
data (Optional[str, Tuple[list, int]]): The data to write.
|
||||
"""
|
||||
self._queue.put(('WRITE', data))
|
||||
|
||||
def flush(self):
|
||||
"""Flush the writer."""
|
||||
self._queue.put(('FLUSH', None))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the writer."""
|
||||
self._queue.put(('END', None))
|
||||
self.join()
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Generate the lineage event which conform to proto format."""
|
||||
import time
|
||||
|
||||
from ..lineage_pb2 import LineageEvent
|
||||
|
||||
|
||||
def serialize_to_lineage_event(name, value):
|
||||
"""Serialize value to lineage event."""
|
||||
event = LineageEvent()
|
||||
event.wall_time = time.time()
|
||||
content = _get_lineage_content(name, event)
|
||||
content.ParseFromString(value)
|
||||
return event.SerializeToString()
|
||||
|
||||
|
||||
def _get_lineage_content(name, event):
|
||||
if name == 'dataset_graph':
|
||||
return event.dataset_graph
|
||||
if name == 'eval_lineage':
|
||||
return event.evaluation_lineage
|
||||
if name == 'train_lineage':
|
||||
return event.train_lineage
|
||||
if name == 'custom_lineage_data':
|
||||
return event.user_defined_info
|
||||
raise KeyError(f'No such field in LineageEvent')
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Generate the summary event which conform to proto format."""
|
||||
import socket
|
||||
import platform
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
@ -51,7 +51,7 @@ def get_event_file_name(prefix, suffix):
|
|||
_check_str_by_regular(suffix)
|
||||
file_name = ""
|
||||
time_second = str(int(time.time()))
|
||||
hostname = socket.gethostname()
|
||||
hostname = platform.node()
|
||||
|
||||
if prefix is not None:
|
||||
file_name = file_name + prefix
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Writes events to disk in a logdir."""
|
||||
import os
|
||||
import stat
|
||||
|
||||
from ..._c_expression import EventWriter_
|
||||
from ._summary_adapter import package_init_event
|
||||
|
||||
|
||||
class BaseWriter:
|
||||
"""BaseWriter to be subclass."""
|
||||
|
||||
def __init__(self, filepath) -> None:
|
||||
self._filepath = filepath
|
||||
self._writer: EventWriter_ = None
|
||||
|
||||
def init_writer(self):
|
||||
"""Write some metadata etc."""
|
||||
|
||||
@property
|
||||
def writer(self) -> EventWriter_:
|
||||
"""Get the writer."""
|
||||
if self._writer is not None:
|
||||
return self._writer
|
||||
|
||||
with open(self._filepath, 'w'):
|
||||
os.chmod(self._filepath, stat.S_IWUSR | stat.S_IRUSR)
|
||||
self._writer = EventWriter_(self._filepath)
|
||||
self.init_writer()
|
||||
return self._writer
|
||||
|
||||
def write(self, plugin, mode, data):
|
||||
"""Write data to file."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def flush(self):
|
||||
"""Flush the writer."""
|
||||
if self._writer is not None:
|
||||
self._writer.Flush()
|
||||
|
||||
def close(self):
|
||||
"""Close the writer."""
|
||||
if self._writer is not None:
|
||||
self._writer.Shut()
|
||||
|
||||
|
||||
class SummaryWriter(BaseWriter):
|
||||
"""SummaryWriter for write summaries."""
|
||||
|
||||
def init_writer(self):
|
||||
"""Write some metadata etc."""
|
||||
self.writer.Write(package_init_event().SerializeToString())
|
||||
|
||||
def write(self, plugin, mode, data):
|
||||
"""Write data to file."""
|
||||
if plugin in ('summary', 'graph'):
|
||||
self.writer.Write(data)
|
||||
|
||||
|
||||
class LineageWriter(BaseWriter):
|
||||
"""LineageWriter for write lineage."""
|
||||
|
||||
def write(self, plugin, mode, data):
|
||||
"""Write data to file."""
|
||||
if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'):
|
||||
self.writer.Write(data)
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Write events to disk in a base directory."""
|
||||
import os
|
||||
from collections import deque
|
||||
from multiprocessing import Pool, Process, Queue, cpu_count
|
||||
|
||||
from ._lineage_adapter import serialize_to_lineage_event
|
||||
from ._summary_adapter import package_graph_event, package_summary_event
|
||||
from ._summary_writer import SummaryWriter, LineageWriter
|
||||
|
||||
|
||||
def _pack_data(datadict):
|
||||
"""Pack data according to which plugin."""
|
||||
result = []
|
||||
summaries, step, mode = [], None, None
|
||||
for plugin, datalist in datadict.items():
|
||||
for data in datalist:
|
||||
if plugin == 'graph':
|
||||
result.append([plugin, data.get('mode'), package_graph_event(data.get('value')).SerializeToString()])
|
||||
elif plugin in ('train_lineage', 'eval_lineage', 'custom_lineage_data', 'dataset_graph'):
|
||||
result.append([plugin, data.get('mode'), serialize_to_lineage_event(plugin, data.get('value'))])
|
||||
elif plugin in ('scalar', 'tensor', 'histogram', 'image'):
|
||||
summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')})
|
||||
step = data.get('step')
|
||||
mode = data.get('mode')
|
||||
if summaries:
|
||||
result.append(['summary', mode, package_summary_event(summaries, step).SerializeToString()])
|
||||
return result
|
||||
|
||||
|
||||
class WriterPool(Process):
|
||||
"""
|
||||
Use a set of pooled resident processes for writing a list of file.
|
||||
|
||||
Args:
|
||||
base_dir (str): The base directory to hold all the files.
|
||||
filelist (str): The mapping from short name to long filename.
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir, **filedict) -> None:
|
||||
super().__init__()
|
||||
self._base_dir, self._filedict = base_dir, filedict
|
||||
self._queue = Queue(cpu_count() * 2)
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
writers = self._get_writers()
|
||||
|
||||
with Pool() as pool:
|
||||
deq = deque()
|
||||
while True:
|
||||
while deq and deq[0].ready():
|
||||
for plugin, mode, data in deq.popleft().get():
|
||||
for writer in writers:
|
||||
writer.write(plugin, mode, data)
|
||||
|
||||
if not self._queue.empty():
|
||||
action, data = self._queue.get()
|
||||
if action == 'WRITE':
|
||||
deq.append(pool.apply_async(_pack_data, (data,)))
|
||||
elif action == 'FLUSH':
|
||||
for writer in writers:
|
||||
writer.flush()
|
||||
elif action == 'END':
|
||||
break
|
||||
for result in deq:
|
||||
for plugin, mode, data in result.get():
|
||||
for writer in writers:
|
||||
writer.write(plugin, mode, data)
|
||||
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
|
||||
def _get_writers(self):
|
||||
writers = []
|
||||
for plugin, filename in self._filedict.items():
|
||||
filepath = os.path.join(self._base_dir, filename)
|
||||
if plugin == 'summary':
|
||||
writers.append(SummaryWriter(filepath))
|
||||
elif plugin == 'lineage':
|
||||
writers.append(LineageWriter(filepath))
|
||||
return writers
|
||||
|
||||
def write(self, data) -> None:
|
||||
"""
|
||||
Write the event to file.
|
||||
|
||||
Args:
|
||||
name (str): The key of a specified file.
|
||||
data (Optional[str, Tuple[list, int]]): The data to write.
|
||||
"""
|
||||
self._queue.put(('WRITE', data))
|
||||
|
||||
def flush(self):
|
||||
"""Flush the writer and sync data to disk."""
|
||||
self._queue.put(('FLUSH', None))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the writer."""
|
||||
self._queue.put(('END', None))
|
||||
self.join()
|
|
@ -21,9 +21,9 @@ from mindspore import log as logger
|
|||
|
||||
from ..._c_expression import Tensor
|
||||
from ..._checkparam import _check_str_by_regular
|
||||
from .._utils import _make_directory
|
||||
from ._event_writer import EventWriter
|
||||
from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event
|
||||
from .._utils import _make_directory, _check_to_numpy, _check_lineage_value
|
||||
from ._summary_adapter import get_event_file_name, package_graph_event
|
||||
from ._writer_pool import WriterPool
|
||||
|
||||
# for the moment, this lock is for caution's sake,
|
||||
# there are actually no any concurrencies happening.
|
||||
|
@ -53,16 +53,20 @@ def _get_summary_tensor_data():
|
|||
return data
|
||||
|
||||
|
||||
def _dictlist():
|
||||
from collections import defaultdict
|
||||
return defaultdict(list)
|
||||
|
||||
|
||||
class SummaryRecord:
|
||||
"""
|
||||
SummaryRecord is used to record the summary value.
|
||||
SummaryRecord is used to record the summary data and lineage data.
|
||||
|
||||
Note:
|
||||
The API will create an event file in a given directory and add summaries and events to it.
|
||||
It writes the event log to a file by executing the record method. In addition,
|
||||
if the SummaryRecord object is created and the summary operator is used in the network,
|
||||
even if the record method is not called, the event in the cache will be written to the
|
||||
file at the end of execution. Make sure to close the SummaryRecord object at the end.
|
||||
The API will create a summary file and a lineage file lazily in a given directory and writes data to them.
|
||||
It writes the data to files by executing the record method. In addition to record the data bubbled up from
|
||||
the network by defining the summary operators, SummaryRecord also supports to record extra data which
|
||||
can be added by calling add_value. Finally, make sure to close the SummaryRecord object at the end.
|
||||
|
||||
Args:
|
||||
log_dir (str): The log_dir is a directory location to save the summary.
|
||||
|
@ -89,10 +93,12 @@ class SummaryRecord:
|
|||
file_suffix="_MS",
|
||||
network=None):
|
||||
|
||||
self._event_writer, self._closed = None, False
|
||||
self._closed, self._mode = False, 'train'
|
||||
self._data_pool = _dictlist()
|
||||
|
||||
_check_str_by_regular(file_prefix)
|
||||
_check_str_by_regular(file_suffix)
|
||||
|
||||
self.log_path = _make_directory(log_dir)
|
||||
|
||||
if not isinstance(queue_max_size, int) or not isinstance(flush_time, int):
|
||||
|
@ -123,16 +129,12 @@ class SummaryRecord:
|
|||
except Exception as ex:
|
||||
raise RuntimeError(ex)
|
||||
|
||||
def _init_event_writer(self):
|
||||
"""Init event writer and write metadata."""
|
||||
event_writer = EventWriter(self.full_file_name, self.flush_time)
|
||||
event_writer.write(package_init_event().SerializeToString())
|
||||
return event_writer
|
||||
self._event_writer = WriterPool(log_dir,
|
||||
summary=self.full_file_name,
|
||||
lineage=get_event_file_name('events', '_lineage'))
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager."""
|
||||
if not self._event_writer:
|
||||
self._event_writer = self._init_event_writer()
|
||||
if self._closed:
|
||||
raise ValueError('SummaryRecord has been closed.')
|
||||
return self
|
||||
|
@ -141,6 +143,76 @@ class SummaryRecord:
|
|||
"""Exit the context manager."""
|
||||
self.close()
|
||||
|
||||
def set_mode(self, mode):
|
||||
"""
|
||||
Set the mode for the recorder to be aware. The mode is set 'train' by default.
|
||||
|
||||
Args:
|
||||
mode (str): The mode to set, which should be 'train' or 'eval'.
|
||||
|
||||
Raises:
|
||||
ValueError: When the mode is not recognized.
|
||||
|
||||
Examples:
|
||||
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
|
||||
>>> summary_record.set_mode('eval')
|
||||
"""
|
||||
mode_spec = 'train', 'eval'
|
||||
if mode not in mode_spec:
|
||||
raise ValueError(f'{repr(mode)} is not a recognized mode.')
|
||||
self._mode = mode
|
||||
|
||||
def add_value(self, plugin, name, value):
|
||||
"""
|
||||
Add value to be record later on.
|
||||
|
||||
When the plugin is 'tensor', 'scalar', 'image' or 'histogram',
|
||||
the name should be the tag name, and the value should be a Tensor.
|
||||
|
||||
When the plugin plugin is 'graph', the value should be a GraphProto.
|
||||
|
||||
When the plugin 'dataset_graph', 'train_lineage', 'eval_lineage',
|
||||
or 'custom_lineage_data', the value should be a proto message.
|
||||
|
||||
|
||||
Args:
|
||||
plugin (str): The plugin for the value.
|
||||
name (str): The name for the value.
|
||||
value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \
|
||||
The value to store.
|
||||
|
||||
- GraphProto: The 'value' should be a serialized string this type when the plugin is 'graph'.
|
||||
- Tensor: The 'value' should be this type when the plugin is 'scalar', 'image', 'tensor' or 'histogram'.
|
||||
- TrainLineage: The 'value' should be this type when the plugin is 'train_lineage'.
|
||||
- EvaluationLineage: The 'value' should be this type when the plugin is 'eval_lineage'.
|
||||
- DatasetGraph: The 'value' should be this type when the plugin is 'dataset_graph'.
|
||||
- UserDefinedInfo: The 'value' should be this type when the plugin is 'custom_lineage_data'.
|
||||
|
||||
Raises:
|
||||
ValueError: When the name is not valid.
|
||||
TypeError: When the value is not a Tensor.
|
||||
|
||||
Examples:
|
||||
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
|
||||
>>> summary_record.add_value('scalar', 'loss', Tensor(0.1))
|
||||
"""
|
||||
if plugin in ('tensor', 'scalar', 'image', 'histogram'):
|
||||
if not name or not isinstance(name, str):
|
||||
raise ValueError(f'{repr(name)} is not a valid tag name.')
|
||||
if not isinstance(value, Tensor):
|
||||
raise TypeError(f'Expect the value to be Tensor, but got {type(value).__name__}')
|
||||
np_value = _check_to_numpy(plugin, value)
|
||||
self._data_pool[plugin].append(dict(tag=name, mode=self._mode, value=np_value))
|
||||
|
||||
elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'):
|
||||
_check_lineage_value(plugin, value)
|
||||
self._data_pool[plugin].append(dict(mode=self._mode, value=value.SerializeToString()))
|
||||
elif plugin == 'graph':
|
||||
package_graph_event(value)
|
||||
self._data_pool[plugin].append(dict(mode=self._mode, value=value))
|
||||
else:
|
||||
raise ValueError(f'No such plugin of {repr(plugin)}')
|
||||
|
||||
def record(self, step, train_network=None):
|
||||
"""
|
||||
Record the summary.
|
||||
|
@ -149,12 +221,12 @@ class SummaryRecord:
|
|||
step (int): Represents training step number.
|
||||
train_network (Cell): The network that called the callback.
|
||||
|
||||
Returns:
|
||||
bool, whether the record process is successful or not.
|
||||
|
||||
Examples:
|
||||
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
|
||||
>>> summary_record.record(step=2)
|
||||
|
||||
Returns:
|
||||
bool, whether the record process is successful or not.
|
||||
"""
|
||||
logger.info("SummaryRecord step is %r.", step)
|
||||
if self._closed:
|
||||
|
@ -163,10 +235,6 @@ class SummaryRecord:
|
|||
if not isinstance(step, int) or isinstance(step, bool):
|
||||
raise ValueError("`step` should be int")
|
||||
# Set the current summary of train step
|
||||
if not self._event_writer:
|
||||
self._event_writer = self._init_event_writer()
|
||||
logger.warning('SummaryRecord should be used as context manager for a with statement.')
|
||||
|
||||
if self.network is not None and not self.has_graph:
|
||||
graph_proto = self.network.get_func_graph_proto()
|
||||
if graph_proto is None and train_network is not None:
|
||||
|
@ -174,39 +242,48 @@ class SummaryRecord:
|
|||
if graph_proto is None:
|
||||
logger.error("Failed to get proto for graph")
|
||||
else:
|
||||
self._event_writer.write(package_graph_event(graph_proto).SerializeToString())
|
||||
self._event_writer.write({'graph': [{'step': step, 'value': graph_proto}]})
|
||||
self.has_graph = True
|
||||
if not _summary_tensor_cache:
|
||||
return True
|
||||
|
||||
data = _get_summary_tensor_data()
|
||||
if not data:
|
||||
logger.info("The step(%r) does not have record data.", step)
|
||||
return False
|
||||
if self.queue_max_size > 0 and len(data) > self.queue_max_size:
|
||||
logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data),
|
||||
self.queue_max_size)
|
||||
if self._mode == 'train':
|
||||
self._add_summary_tensor_data()
|
||||
|
||||
# process the data
|
||||
result = self._data_convert(data)
|
||||
if not result:
|
||||
logger.error("The step(%r) summary data is invalid.", step)
|
||||
return False
|
||||
self._event_writer.write((result, step))
|
||||
logger.debug("Send the summary data to scheduler for saving, step = %d", step)
|
||||
self._event_writer.write(self._consume_data_pool(step))
|
||||
return True
|
||||
|
||||
def _add_summary_tensor_data(self):
|
||||
summary_data = _get_summary_tensor_data()
|
||||
if not summary_data:
|
||||
logger.debug(f'No summary data bubbled from the network.')
|
||||
for name, tensor in summary_data.items():
|
||||
tag, plugin = SummaryRecord._parse_from(name)
|
||||
if (tag, plugin) == (None, None):
|
||||
logger.warning("The name(%r) is invalid, expected 'TAG[:TYPE]'.", name)
|
||||
else:
|
||||
self.add_value(plugin.lower(), tag, tensor)
|
||||
|
||||
def _consume_data_pool(self, step):
|
||||
try:
|
||||
for values in self._data_pool.values():
|
||||
for value in values:
|
||||
value['step'] = step
|
||||
return self._data_pool
|
||||
finally:
|
||||
self._data_pool = _dictlist()
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
"""
|
||||
Get the full path of the log file.
|
||||
|
||||
Returns:
|
||||
str, the full path of log file.
|
||||
|
||||
Examples:
|
||||
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
|
||||
>>> print(summary_record.log_dir)
|
||||
|
||||
Returns:
|
||||
String, the full path of log file.
|
||||
"""
|
||||
return self.full_file_name
|
||||
|
||||
|
@ -235,46 +312,19 @@ class SummaryRecord:
|
|||
"""
|
||||
if not self._closed and self._event_writer:
|
||||
# event writer flush and close
|
||||
logger.info('Please wait it may take quite some time to finish writing and closing.')
|
||||
self._event_writer.close()
|
||||
self._closed = True
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
|
||||
def _data_convert(self, summary):
|
||||
"""Convert the data."""
|
||||
# convert the summary to numpy
|
||||
result = []
|
||||
for name, data in summary.items():
|
||||
# confirm the data is valid
|
||||
summary_tag, summary_type = SummaryRecord._parse_from(name)
|
||||
if summary_tag is None:
|
||||
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
|
||||
return None
|
||||
if isinstance(data, Tensor):
|
||||
result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type})
|
||||
else:
|
||||
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _parse_from(name: str = None):
|
||||
"""
|
||||
Parse the tag and type from name.
|
||||
|
||||
Args:
|
||||
name (str): Format: TAG[:TYPE].
|
||||
|
||||
Returns:
|
||||
Tuple, (summary_tag, summary_type).
|
||||
"""
|
||||
if name is None:
|
||||
logger.error("The name is None")
|
||||
"""Parse the tag and type from name."""
|
||||
if not isinstance(name, str):
|
||||
return None, None
|
||||
match = re.match(r'(.+)\[:(.+)\]', name)
|
||||
if match:
|
||||
return match.groups()
|
||||
logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name)
|
||||
return None, None
|
||||
|
|
|
@ -84,21 +84,6 @@ def test_histogram_multi_summary():
|
|||
event = reader.read_event()
|
||||
assert event.summary.value[0].histogram.count == size
|
||||
|
||||
|
||||
def test_histogram_summary_scalar_tensor():
|
||||
"""Test histogram summary, input is a scalar tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
|
||||
test_data = _wrap_test_data(Tensor(1))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
with SummaryReader(file_name) as reader:
|
||||
event = reader.read_event()
|
||||
assert event.summary.value[0].histogram.count == 1
|
||||
|
||||
|
||||
def test_histogram_summary_empty_tensor():
|
||||
"""Test histogram summary, input is an empty tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
|
Loading…
Reference in New Issue