forked from mindspore-Ecosystem/mindspore
!10881 add the tensor collection feature when record summary
From: @jiang-shuqiang Reviewed-by: Signed-off-by:
This commit is contained in:
commit
fb97aa327c
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -78,7 +78,7 @@ def _make_directory(path: str):
|
|||
"""Make directory."""
|
||||
if path is None or not isinstance(path, str) or path.strip() == "":
|
||||
logger.error("The path(%r) is invalid type.", path)
|
||||
raise TypeError("Input path is invaild type")
|
||||
raise TypeError("Input path is invalid type")
|
||||
|
||||
path = os.path.realpath(path)
|
||||
logger.debug("The abs path is %r", path)
|
||||
|
|
|
@ -27,7 +27,7 @@ from mindspore import log as logger
|
|||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from mindspore.train.summary.summary_record import SummaryRecord, process_export_options
|
||||
from mindspore.train.summary.enums import PluginEnum, ModeEnum
|
||||
from mindspore.train.callback import Callback, ModelCheckpoint
|
||||
from mindspore.train import lineage_pb2
|
||||
|
@ -89,21 +89,21 @@ class SummaryCollector(Callback):
|
|||
For example, you can set {'collect_metric': False} to control not collecting metrics.
|
||||
The data that supports control is shown below.
|
||||
|
||||
- collect_metric: Whether to collect training metrics, currently only the loss is collected.
|
||||
- collect_metric (bool): Whether to collect training metrics, currently only the loss is collected.
|
||||
The first output will be treated as the loss and it will be averaged.
|
||||
Optional: True/False. Default: True.
|
||||
- collect_graph: Whether to collect the computational graph. Currently, only
|
||||
- collect_graph (bool): Whether to collect the computational graph. Currently, only
|
||||
training computational graph is collected. Optional: True/False. Default: True.
|
||||
- collect_train_lineage: Whether to collect lineage data for the training phase,
|
||||
- collect_train_lineage (bool): Whether to collect lineage data for the training phase,
|
||||
this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.
|
||||
- collect_eval_lineage: Whether to collect lineage data for the evaluation phase,
|
||||
- collect_eval_lineage (bool): Whether to collect lineage data for the evaluation phase,
|
||||
this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.
|
||||
- collect_input_data: Whether to collect dataset for each training. Currently only image data is supported.
|
||||
- collect_input_data (bool): Whether to collect dataset for each training.
|
||||
Currently only image data is supported. Optional: True/False. Default: True.
|
||||
- collect_dataset_graph (bool): Whether to collect dataset graph for the training phase.
|
||||
Optional: True/False. Default: True.
|
||||
- collect_dataset_graph: Whether to collect dataset graph for the training phase.
|
||||
Optional: True/False. Default: True.
|
||||
- histogram_regular: Collect weight and bias for parameter distribution page and displayed in MindInsight.
|
||||
This field allows regular strings to control which parameters to collect.
|
||||
- histogram_regular (Union[str, None]): Collect weight and bias for parameter distribution page
|
||||
and displayed in MindInsight. This field allows regular strings to control which parameters to collect.
|
||||
Default: None, it means only the first five parameters are collected.
|
||||
It is not recommended to collect too many parameters at once, as it can affect performance.
|
||||
Note that if you collect too many parameters and run out of memory, the training will fail.
|
||||
|
@ -127,6 +127,13 @@ class SummaryCollector(Callback):
|
|||
max_file_size (Optional[int]): The maximum size in bytes of each file that can be written to the disk.
|
||||
Default: None, which means no limit. For example, to write not larger than 4GB,
|
||||
specify `max_file_size=4 * 1024**3`.
|
||||
export_options (Union[None, dict]): Perform custom operations on the export data.
|
||||
Default: None, it means there is no export data.
|
||||
You can customize the export data with a dictionary. For example, you can set {'tensor_format': 'npy'}
|
||||
to export tensor as npy file. The data that supports control is shown below.
|
||||
|
||||
- tensor_format (Union[str, None]): Customize the export tensor format.
|
||||
Default: None, it means there is no export tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If the parameter value is not expected.
|
||||
|
@ -175,7 +182,8 @@ class SummaryCollector(Callback):
|
|||
keep_default_action=True,
|
||||
custom_lineage_data=None,
|
||||
collect_tensor_freq=None,
|
||||
max_file_size=None):
|
||||
max_file_size=None,
|
||||
export_options=None):
|
||||
super(SummaryCollector, self).__init__()
|
||||
|
||||
self._summary_dir = self._process_summary_dir(summary_dir)
|
||||
|
@ -191,6 +199,8 @@ class SummaryCollector(Callback):
|
|||
self._check_positive('max_file_size', max_file_size, allow_none=True)
|
||||
self._max_file_size = max_file_size
|
||||
|
||||
self._export_options = process_export_options(export_options)
|
||||
|
||||
self._check_action(keep_default_action)
|
||||
|
||||
self._collect_specified_data = self._process_specified_data(collect_specified_data, keep_default_action)
|
||||
|
@ -209,7 +219,8 @@ class SummaryCollector(Callback):
|
|||
def __enter__(self):
|
||||
self._record = SummaryRecord(log_dir=self._summary_dir,
|
||||
max_file_size=self._max_file_size,
|
||||
raise_exception=False)
|
||||
raise_exception=False,
|
||||
export_options=self._export_options)
|
||||
self._first_step, self._dataset_sink_mode = True, True
|
||||
return self
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -36,13 +36,14 @@ EVENT_FILE_INIT_VERSION = 1
|
|||
F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max
|
||||
|
||||
|
||||
def get_event_file_name(prefix, suffix):
|
||||
def get_event_file_name(prefix, suffix, seconds=None):
|
||||
"""
|
||||
Create file name: file_prefix + EVENT_FILE_NAME_MARK + time(seconds) + "." + Hostname + file_suffix.
|
||||
|
||||
Args:
|
||||
prefix (str): The prefix of file name.
|
||||
suffix (str): The suffix of file name.
|
||||
seconds (str): The time stamp of file name.
|
||||
|
||||
Returns:
|
||||
String, the name of event log file.
|
||||
|
@ -51,6 +52,8 @@ def get_event_file_name(prefix, suffix):
|
|||
Validator.check_str_by_regular(suffix)
|
||||
file_name = ""
|
||||
time_second = str(int(time.time()))
|
||||
if seconds is not None:
|
||||
time_second = seconds
|
||||
hostname = platform.node()
|
||||
|
||||
if prefix is not None:
|
||||
|
@ -96,8 +99,9 @@ def package_summary_event(data_list, step, wall_time):
|
|||
Package the summary to event protobuffer.
|
||||
|
||||
Args:
|
||||
data_id (Number): Summary data id.
|
||||
data_list (list): Summary data list.
|
||||
step (Number): The recode step index.
|
||||
wall_time (float): The wall time.
|
||||
|
||||
Returns:
|
||||
Summary, the summary event.
|
||||
|
|
|
@ -19,11 +19,12 @@ import signal
|
|||
from collections import deque
|
||||
|
||||
import mindspore.log as logger
|
||||
from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum
|
||||
|
||||
from ._lineage_adapter import serialize_to_lineage_event
|
||||
from ._summary_adapter import package_graph_event, package_summary_event
|
||||
from ._explain_adapter import package_explain_event
|
||||
from .writer import LineageWriter, SummaryWriter, ExplainWriter
|
||||
from .writer import LineageWriter, SummaryWriter, ExplainWriter, ExportWriter
|
||||
|
||||
try:
|
||||
from multiprocessing import get_context
|
||||
|
@ -37,17 +38,24 @@ def _pack_data(datadict, wall_time):
|
|||
result, summaries, step = [], [], None
|
||||
for plugin, datalist in datadict.items():
|
||||
for data in datalist:
|
||||
if plugin == 'graph':
|
||||
if plugin == PluginEnum.GRAPH.value:
|
||||
result.append([plugin, package_graph_event(data.get('value')).SerializeToString()])
|
||||
elif plugin in ('train_lineage', 'eval_lineage', 'custom_lineage_data', 'dataset_graph'):
|
||||
elif plugin in (PluginEnum.TRAIN_LINEAGE.value, PluginEnum.EVAL_LINEAGE.value,
|
||||
PluginEnum.CUSTOM_LINEAGE_DATA.value, PluginEnum.DATASET_GRAPH.value):
|
||||
result.append([plugin, serialize_to_lineage_event(plugin, data.get('value'))])
|
||||
elif plugin in ('scalar', 'tensor', 'histogram', 'image'):
|
||||
elif plugin in (PluginEnum.SCALAR.value, PluginEnum.TENSOR.value, PluginEnum.HISTOGRAM.value,
|
||||
PluginEnum.IMAGE.value):
|
||||
summaries.append({'_type': plugin.title(), 'name': data.get('tag'), 'data': data.get('value')})
|
||||
step = data.get('step')
|
||||
elif plugin == 'explainer':
|
||||
elif plugin == PluginEnum.EXPLAINER.value:
|
||||
result.append([plugin, package_explain_event(data.get('value'))])
|
||||
|
||||
if 'export_option' in data:
|
||||
result.append([WriterPluginEnum.EXPORTER.value, data])
|
||||
|
||||
if summaries:
|
||||
result.append(['summary', package_summary_event(summaries, step, wall_time).SerializeToString()])
|
||||
result.append(
|
||||
[WriterPluginEnum.SUMMARY.value, package_summary_event(summaries, step, wall_time).SerializeToString()])
|
||||
return result
|
||||
|
||||
|
||||
|
@ -60,6 +68,7 @@ class WriterPool(ctx.Process):
|
|||
max_file_size (Optional[int]): The maximum size of each file that can be written to disk in bytes.
|
||||
raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs
|
||||
in recording data. Default: False, this means that error logs are printed and no exception is thrown.
|
||||
export_options (Union[None, dict]): Perform custom operations on the export data. Default: None.
|
||||
filedict (dict): The mapping from plugin to filename.
|
||||
"""
|
||||
|
||||
|
@ -114,12 +123,14 @@ class WriterPool(ctx.Process):
|
|||
self._writers_ = []
|
||||
for plugin, filename in self._filedict.items():
|
||||
filepath = os.path.join(self._base_dir, filename)
|
||||
if plugin == 'summary':
|
||||
if plugin == WriterPluginEnum.SUMMARY.value:
|
||||
self._writers_.append(SummaryWriter(filepath, self._max_file_size))
|
||||
elif plugin == 'lineage':
|
||||
elif plugin == WriterPluginEnum.LINEAGE.value:
|
||||
self._writers_.append(LineageWriter(filepath, self._max_file_size))
|
||||
elif plugin == 'explainer':
|
||||
elif plugin == WriterPluginEnum.EXPLAINER.value:
|
||||
self._writers_.append(ExplainWriter(filepath, self._max_file_size))
|
||||
elif plugin == WriterPluginEnum.EXPORTER.value:
|
||||
self._writers_.append(ExportWriter(filepath, self._max_file_size))
|
||||
return self._writers_
|
||||
|
||||
def _write(self, plugin, data):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -34,7 +34,17 @@ class PluginEnum(BaseEnum):
|
|||
HISTOGRAM = 'histogram'
|
||||
TRAIN_LINEAGE = 'train_lineage'
|
||||
EVAL_LINEAGE = 'eval_lineage'
|
||||
CUSTOM_LINEAGE_DATA = 'custom_lineage_data'
|
||||
DATASET_GRAPH = 'dataset_graph'
|
||||
EXPLAINER = 'explainer'
|
||||
|
||||
|
||||
class WriterPluginEnum(Enum):
|
||||
"""The list of extra plugins."""
|
||||
EXPORTER = 'exporter'
|
||||
EXPLAINER = 'explainer'
|
||||
SUMMARY = 'summary'
|
||||
LINEAGE = 'lineage'
|
||||
|
||||
|
||||
class ModeEnum(BaseEnum):
|
||||
|
|
|
@ -17,6 +17,7 @@ import atexit
|
|||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from mindspore import log as logger
|
||||
|
@ -24,7 +25,7 @@ from mindspore.nn import Cell
|
|||
|
||||
from ..._c_expression import Tensor
|
||||
from ..._checkparam import Validator
|
||||
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory
|
||||
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type
|
||||
from ._summary_adapter import get_event_file_name, package_graph_event
|
||||
from ._explain_adapter import check_explain_proto
|
||||
from ._writer_pool import WriterPool
|
||||
|
@ -34,6 +35,9 @@ from ._writer_pool import WriterPool
|
|||
_summary_lock = threading.Lock()
|
||||
# cache the summary data
|
||||
_summary_tensor_cache = {}
|
||||
_DEFAULT_EXPORT_OPTIONS = {
|
||||
'tensor_format': 'npy',
|
||||
}
|
||||
|
||||
|
||||
def _cache_summary_tensor_data(summary):
|
||||
|
@ -57,6 +61,27 @@ def _get_summary_tensor_data():
|
|||
return data
|
||||
|
||||
|
||||
def process_export_options(export_options):
|
||||
"""Check specified data type and value."""
|
||||
if export_options is None:
|
||||
return None
|
||||
|
||||
check_value_type('export_options', export_options, [dict, type(None)])
|
||||
|
||||
for param_name in export_options:
|
||||
check_value_type(param_name, param_name, [str])
|
||||
|
||||
unexpected_params = set(export_options) - set(_DEFAULT_EXPORT_OPTIONS)
|
||||
if unexpected_params:
|
||||
raise ValueError(f'For `export_options` the keys {unexpected_params} are unsupported, '
|
||||
f'expect the follow keys: {list(_DEFAULT_EXPORT_OPTIONS.keys())}')
|
||||
|
||||
for item in set(export_options):
|
||||
check_value_type(item, export_options.get(item), [str, type(None)])
|
||||
|
||||
return export_options
|
||||
|
||||
|
||||
class SummaryRecord:
|
||||
"""
|
||||
SummaryRecord is used to record the summary data and lineage data.
|
||||
|
@ -81,6 +106,13 @@ class SummaryRecord:
|
|||
Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`.
|
||||
raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs
|
||||
in recording data. Default: False, this means that error logs are printed and no exception is thrown.
|
||||
export_options (Union[None, dict]): Perform custom operations on the export data.
|
||||
Default: None, it means there is no export data.
|
||||
You can customize the export data with a dictionary. For example, you can set {'tensor_format': 'npy'}
|
||||
to export tensor as npy file. The data that supports control is shown below.
|
||||
|
||||
- tensor_format (Union[str, None]): Customize the export tensor format.
|
||||
Default: None, it means there is no export tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If the parameter type is incorrect.
|
||||
|
@ -99,7 +131,7 @@ class SummaryRecord:
|
|||
"""
|
||||
|
||||
def __init__(self, log_dir, file_prefix="events", file_suffix="_MS",
|
||||
network=None, max_file_size=None, raise_exception=False):
|
||||
network=None, max_file_size=None, raise_exception=False, export_options=None):
|
||||
|
||||
self._closed, self._event_writer = False, None
|
||||
self._mode, self._data_pool = 'train', defaultdict(list)
|
||||
|
@ -126,13 +158,20 @@ class SummaryRecord:
|
|||
self.network = network
|
||||
self.has_graph = False
|
||||
|
||||
seconds = str(int(time.time()))
|
||||
# create the summary writer file
|
||||
self.event_file_name = get_event_file_name(self.prefix, self.suffix)
|
||||
self.event_file_name = get_event_file_name(self.prefix, self.suffix, seconds)
|
||||
self.full_file_name = os.path.join(self.log_path, self.event_file_name)
|
||||
|
||||
self._export_options = process_export_options(export_options)
|
||||
export_dir = ''
|
||||
if self._export_options is not None:
|
||||
export_dir = "export_{}".format(seconds)
|
||||
|
||||
filename_dict = dict(summary=self.full_file_name,
|
||||
lineage=get_event_file_name(self.prefix, '_lineage'),
|
||||
explainer=get_event_file_name(self.prefix, '_explain'))
|
||||
explainer=get_event_file_name(self.prefix, '_explain'),
|
||||
exporter=export_dir)
|
||||
self._event_writer = WriterPool(log_dir,
|
||||
max_file_size,
|
||||
raise_exception,
|
||||
|
@ -211,7 +250,11 @@ class SummaryRecord:
|
|||
if name in {item['tag'] for item in self._data_pool[plugin]}:
|
||||
entry = repr(f'{name}/{plugin}')
|
||||
logger.warning(f'{entry} has duplicate values. Only the newest one will be recorded.')
|
||||
self._data_pool[plugin].append(dict(tag=name, value=np_value))
|
||||
data = dict(tag=name, value=np_value)
|
||||
export_plugin = '{}_format'.format(plugin)
|
||||
if self._export_options is not None and export_plugin in self._export_options:
|
||||
data['export_option'] = self._export_options.get(export_plugin)
|
||||
self._data_pool[plugin].append(data)
|
||||
|
||||
elif plugin in ('train_lineage', 'eval_lineage', 'dataset_graph', 'custom_lineage_data'):
|
||||
_check_lineage_value(plugin, value)
|
||||
|
|
|
@ -15,12 +15,19 @@
|
|||
"""Writes events to disk in a logdir."""
|
||||
import os
|
||||
import stat
|
||||
from urllib.parse import quote
|
||||
from shutil import disk_usage
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum
|
||||
|
||||
from .._utils import _make_directory
|
||||
from ..._c_expression import EventWriter_
|
||||
from ._summary_adapter import package_init_event
|
||||
|
||||
FREE_DISK_SPACE_TIMES = 32
|
||||
FILE_MODE = 0o600
|
||||
|
||||
|
||||
class BaseWriter:
|
||||
|
@ -79,11 +86,11 @@ class SummaryWriter(BaseWriter):
|
|||
|
||||
def init_writer(self):
|
||||
"""Write some metadata etc."""
|
||||
self.write('summary', package_init_event().SerializeToString())
|
||||
self.write(WriterPluginEnum.SUMMARY.value, package_init_event().SerializeToString())
|
||||
|
||||
def write(self, plugin, data):
|
||||
"""Write data to file."""
|
||||
if plugin in ('summary', 'graph'):
|
||||
if plugin in (WriterPluginEnum.SUMMARY.value, PluginEnum.GRAPH.value):
|
||||
super().write(plugin, data)
|
||||
|
||||
|
||||
|
@ -92,7 +99,8 @@ class LineageWriter(BaseWriter):
|
|||
|
||||
def write(self, plugin, data):
|
||||
"""Write data to file."""
|
||||
if plugin in ('dataset_graph', 'train_lineage', 'eval_lineage', 'custom_lineage_data'):
|
||||
if plugin in (PluginEnum.DATASET_GRAPH.value, PluginEnum.TRAIN_LINEAGE.value, PluginEnum.EVAL_LINEAGE.value,
|
||||
PluginEnum.CUSTOM_LINEAGE_DATA.value):
|
||||
super().write(plugin, data)
|
||||
|
||||
|
||||
|
@ -101,5 +109,64 @@ class ExplainWriter(BaseWriter):
|
|||
|
||||
def write(self, plugin, data):
|
||||
"""Write data to file."""
|
||||
if plugin == 'explainer':
|
||||
if plugin == WriterPluginEnum.EXPLAINER.value:
|
||||
super().write(plugin, data)
|
||||
|
||||
|
||||
class ExportWriter(BaseWriter):
|
||||
"""ExportWriter for export data."""
|
||||
|
||||
def write(self, plugin, data):
|
||||
"""Write data to file."""
|
||||
if plugin == WriterPluginEnum.EXPORTER.value:
|
||||
self.export_data(data, data.get('export_option'))
|
||||
|
||||
def flush(self):
|
||||
"""Flush the writer."""
|
||||
|
||||
def close(self):
|
||||
"""Close the writer."""
|
||||
|
||||
def export_data(self, data, export_option):
|
||||
"""
|
||||
export the tensor data.
|
||||
|
||||
Args:
|
||||
data (dict): Export data info.
|
||||
export_option (Union[None, str]): The export options.
|
||||
"""
|
||||
options = {
|
||||
'npy': self._export_npy
|
||||
}
|
||||
|
||||
if export_option in options:
|
||||
options[export_option](data, self._filepath, self._max_file_size)
|
||||
|
||||
@staticmethod
|
||||
def _export_npy(data, export_dir, max_file_size):
|
||||
"""
|
||||
export the tensor data as npy.
|
||||
|
||||
Args:
|
||||
data (dict): Export data info.
|
||||
export_dir (str): The path of export dir.
|
||||
max_file_size (Optional[int]): The maximum size in bytes of each file that can be written to the disk.
|
||||
"""
|
||||
tag = quote(data.get('tag'), safe="")
|
||||
step = int(data.get('step'))
|
||||
np_value = data.get('value')
|
||||
path = _make_directory(os.path.join(export_dir, 'tensor'))
|
||||
|
||||
# 128 is the typical length of header of npy file
|
||||
metadata_length = 128
|
||||
required_length = np_value.nbytes + metadata_length
|
||||
if disk_usage(path).free < required_length * FREE_DISK_SPACE_TIMES:
|
||||
raise RuntimeError(f"The disk space may be soon exhausted by the '{path}'.")
|
||||
|
||||
if max_file_size is not None and max_file_size < required_length:
|
||||
raise RuntimeWarning(f"'max_file_size' reached: There are {max_file_size} bytes remaining, "
|
||||
f"but the '{path}' requires to write {required_length} bytes.")
|
||||
|
||||
np_path = "{}/{}_{}.npy".format(path, tag, step)
|
||||
np.save(np_path, np_value)
|
||||
os.chmod(np_path, FILE_MODE)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -170,13 +170,17 @@ class TestSummary:
|
|||
def test_summarycollector_user_defind(self):
|
||||
"""Test SummaryCollector with user defind."""
|
||||
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=2,
|
||||
custom_lineage_data={'test': 'self test'})
|
||||
custom_lineage_data={'test': 'self test'},
|
||||
export_options={'tensor_format': 'npy'})
|
||||
|
||||
tag_list = self._list_summary_tags(summary_dir)
|
||||
file_list = self._list_tensor_files(summary_dir)
|
||||
# There will not record input data when dataset sink mode is True
|
||||
expected_tags = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto',
|
||||
'fc2.weight/auto', 'loss/auto', 'histogram', 'image', 'scalar', 'tensor'}
|
||||
assert set(expected_tags) == set(tag_list)
|
||||
expected_files = {'tensor_1.npy'}
|
||||
assert set(expected_files) == set(file_list)
|
||||
|
||||
@staticmethod
|
||||
def _list_summary_tags(summary_dir):
|
||||
|
@ -198,3 +202,21 @@ class TestSummary:
|
|||
for value in summary_event.summary.value:
|
||||
tags.append(value.tag)
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
def _list_tensor_files(summary_dir):
|
||||
"""list tensor tags."""
|
||||
export_file_path = ''
|
||||
for file in os.listdir(summary_dir):
|
||||
if re.search("export_", file):
|
||||
export_file_path = os.path.join(summary_dir, file)
|
||||
break
|
||||
assert export_file_path
|
||||
tensor_file_path = os.path.join(export_file_path, "tensor")
|
||||
assert tensor_file_path
|
||||
|
||||
tensors = list()
|
||||
for file in os.listdir(tensor_file_path):
|
||||
tensors.append(file)
|
||||
|
||||
return tensors
|
||||
|
|
|
@ -143,6 +143,18 @@ class TestSummaryCollector:
|
|||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("export_options", [123])
|
||||
def test_params_with_export_options_type_error(self, export_options):
|
||||
"""Test type error scenario for collect specified data param."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir, export_options=export_options)
|
||||
|
||||
expected_msg = f"For `export_options` the type should be a valid type of ['dict', 'NoneType'], " \
|
||||
f"but got {type(export_options).__name__}."
|
||||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("collect_specified_data", [
|
||||
{
|
||||
123: 123
|
||||
|
@ -204,6 +216,15 @@ class TestSummaryCollector:
|
|||
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported"
|
||||
assert expected_msg in str(exc.value)
|
||||
|
||||
def test_params_with_export_options_unexpected_key(self):
|
||||
"""Test the export_options parameter with unexpected key."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
data = {'unexpected_key': "value"}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryCollector(summary_dir, export_options=data)
|
||||
expected_msg = f"For `export_options` the keys {set(data)} are unsupported"
|
||||
assert expected_msg in str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("custom_lineage_data", [
|
||||
123,
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue