Add device id to summary file name

To prevent write data conflicts in multi-card scenarios, the file on each card is increased by device_id
This commit is contained in:
ougongchang 2021-02-02 15:49:22 +08:00
parent c2d9e1f396
commit 3240b2d8e1
5 changed files with 30 additions and 7 deletions

View File

@ -33,7 +33,13 @@ class DatasetGraph:
DatasetGraph, a object of lineage_pb2.DatasetGraph.
"""
dataset_package = import_module('mindspore.dataset')
dataset_dict = dataset_package.serialize(dataset)
try:
dataset_dict = dataset_package.serialize(dataset)
except (TypeError, OSError) as exc:
logger.warning("Summary can not collect dataset graph, there is an error in dataset internal, "
"detail: %s.", str(exc))
return None
dataset_graph_proto = lineage_pb2.DatasetGraph()
if not isinstance(dataset_dict, dict):
logger.warning("The dataset graph serialized from dataset object is not a dict. "

View File

@ -518,6 +518,8 @@ class SummaryCollector(Callback):
train_dataset = cb_params.train_dataset
dataset_graph = DatasetGraph()
graph_bytes = dataset_graph.package_dataset_graph(train_dataset)
if graph_bytes is None:
return
self._record.add_value('dataset_graph', 'train_dataset', graph_bytes)
def _collect_graphs(self, cb_params):

View File

@ -20,6 +20,8 @@ import numpy as np
from PIL import Image
from mindspore import log as logger
from mindspore import context
from mindspore.communication.management import get_rank
from ..._checkparam import Validator
from ..anf_ir_pb2 import DataType, ModelProto
@ -53,10 +55,18 @@ def get_event_file_name(prefix, suffix, time_second):
file_name = ""
hostname = platform.node()
if prefix is not None:
file_name = file_name + prefix
device_num = context.get_auto_parallel_context('device_num')
device_id = context.get_context('device_id')
if device_num > 1:
# Notice:
# In GPU distribute training scene, get_context('device_id') will not work,
# so we use get_rank instead of get_context.
device_id = get_rank()
file_name = file_name + EVENT_FILE_NAME_MARK + time_second + "." + hostname
file_name = f'{file_name}{EVENT_FILE_NAME_MARK}{time_second}.{device_id}.{hostname}'
if prefix is not None:
file_name = prefix + file_name
if suffix is not None:
file_name = file_name + suffix

View File

@ -97,6 +97,11 @@ class WriterPool(ctx.Process):
with ctx.Pool(min(ctx.cpu_count(), 32)) as pool:
deq = deque()
while True:
if not self._writers:
logger.warning("Can not find any writer to write summary data, "
"so SummaryRecord will not record data.")
break
while deq and deq[0].ready():
for plugin, data in deq.popleft().get():
self._write(plugin, data)

View File

@ -112,8 +112,8 @@ class SummaryRecord:
network (Cell): Obtain a pipeline through network for saving graph summary. Default: None.
max_file_size (int, optional): The maximum size of each file that can be written to disk (in bytes). \
Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`.
raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs
in recording data. Default: False, this means that error logs are printed and no exception is thrown.
raise_exception (bool, optional): Sets whether to throw an exception when a RuntimeError or OSError exception
occurs in recording data. Default: False, this means that error logs are printed and no exception is thrown.
export_options (Union[None, dict]): Perform custom operations on the export data.
Default: None, it means there is no export data.
Note that the size of export files is not limited by the max_file_size.
@ -177,7 +177,7 @@ class SummaryRecord:
if self._export_options is not None:
export_dir = "export_{}".format(time_second)
filename_dict = dict(summary=self.full_file_name,
filename_dict = dict(summary=self.event_file_name,
lineage=get_event_file_name(self.prefix, '_lineage', time_second),
explainer=get_event_file_name(self.prefix, '_explain', time_second),
exporter=export_dir)