enhance the SummaryRecord with set_mode and add_value

This commit is contained in:
Li Hongzhang 2020-06-05 14:50:07 +08:00
parent 19e66f06e2
commit 0921c1e538
10 changed files with 522 additions and 177 deletions

View File

@ -79,6 +79,7 @@ if (ENABLE_DUMP_PROTO)
file(GLOB_RECURSE PROTO_PY RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"utils/anf_ir.proto"
"utils/summary.proto"
"utils/lineage.proto"
"utils/checkpoint.proto"
)
ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY})

View File

@ -0,0 +1,129 @@
// Copyright 2020 Huawei Technologies Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mindspore.irpb;
option cc_enable_arenas = true;
// Event Protocol buffer, Top define
message LineageEvent {
// Timestamp
required double wall_time = 1;
// The step of train.
optional int64 step = 2;
oneof what {
// An event file was started, with the specified version.
// Now version is "Mindspore.Event:1"
string version = 3;
// Train lineage
TrainLineage train_lineage = 6;
// Evaluation lineage
EvaluationLineage evaluation_lineage = 7;
// Dataset graph
DatasetGraph dataset_graph = 9;
// User defined info
UserDefinedInfo user_defined_info = 10;
}
}
// User defined info
message UserDefinedInfo{
// repeated user defined info
repeated UserDefinedInfo user_info = 1;
// key/value which contains both scalar and dict
map<string, UserDefinedInfo> map_dict = 2;
map<string, int32> map_int32 = 3;
map<string, string> map_str = 4;
map<string, double> map_double = 5;
}
// TrainLineage records infos of a train.
message TrainLineage{
message HyperParameters{
optional string optimizer = 1;
optional float learning_rate = 2;
optional string loss_function = 3;
optional int32 epoch = 4;
optional string parallel_mode = 5;
optional int32 device_num = 6;
optional int32 batch_size = 8;
}
message TrainDataset{
optional string train_dataset_path = 1;
optional int32 train_dataset_size = 2;
}
message Algorithm{
optional string network = 1;
optional float loss = 2;
}
message Model{
optional string path = 3;
optional int64 size = 4;
}
optional HyperParameters hyper_parameters = 1;
optional TrainDataset train_dataset = 2;
optional Algorithm algorithm = 3;
optional Model model = 4;
}
//EvalLineage records infos of evaluation.
message EvaluationLineage{
message ValidDataset{
optional string valid_dataset_path = 1;
optional int32 valid_dataset_size = 2;
}
optional string metric = 2;
optional ValidDataset valid_dataset = 3;
}
// DatasetGraph
message DatasetGraph {
repeated DatasetGraph children = 1;
optional OperationParameter parameter = 2;
repeated Operation operations = 3;
optional Operation sampler = 4;
}
message Operation {
optional OperationParameter operationParam = 1;
repeated int32 size = 2;
repeated float weights = 3;
}
message OperationParameter{
map<string, string> mapStr = 1;
map<string, StrList> mapStrList = 2;
map<string, bool> mapBool = 3;
map<string, int32> mapInt = 4;
map<string, double> mapDouble = 5;
}
message StrList {
repeated string strValue = 1;
}

View File

@ -22,6 +22,7 @@ from mindspore import log as logger
from mindspore.common.api import _executor
from mindspore.common.dtype import pytype_to_dtype
from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
def _convert_type(types):
"""
@ -196,3 +197,38 @@ def _to_full_shapes(shapes, device_num):
new_shape += (item,)
new_shapes.append(new_shape)
return new_shapes
def _check_to_numpy(plugin, tensor):
"""Check the tensor and return a numpy.ndarray."""
np_value = tensor.asnumpy()
if plugin == 'scalar':
if np_value.size == 1:
return np_value
raise ValueError('The tensor holds more than one value, but the scalar plugin expects on value.')
if plugin == 'image':
if np_value.ndim == 4:
return np_value
raise ValueError('The tensor seems not to hold a valid image.')
if plugin in ('tensor', 'histogram'):
if np_value.ndim > 0:
return np_value
raise ValueError('The tensor should not be empty.')
return np_value
def _check_lineage_value(plugin, value):
"""Check the lineage value."""
def raises(plugin, prototype):
raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.')
if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph):
raises(plugin, DatasetGraph)
if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage):
raises(plugin, EvaluationLineage)
if plugin == 'train_lineage' and not isinstance(value, TrainLineage):
raises(plugin, TrainLineage)
if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo):
raises(plugin, UserDefinedInfo)

View File

@ -1,88 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Writes events to disk in a logdir."""
import os
import stat
from collections import deque
from multiprocessing import Pool, Process, Queue, cpu_count
from ..._c_expression import EventWriter_
from ._summary_adapter import package_summary_event
def _pack(result, step):
summary_event = package_summary_event(result, step)
return summary_event.SerializeToString()
class EventWriter(Process):
"""
Creates a `EventWriter` and write event to file.
Args:
filepath (str): Summary event file path and file name.
flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120.
"""
def __init__(self, filepath: str, flush_interval: int) -> None:
super().__init__()
_ = flush_interval
with open(filepath, 'w'):
os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR)
self._writer = EventWriter_(filepath)
self._queue = Queue(cpu_count() * 2)
self.start()
def run(self):
with Pool(min(cpu_count(), 32)) as pool:
deq = deque()
while True:
while deq and deq[0].ready():
self._writer.Write(deq.popleft().get())
if not self._queue.empty():
action, data = self._queue.get()
if action == 'WRITE':
if not isinstance(data, (str, bytes)):
deq.append(pool.apply_async(_pack, data))
else:
self._writer.Write(data)
elif action == 'FLUSH':
self._writer.Flush()
elif action == 'END':
break
for res in deq:
self._writer.Write(res.get())
self._writer.Shut()
def write(self, data) -> None:
"""
Write the event to file.
Args:
data (Optional[str, Tuple[list, int]]): The data to write.
"""
self._queue.put(('WRITE', data))
def flush(self):
"""Flush the writer."""
self._queue.put(('FLUSH', None))
def close(self) -> None:
"""Close the writer."""
self._queue.put(('END', None))
self.join()

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate the lineage event which conform to proto format."""
import time
from ..lineage_pb2 import LineageEvent
def serialize_to_lineage_event(name, value):
"""Serialize value to lineage event."""
event = LineageEvent()
event.wall_time = time.time()
content = _get_lineage_content(name, event)
content.ParseFromString(value)
return event.SerializeToString()
def _get_lineage_content(name, event):
if name == 'dataset_graph':
return event.dataset_graph
if name == 'eval_lineage':
return event.evaluation_lineage
if name == 'train_lineage':
return event.train_lineage
if name == 'custom_lineage_data':
return event.user_defined_info
raise KeyError(f'No such field in LineageEvent')

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Generate the summary event which conform to proto format."""
import socket
import platform
import time
import numpy as np
@ -51,7 +51,7 @@ def get_event_file_name(prefix, suffix):
_check_str_by_regular(suffix)
file_name = ""
time_second = str(int(time.time()))
hostname = socket.gethostname()
hostname = platform.node()
if prefix is not None:
file_name = file_name + prefix

View File

@ -0,0 +1,79 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Writes events to disk in a logdir."""
import os
import stat
from ..._c_expression import EventWriter_
from ._summary_adapter import package_init_event
class BaseWriter:
"""BaseWriter to be subclass."""
def __init__(self, filepath) -> None:
self._filepath = filepath
self._writer: EventWriter_ = None
def init_writer(self):
"""Write some metadata etc."""
@property
def writer(self) -> EventWriter_:
"""Get the writer."""
if self._writer is not None:
return self._writer
with open(self._filepath, 'w'):
os.chmod(self._filepath, stat.S_IWUSR | stat.S_IRUSR)
self._writer = EventWriter_(self._filepath)
self.init_writer()
return self._writer
def write(self, plugin, mode, data):
"""Write data to file."""
raise NotImplementedError()
def flush(self):
"""Flush the writer."""
if self._writer is not None:
self._writer.Flush()
def close(self):
"""Close the writer."""
if self._writer is not None:
self._writer.Shut()
class SummaryWriter(BaseWriter):
"""SummaryWriter for write summaries."""
def init_writer(self):
"""Write some metadata etc."""
self.writer.Write(package_init_event().SerializeToString())
def write(self, plugin, mode, data):
"""Write data to file."""
if plugin in ('summary', 'graph'):
self.writer.Write(data)
class LineageWriter(BaseWriter):
"""LineageWriter for write lineage."""
def write(self, plugin, mode, data):
"""Write data to file."""
if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'):
self.writer.Write(data)

View File

@ -0,0 +1,114 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Write events to disk in a base directory."""
import os
from collections import deque
from multiprocessing import Pool, Process, Queue, cpu_count
from ._lineage_adapter import serialize_to_lineage_event
from ._summary_adapter import package_graph_event, package_summary_event
from ._summary_writer import SummaryWriter, LineageWriter
def _pack_data(datadict):
"""Pack data according to which plugin."""
result = []
summaries, step, mode = [], None, None
for plugin, datalist in datadict.items():
for data in datalist:
if plugin == 'graph':
result.append([plugin, data.get('mode'), package_graph_event(data.get('value')).SerializeToString()])
elif plugin in ('train_lineage', 'eval_lineage', 'custom_lineage_data', 'dataset_graph'):
result.append([plugin, data.get('mode'), serialize_to_lineage_event(plugin, data.get('value'))])
elif plugin in ('scalar', 'tensor', 'histogram', 'image'):
summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')})
step = data.get('step')
mode = data.get('mode')
if summaries:
result.append(['summary', mode, package_summary_event(summaries, step).SerializeToString()])
return result
class WriterPool(Process):
"""
Use a set of pooled resident processes for writing a list of file.
Args:
base_dir (str): The base directory to hold all the files.
filelist (str): The mapping from short name to long filename.
"""
def __init__(self, base_dir, **filedict) -> None:
super().__init__()
self._base_dir, self._filedict = base_dir, filedict
self._queue = Queue(cpu_count() * 2)
self.start()
def run(self):
writers = self._get_writers()
with Pool() as pool:
deq = deque()
while True:
while deq and deq[0].ready():
for plugin, mode, data in deq.popleft().get():
for writer in writers:
writer.write(plugin, mode, data)
if not self._queue.empty():
action, data = self._queue.get()
if action == 'WRITE':
deq.append(pool.apply_async(_pack_data, (data,)))
elif action == 'FLUSH':
for writer in writers:
writer.flush()
elif action == 'END':
break
for result in deq:
for plugin, mode, data in result.get():
for writer in writers:
writer.write(plugin, mode, data)
for writer in writers:
writer.close()
def _get_writers(self):
writers = []
for plugin, filename in self._filedict.items():
filepath = os.path.join(self._base_dir, filename)
if plugin == 'summary':
writers.append(SummaryWriter(filepath))
elif plugin == 'lineage':
writers.append(LineageWriter(filepath))
return writers
def write(self, data) -> None:
"""
Write the event to file.
Args:
name (str): The key of a specified file.
data (Optional[str, Tuple[list, int]]): The data to write.
"""
self._queue.put(('WRITE', data))
def flush(self):
"""Flush the writer and sync data to disk."""
self._queue.put(('FLUSH', None))
def close(self) -> None:
"""Close the writer."""
self._queue.put(('END', None))
self.join()

View File

@ -21,9 +21,9 @@ from mindspore import log as logger
from ..._c_expression import Tensor
from ..._checkparam import _check_str_by_regular
from .._utils import _make_directory
from ._event_writer import EventWriter
from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event
from .._utils import _make_directory, _check_to_numpy, _check_lineage_value
from ._summary_adapter import get_event_file_name, package_graph_event
from ._writer_pool import WriterPool
# for the moment, this lock is for caution's sake,
# there are actually no any concurrencies happening.
@ -53,16 +53,20 @@ def _get_summary_tensor_data():
return data
def _dictlist():
from collections import defaultdict
return defaultdict(list)
class SummaryRecord:
"""
SummaryRecord is used to record the summary value.
SummaryRecord is used to record the summary data and lineage data.
Note:
The API will create an event file in a given directory and add summaries and events to it.
It writes the event log to a file by executing the record method. In addition,
if the SummaryRecord object is created and the summary operator is used in the network,
even if the record method is not called, the event in the cache will be written to the
file at the end of execution. Make sure to close the SummaryRecord object at the end.
The API will create a summary file and a lineage file lazily in a given directory and writes data to them.
It writes the data to files by executing the record method. In addition to record the data bubbled up from
the network by defining the summary operators, SummaryRecord also supports to record extra data which
can be added by calling add_value. Finally, make sure to close the SummaryRecord object at the end.
Args:
log_dir (str): The log_dir is a directory location to save the summary.
@ -89,10 +93,12 @@ class SummaryRecord:
file_suffix="_MS",
network=None):
self._event_writer, self._closed = None, False
self._closed, self._mode = False, 'train'
self._data_pool = _dictlist()
_check_str_by_regular(file_prefix)
_check_str_by_regular(file_suffix)
self.log_path = _make_directory(log_dir)
if not isinstance(queue_max_size, int) or not isinstance(flush_time, int):
@ -123,16 +129,12 @@ class SummaryRecord:
except Exception as ex:
raise RuntimeError(ex)
def _init_event_writer(self):
"""Init event writer and write metadata."""
event_writer = EventWriter(self.full_file_name, self.flush_time)
event_writer.write(package_init_event().SerializeToString())
return event_writer
self._event_writer = WriterPool(log_dir,
summary=self.full_file_name,
lineage=get_event_file_name('events', '_lineage'))
def __enter__(self):
"""Enter the context manager."""
if not self._event_writer:
self._event_writer = self._init_event_writer()
if self._closed:
raise ValueError('SummaryRecord has been closed.')
return self
@ -141,6 +143,76 @@ class SummaryRecord:
"""Exit the context manager."""
self.close()
def set_mode(self, mode):
"""
Set the mode for the recorder to be aware. The mode is set 'train' by default.
Args:
mode (str): The mode to set, which should be 'train' or 'eval'.
Raises:
ValueError: When the mode is not recognized.
Examples:
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>> summary_record.set_mode('eval')
"""
mode_spec = 'train', 'eval'
if mode not in mode_spec:
raise ValueError(f'{repr(mode)} is not a recognized mode.')
self._mode = mode
def add_value(self, plugin, name, value):
"""
Add value to be record later on.
When the plugin is 'tensor', 'scalar', 'image' or 'histogram',
the name should be the tag name, and the value should be a Tensor.
When the plugin plugin is 'graph', the value should be a GraphProto.
When the plugin 'dataset_graph', 'train_lineage', 'eval_lineage',
or 'custom_lineage_data', the value should be a proto message.
Args:
plugin (str): The plugin for the value.
name (str): The name for the value.
value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \
The value to store.
- GraphProto: The 'value' should be a serialized string this type when the plugin is 'graph'.
- Tensor: The 'value' should be this type when the plugin is 'scalar', 'image', 'tensor' or 'histogram'.
- TrainLineage: The 'value' should be this type when the plugin is 'train_lineage'.
- EvaluationLineage: The 'value' should be this type when the plugin is 'eval_lineage'.
- DatasetGraph: The 'value' should be this type when the plugin is 'dataset_graph'.
- UserDefinedInfo: The 'value' should be this type when the plugin is 'custom_lineage_data'.
Raises:
ValueError: When the name is not valid.
TypeError: When the value is not a Tensor.
Examples:
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>> summary_record.add_value('scalar', 'loss', Tensor(0.1))
"""
if plugin in ('tensor', 'scalar', 'image', 'histogram'):
if not name or not isinstance(name, str):
raise ValueError(f'{repr(name)} is not a valid tag name.')
if not isinstance(value, Tensor):
raise TypeError(f'Expect the value to be Tensor, but got {type(value).__name__}')
np_value = _check_to_numpy(plugin, value)
self._data_pool[plugin].append(dict(tag=name, mode=self._mode, value=np_value))
elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'):
_check_lineage_value(plugin, value)
self._data_pool[plugin].append(dict(mode=self._mode, value=value.SerializeToString()))
elif plugin == 'graph':
package_graph_event(value)
self._data_pool[plugin].append(dict(mode=self._mode, value=value))
else:
raise ValueError(f'No such plugin of {repr(plugin)}')
def record(self, step, train_network=None):
"""
Record the summary.
@ -149,12 +221,12 @@ class SummaryRecord:
step (int): Represents training step number.
train_network (Cell): The network that called the callback.
Returns:
bool, whether the record process is successful or not.
Examples:
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>> summary_record.record(step=2)
Returns:
bool, whether the record process is successful or not.
"""
logger.info("SummaryRecord step is %r.", step)
if self._closed:
@ -163,10 +235,6 @@ class SummaryRecord:
if not isinstance(step, int) or isinstance(step, bool):
raise ValueError("`step` should be int")
# Set the current summary of train step
if not self._event_writer:
self._event_writer = self._init_event_writer()
logger.warning('SummaryRecord should be used as context manager for a with statement.')
if self.network is not None and not self.has_graph:
graph_proto = self.network.get_func_graph_proto()
if graph_proto is None and train_network is not None:
@ -174,39 +242,48 @@ class SummaryRecord:
if graph_proto is None:
logger.error("Failed to get proto for graph")
else:
self._event_writer.write(package_graph_event(graph_proto).SerializeToString())
self._event_writer.write({'graph': [{'step': step, 'value': graph_proto}]})
self.has_graph = True
if not _summary_tensor_cache:
return True
data = _get_summary_tensor_data()
if not data:
logger.info("The step(%r) does not have record data.", step)
return False
if self.queue_max_size > 0 and len(data) > self.queue_max_size:
logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data),
self.queue_max_size)
if self._mode == 'train':
self._add_summary_tensor_data()
# process the data
result = self._data_convert(data)
if not result:
logger.error("The step(%r) summary data is invalid.", step)
return False
self._event_writer.write((result, step))
logger.debug("Send the summary data to scheduler for saving, step = %d", step)
self._event_writer.write(self._consume_data_pool(step))
return True
def _add_summary_tensor_data(self):
summary_data = _get_summary_tensor_data()
if not summary_data:
logger.debug(f'No summary data bubbled from the network.')
for name, tensor in summary_data.items():
tag, plugin = SummaryRecord._parse_from(name)
if (tag, plugin) == (None, None):
logger.warning("The name(%r) is invalid, expected 'TAG[:TYPE]'.", name)
else:
self.add_value(plugin.lower(), tag, tensor)
def _consume_data_pool(self, step):
try:
for values in self._data_pool.values():
for value in values:
value['step'] = step
return self._data_pool
finally:
self._data_pool = _dictlist()
@property
def log_dir(self):
"""
Get the full path of the log file.
Returns:
str, the full path of log file.
Examples:
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
>>> print(summary_record.log_dir)
Returns:
String, the full path of log file.
"""
return self.full_file_name
@ -235,46 +312,19 @@ class SummaryRecord:
"""
if not self._closed and self._event_writer:
# event writer flush and close
logger.info('Please wait it may take quite some time to finish writing and closing.')
self._event_writer.close()
self._closed = True
def __del__(self) -> None:
self.close()
def _data_convert(self, summary):
"""Convert the data."""
# convert the summary to numpy
result = []
for name, data in summary.items():
# confirm the data is valid
summary_tag, summary_type = SummaryRecord._parse_from(name)
if summary_tag is None:
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
return None
if isinstance(data, Tensor):
result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type})
else:
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
return None
return result
@staticmethod
def _parse_from(name: str = None):
"""
Parse the tag and type from name.
Args:
name (str): Format: TAG[:TYPE].
Returns:
Tuple, (summary_tag, summary_type).
"""
if name is None:
logger.error("The name is None")
"""Parse the tag and type from name."""
if not isinstance(name, str):
return None, None
match = re.match(r'(.+)\[:(.+)\]', name)
if match:
return match.groups()
logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name)
return None, None

View File

@ -84,21 +84,6 @@ def test_histogram_multi_summary():
event = reader.read_event()
assert event.summary.value[0].histogram.count == size
def test_histogram_summary_scalar_tensor():
"""Test histogram summary, input is a scalar tensor."""
with tempfile.TemporaryDirectory() as tmp_dir:
with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
test_data = _wrap_test_data(Tensor(1))
_cache_summary_tensor_data(test_data)
test_writer.record(step=1)
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
with SummaryReader(file_name) as reader:
event = reader.read_event()
assert event.summary.value[0].histogram.count == 1
def test_histogram_summary_empty_tensor():
"""Test histogram summary, input is an empty tensor."""
with tempfile.TemporaryDirectory() as tmp_dir: