diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 88f88d49e93..1d6dca494b2 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -79,6 +79,7 @@ if (ENABLE_DUMP_PROTO) file(GLOB_RECURSE PROTO_PY RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "utils/anf_ir.proto" "utils/summary.proto" + "utils/lineage.proto" "utils/checkpoint.proto" ) ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) diff --git a/mindspore/ccsrc/utils/lineage.proto b/mindspore/ccsrc/utils/lineage.proto new file mode 100644 index 00000000000..510e58fc553 --- /dev/null +++ b/mindspore/ccsrc/utils/lineage.proto @@ -0,0 +1,129 @@ +// Copyright 2020 Huawei Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mindspore.irpb; +option cc_enable_arenas = true; + + +// Event Protocol buffer, Top define +message LineageEvent { + // Timestamp + required double wall_time = 1; + + // The step of train. + optional int64 step = 2; + + oneof what { + // An event file was started, with the specified version. + // Now version is "Mindspore.Event:1" + string version = 3; + + // Train lineage + TrainLineage train_lineage = 6; + + // Evaluation lineage + EvaluationLineage evaluation_lineage = 7; + + // Dataset graph + DatasetGraph dataset_graph = 9; + + // User defined info + UserDefinedInfo user_defined_info = 10; + } +} + +// User defined info +message UserDefinedInfo{ + // repeated user defined info + repeated UserDefinedInfo user_info = 1; + + // key/value which contains both scalar and dict + map map_dict = 2; + map map_int32 = 3; + map map_str = 4; + map map_double = 5; +} + +// TrainLineage records infos of a train. +message TrainLineage{ + message HyperParameters{ + optional string optimizer = 1; + optional float learning_rate = 2; + optional string loss_function = 3; + optional int32 epoch = 4; + optional string parallel_mode = 5; + optional int32 device_num = 6; + optional int32 batch_size = 8; + } + + message TrainDataset{ + optional string train_dataset_path = 1; + optional int32 train_dataset_size = 2; + } + + message Algorithm{ + optional string network = 1; + optional float loss = 2; + } + + message Model{ + optional string path = 3; + optional int64 size = 4; + } + + optional HyperParameters hyper_parameters = 1; + optional TrainDataset train_dataset = 2; + optional Algorithm algorithm = 3; + optional Model model = 4; +} + +//EvalLineage records infos of evaluation. +message EvaluationLineage{ + message ValidDataset{ + optional string valid_dataset_path = 1; + optional int32 valid_dataset_size = 2; + } + + optional string metric = 2; + optional ValidDataset valid_dataset = 3; +} + + +// DatasetGraph +message DatasetGraph { + repeated DatasetGraph children = 1; + optional OperationParameter parameter = 2; + repeated Operation operations = 3; + optional Operation sampler = 4; +} + +message Operation { + optional OperationParameter operationParam = 1; + repeated int32 size = 2; + repeated float weights = 3; +} + +message OperationParameter{ + map mapStr = 1; + map mapStrList = 2; + map mapBool = 3; + map mapInt = 4; + map mapDouble = 5; +} + +message StrList { + repeated string strValue = 1; +} diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 958ea7e2c2e..2e2a87758d5 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -22,6 +22,7 @@ from mindspore import log as logger from mindspore.common.api import _executor from mindspore.common.dtype import pytype_to_dtype +from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo def _convert_type(types): """ @@ -196,3 +197,38 @@ def _to_full_shapes(shapes, device_num): new_shape += (item,) new_shapes.append(new_shape) return new_shapes + + +def _check_to_numpy(plugin, tensor): + """Check the tensor and return a numpy.ndarray.""" + np_value = tensor.asnumpy() + if plugin == 'scalar': + if np_value.size == 1: + return np_value + raise ValueError('The tensor holds more than one value, but the scalar plugin expects on value.') + if plugin == 'image': + if np_value.ndim == 4: + return np_value + raise ValueError('The tensor seems not to hold a valid image.') + if plugin in ('tensor', 'histogram'): + if np_value.ndim > 0: + return np_value + raise ValueError('The tensor should not be empty.') + return np_value + +def _check_lineage_value(plugin, value): + """Check the lineage value.""" + def raises(plugin, prototype): + raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.') + + if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph): + raises(plugin, DatasetGraph) + + if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage): + raises(plugin, EvaluationLineage) + + if plugin == 'train_lineage' and not isinstance(value, TrainLineage): + raises(plugin, TrainLineage) + + if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo): + raises(plugin, UserDefinedInfo) diff --git a/mindspore/train/summary/_event_writer.py b/mindspore/train/summary/_event_writer.py deleted file mode 100644 index 0a21dea04ec..00000000000 --- a/mindspore/train/summary/_event_writer.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Writes events to disk in a logdir.""" -import os -import stat -from collections import deque -from multiprocessing import Pool, Process, Queue, cpu_count - -from ..._c_expression import EventWriter_ -from ._summary_adapter import package_summary_event - - -def _pack(result, step): - summary_event = package_summary_event(result, step) - return summary_event.SerializeToString() - - -class EventWriter(Process): - """ - Creates a `EventWriter` and write event to file. - - Args: - filepath (str): Summary event file path and file name. - flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120. - """ - - def __init__(self, filepath: str, flush_interval: int) -> None: - super().__init__() - _ = flush_interval - with open(filepath, 'w'): - os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR) - self._writer = EventWriter_(filepath) - self._queue = Queue(cpu_count() * 2) - self.start() - - def run(self): - - with Pool(min(cpu_count(), 32)) as pool: - deq = deque() - while True: - while deq and deq[0].ready(): - self._writer.Write(deq.popleft().get()) - - if not self._queue.empty(): - action, data = self._queue.get() - if action == 'WRITE': - if not isinstance(data, (str, bytes)): - deq.append(pool.apply_async(_pack, data)) - else: - self._writer.Write(data) - elif action == 'FLUSH': - self._writer.Flush() - elif action == 'END': - break - for res in deq: - self._writer.Write(res.get()) - - self._writer.Shut() - - def write(self, data) -> None: - """ - Write the event to file. - - Args: - data (Optional[str, Tuple[list, int]]): The data to write. - """ - self._queue.put(('WRITE', data)) - - def flush(self): - """Flush the writer.""" - self._queue.put(('FLUSH', None)) - - def close(self) -> None: - """Close the writer.""" - self._queue.put(('END', None)) - self.join() diff --git a/mindspore/train/summary/_lineage_adapter.py b/mindspore/train/summary/_lineage_adapter.py new file mode 100644 index 00000000000..d85d16b49db --- /dev/null +++ b/mindspore/train/summary/_lineage_adapter.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Generate the lineage event which conform to proto format.""" +import time + +from ..lineage_pb2 import LineageEvent + + +def serialize_to_lineage_event(name, value): + """Serialize value to lineage event.""" + event = LineageEvent() + event.wall_time = time.time() + content = _get_lineage_content(name, event) + content.ParseFromString(value) + return event.SerializeToString() + + +def _get_lineage_content(name, event): + if name == 'dataset_graph': + return event.dataset_graph + if name == 'eval_lineage': + return event.evaluation_lineage + if name == 'train_lineage': + return event.train_lineage + if name == 'custom_lineage_data': + return event.user_defined_info + raise KeyError(f'No such field in LineageEvent') diff --git a/mindspore/train/summary/_summary_adapter.py b/mindspore/train/summary/_summary_adapter.py index fc4a2302bd7..40e32b1c6ad 100644 --- a/mindspore/train/summary/_summary_adapter.py +++ b/mindspore/train/summary/_summary_adapter.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """Generate the summary event which conform to proto format.""" -import socket +import platform import time import numpy as np @@ -51,7 +51,7 @@ def get_event_file_name(prefix, suffix): _check_str_by_regular(suffix) file_name = "" time_second = str(int(time.time())) - hostname = socket.gethostname() + hostname = platform.node() if prefix is not None: file_name = file_name + prefix diff --git a/mindspore/train/summary/_summary_writer.py b/mindspore/train/summary/_summary_writer.py new file mode 100644 index 00000000000..36d020819a5 --- /dev/null +++ b/mindspore/train/summary/_summary_writer.py @@ -0,0 +1,79 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Writes events to disk in a logdir.""" +import os +import stat + +from ..._c_expression import EventWriter_ +from ._summary_adapter import package_init_event + + +class BaseWriter: + """BaseWriter to be subclass.""" + + def __init__(self, filepath) -> None: + self._filepath = filepath + self._writer: EventWriter_ = None + + def init_writer(self): + """Write some metadata etc.""" + + @property + def writer(self) -> EventWriter_: + """Get the writer.""" + if self._writer is not None: + return self._writer + + with open(self._filepath, 'w'): + os.chmod(self._filepath, stat.S_IWUSR | stat.S_IRUSR) + self._writer = EventWriter_(self._filepath) + self.init_writer() + return self._writer + + def write(self, plugin, mode, data): + """Write data to file.""" + raise NotImplementedError() + + def flush(self): + """Flush the writer.""" + if self._writer is not None: + self._writer.Flush() + + def close(self): + """Close the writer.""" + if self._writer is not None: + self._writer.Shut() + + +class SummaryWriter(BaseWriter): + """SummaryWriter for write summaries.""" + + def init_writer(self): + """Write some metadata etc.""" + self.writer.Write(package_init_event().SerializeToString()) + + def write(self, plugin, mode, data): + """Write data to file.""" + if plugin in ('summary', 'graph'): + self.writer.Write(data) + + +class LineageWriter(BaseWriter): + """LineageWriter for write lineage.""" + + def write(self, plugin, mode, data): + """Write data to file.""" + if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'): + self.writer.Write(data) diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py new file mode 100644 index 00000000000..2d219743dea --- /dev/null +++ b/mindspore/train/summary/_writer_pool.py @@ -0,0 +1,114 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Write events to disk in a base directory.""" +import os +from collections import deque +from multiprocessing import Pool, Process, Queue, cpu_count + +from ._lineage_adapter import serialize_to_lineage_event +from ._summary_adapter import package_graph_event, package_summary_event +from ._summary_writer import SummaryWriter, LineageWriter + + +def _pack_data(datadict): + """Pack data according to which plugin.""" + result = [] + summaries, step, mode = [], None, None + for plugin, datalist in datadict.items(): + for data in datalist: + if plugin == 'graph': + result.append([plugin, data.get('mode'), package_graph_event(data.get('value')).SerializeToString()]) + elif plugin in ('train_lineage', 'eval_lineage', 'custom_lineage_data', 'dataset_graph'): + result.append([plugin, data.get('mode'), serialize_to_lineage_event(plugin, data.get('value'))]) + elif plugin in ('scalar', 'tensor', 'histogram', 'image'): + summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')}) + step = data.get('step') + mode = data.get('mode') + if summaries: + result.append(['summary', mode, package_summary_event(summaries, step).SerializeToString()]) + return result + + +class WriterPool(Process): + """ + Use a set of pooled resident processes for writing a list of file. + + Args: + base_dir (str): The base directory to hold all the files. + filelist (str): The mapping from short name to long filename. + """ + + def __init__(self, base_dir, **filedict) -> None: + super().__init__() + self._base_dir, self._filedict = base_dir, filedict + self._queue = Queue(cpu_count() * 2) + self.start() + + def run(self): + writers = self._get_writers() + + with Pool() as pool: + deq = deque() + while True: + while deq and deq[0].ready(): + for plugin, mode, data in deq.popleft().get(): + for writer in writers: + writer.write(plugin, mode, data) + + if not self._queue.empty(): + action, data = self._queue.get() + if action == 'WRITE': + deq.append(pool.apply_async(_pack_data, (data,))) + elif action == 'FLUSH': + for writer in writers: + writer.flush() + elif action == 'END': + break + for result in deq: + for plugin, mode, data in result.get(): + for writer in writers: + writer.write(plugin, mode, data) + + for writer in writers: + writer.close() + + def _get_writers(self): + writers = [] + for plugin, filename in self._filedict.items(): + filepath = os.path.join(self._base_dir, filename) + if plugin == 'summary': + writers.append(SummaryWriter(filepath)) + elif plugin == 'lineage': + writers.append(LineageWriter(filepath)) + return writers + + def write(self, data) -> None: + """ + Write the event to file. + + Args: + name (str): The key of a specified file. + data (Optional[str, Tuple[list, int]]): The data to write. + """ + self._queue.put(('WRITE', data)) + + def flush(self): + """Flush the writer and sync data to disk.""" + self._queue.put(('FLUSH', None)) + + def close(self) -> None: + """Close the writer.""" + self._queue.put(('END', None)) + self.join() diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 661eb1d810d..61c2c8adebe 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -21,9 +21,9 @@ from mindspore import log as logger from ..._c_expression import Tensor from ..._checkparam import _check_str_by_regular -from .._utils import _make_directory -from ._event_writer import EventWriter -from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event +from .._utils import _make_directory, _check_to_numpy, _check_lineage_value +from ._summary_adapter import get_event_file_name, package_graph_event +from ._writer_pool import WriterPool # for the moment, this lock is for caution's sake, # there are actually no any concurrencies happening. @@ -53,16 +53,20 @@ def _get_summary_tensor_data(): return data +def _dictlist(): + from collections import defaultdict + return defaultdict(list) + + class SummaryRecord: """ - SummaryRecord is used to record the summary value. + SummaryRecord is used to record the summary data and lineage data. Note: - The API will create an event file in a given directory and add summaries and events to it. - It writes the event log to a file by executing the record method. In addition, - if the SummaryRecord object is created and the summary operator is used in the network, - even if the record method is not called, the event in the cache will be written to the - file at the end of execution. Make sure to close the SummaryRecord object at the end. + The API will create a summary file and a lineage file lazily in a given directory and writes data to them. + It writes the data to files by executing the record method. In addition to record the data bubbled up from + the network by defining the summary operators, SummaryRecord also supports to record extra data which + can be added by calling add_value. Finally, make sure to close the SummaryRecord object at the end. Args: log_dir (str): The log_dir is a directory location to save the summary. @@ -89,10 +93,12 @@ class SummaryRecord: file_suffix="_MS", network=None): - self._event_writer, self._closed = None, False + self._closed, self._mode = False, 'train' + self._data_pool = _dictlist() _check_str_by_regular(file_prefix) _check_str_by_regular(file_suffix) + self.log_path = _make_directory(log_dir) if not isinstance(queue_max_size, int) or not isinstance(flush_time, int): @@ -123,16 +129,12 @@ class SummaryRecord: except Exception as ex: raise RuntimeError(ex) - def _init_event_writer(self): - """Init event writer and write metadata.""" - event_writer = EventWriter(self.full_file_name, self.flush_time) - event_writer.write(package_init_event().SerializeToString()) - return event_writer + self._event_writer = WriterPool(log_dir, + summary=self.full_file_name, + lineage=get_event_file_name('events', '_lineage')) def __enter__(self): """Enter the context manager.""" - if not self._event_writer: - self._event_writer = self._init_event_writer() if self._closed: raise ValueError('SummaryRecord has been closed.') return self @@ -141,6 +143,76 @@ class SummaryRecord: """Exit the context manager.""" self.close() + def set_mode(self, mode): + """ + Set the mode for the recorder to be aware. The mode is set 'train' by default. + + Args: + mode (str): The mode to set, which should be 'train' or 'eval'. + + Raises: + ValueError: When the mode is not recognized. + + Examples: + >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: + >>> summary_record.set_mode('eval') + """ + mode_spec = 'train', 'eval' + if mode not in mode_spec: + raise ValueError(f'{repr(mode)} is not a recognized mode.') + self._mode = mode + + def add_value(self, plugin, name, value): + """ + Add value to be record later on. + + When the plugin is 'tensor', 'scalar', 'image' or 'histogram', + the name should be the tag name, and the value should be a Tensor. + + When the plugin plugin is 'graph', the value should be a GraphProto. + + When the plugin 'dataset_graph', 'train_lineage', 'eval_lineage', + or 'custom_lineage_data', the value should be a proto message. + + + Args: + plugin (str): The plugin for the value. + name (str): The name for the value. + value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \ + The value to store. + + - GraphProto: The 'value' should be a serialized string this type when the plugin is 'graph'. + - Tensor: The 'value' should be this type when the plugin is 'scalar', 'image', 'tensor' or 'histogram'. + - TrainLineage: The 'value' should be this type when the plugin is 'train_lineage'. + - EvaluationLineage: The 'value' should be this type when the plugin is 'eval_lineage'. + - DatasetGraph: The 'value' should be this type when the plugin is 'dataset_graph'. + - UserDefinedInfo: The 'value' should be this type when the plugin is 'custom_lineage_data'. + + Raises: + ValueError: When the name is not valid. + TypeError: When the value is not a Tensor. + + Examples: + >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: + >>> summary_record.add_value('scalar', 'loss', Tensor(0.1)) + """ + if plugin in ('tensor', 'scalar', 'image', 'histogram'): + if not name or not isinstance(name, str): + raise ValueError(f'{repr(name)} is not a valid tag name.') + if not isinstance(value, Tensor): + raise TypeError(f'Expect the value to be Tensor, but got {type(value).__name__}') + np_value = _check_to_numpy(plugin, value) + self._data_pool[plugin].append(dict(tag=name, mode=self._mode, value=np_value)) + + elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'): + _check_lineage_value(plugin, value) + self._data_pool[plugin].append(dict(mode=self._mode, value=value.SerializeToString())) + elif plugin == 'graph': + package_graph_event(value) + self._data_pool[plugin].append(dict(mode=self._mode, value=value)) + else: + raise ValueError(f'No such plugin of {repr(plugin)}') + def record(self, step, train_network=None): """ Record the summary. @@ -149,12 +221,12 @@ class SummaryRecord: step (int): Represents training step number. train_network (Cell): The network that called the callback. + Returns: + bool, whether the record process is successful or not. + Examples: >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: >>> summary_record.record(step=2) - - Returns: - bool, whether the record process is successful or not. """ logger.info("SummaryRecord step is %r.", step) if self._closed: @@ -163,10 +235,6 @@ class SummaryRecord: if not isinstance(step, int) or isinstance(step, bool): raise ValueError("`step` should be int") # Set the current summary of train step - if not self._event_writer: - self._event_writer = self._init_event_writer() - logger.warning('SummaryRecord should be used as context manager for a with statement.') - if self.network is not None and not self.has_graph: graph_proto = self.network.get_func_graph_proto() if graph_proto is None and train_network is not None: @@ -174,39 +242,48 @@ class SummaryRecord: if graph_proto is None: logger.error("Failed to get proto for graph") else: - self._event_writer.write(package_graph_event(graph_proto).SerializeToString()) + self._event_writer.write({'graph': [{'step': step, 'value': graph_proto}]}) self.has_graph = True if not _summary_tensor_cache: return True - data = _get_summary_tensor_data() - if not data: - logger.info("The step(%r) does not have record data.", step) - return False - if self.queue_max_size > 0 and len(data) > self.queue_max_size: - logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data), - self.queue_max_size) + if self._mode == 'train': + self._add_summary_tensor_data() - # process the data - result = self._data_convert(data) - if not result: - logger.error("The step(%r) summary data is invalid.", step) - return False - self._event_writer.write((result, step)) - logger.debug("Send the summary data to scheduler for saving, step = %d", step) + self._event_writer.write(self._consume_data_pool(step)) return True + def _add_summary_tensor_data(self): + summary_data = _get_summary_tensor_data() + if not summary_data: + logger.debug(f'No summary data bubbled from the network.') + for name, tensor in summary_data.items(): + tag, plugin = SummaryRecord._parse_from(name) + if (tag, plugin) == (None, None): + logger.warning("The name(%r) is invalid, expected 'TAG[:TYPE]'.", name) + else: + self.add_value(plugin.lower(), tag, tensor) + + def _consume_data_pool(self, step): + try: + for values in self._data_pool.values(): + for value in values: + value['step'] = step + return self._data_pool + finally: + self._data_pool = _dictlist() + @property def log_dir(self): """ Get the full path of the log file. + Returns: + str, the full path of log file. + Examples: >>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record: >>> print(summary_record.log_dir) - - Returns: - String, the full path of log file. """ return self.full_file_name @@ -235,46 +312,19 @@ class SummaryRecord: """ if not self._closed and self._event_writer: # event writer flush and close + logger.info('Please wait it may take quite some time to finish writing and closing.') self._event_writer.close() self._closed = True def __del__(self) -> None: self.close() - def _data_convert(self, summary): - """Convert the data.""" - # convert the summary to numpy - result = [] - for name, data in summary.items(): - # confirm the data is valid - summary_tag, summary_type = SummaryRecord._parse_from(name) - if summary_tag is None: - logger.error("The data type is invalid, name = %r, tensor = %r", name, data) - return None - if isinstance(data, Tensor): - result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type}) - else: - logger.error("The data type is invalid, name = %r, tensor = %r", name, data) - return None - - return result - @staticmethod def _parse_from(name: str = None): - """ - Parse the tag and type from name. - - Args: - name (str): Format: TAG[:TYPE]. - - Returns: - Tuple, (summary_tag, summary_type). - """ - if name is None: - logger.error("The name is None") + """Parse the tag and type from name.""" + if not isinstance(name, str): return None, None match = re.match(r'(.+)\[:(.+)\]', name) if match: return match.groups() - logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name) return None, None diff --git a/tests/ut/python/train/summary/test_histogram_summary.py b/tests/ut/python/train/summary/test_histogram_summary.py index e304146a2ea..2d5a175f70a 100644 --- a/tests/ut/python/train/summary/test_histogram_summary.py +++ b/tests/ut/python/train/summary/test_histogram_summary.py @@ -84,21 +84,6 @@ def test_histogram_multi_summary(): event = reader.read_event() assert event.summary.value[0].histogram.count == size - -def test_histogram_summary_scalar_tensor(): - """Test histogram summary, input is a scalar tensor.""" - with tempfile.TemporaryDirectory() as tmp_dir: - with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer: - test_data = _wrap_test_data(Tensor(1)) - _cache_summary_tensor_data(test_data) - test_writer.record(step=1) - - file_name = os.path.join(tmp_dir, test_writer.event_file_name) - with SummaryReader(file_name) as reader: - event = reader.read_event() - assert event.summary.value[0].histogram.count == 1 - - def test_histogram_summary_empty_tensor(): """Test histogram summary, input is an empty tensor.""" with tempfile.TemporaryDirectory() as tmp_dir: