forked from mindspore-Ecosystem/mindspore
Add device id to summary file name
To prevent write data conflicts in multi-card scenarios, the file on each card is increased by device_id
This commit is contained in:
parent
c2d9e1f396
commit
3240b2d8e1
|
@ -33,7 +33,13 @@ class DatasetGraph:
|
|||
DatasetGraph, a object of lineage_pb2.DatasetGraph.
|
||||
"""
|
||||
dataset_package = import_module('mindspore.dataset')
|
||||
dataset_dict = dataset_package.serialize(dataset)
|
||||
try:
|
||||
dataset_dict = dataset_package.serialize(dataset)
|
||||
except (TypeError, OSError) as exc:
|
||||
logger.warning("Summary can not collect dataset graph, there is an error in dataset internal, "
|
||||
"detail: %s.", str(exc))
|
||||
return None
|
||||
|
||||
dataset_graph_proto = lineage_pb2.DatasetGraph()
|
||||
if not isinstance(dataset_dict, dict):
|
||||
logger.warning("The dataset graph serialized from dataset object is not a dict. "
|
||||
|
|
|
@ -518,6 +518,8 @@ class SummaryCollector(Callback):
|
|||
train_dataset = cb_params.train_dataset
|
||||
dataset_graph = DatasetGraph()
|
||||
graph_bytes = dataset_graph.package_dataset_graph(train_dataset)
|
||||
if graph_bytes is None:
|
||||
return
|
||||
self._record.add_value('dataset_graph', 'train_dataset', graph_bytes)
|
||||
|
||||
def _collect_graphs(self, cb_params):
|
||||
|
|
|
@ -20,6 +20,8 @@ import numpy as np
|
|||
from PIL import Image
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import get_rank
|
||||
|
||||
from ..._checkparam import Validator
|
||||
from ..anf_ir_pb2 import DataType, ModelProto
|
||||
|
@ -53,10 +55,18 @@ def get_event_file_name(prefix, suffix, time_second):
|
|||
file_name = ""
|
||||
hostname = platform.node()
|
||||
|
||||
if prefix is not None:
|
||||
file_name = file_name + prefix
|
||||
device_num = context.get_auto_parallel_context('device_num')
|
||||
device_id = context.get_context('device_id')
|
||||
if device_num > 1:
|
||||
# Notice:
|
||||
# In GPU distribute training scene, get_context('device_id') will not work,
|
||||
# so we use get_rank instead of get_context.
|
||||
device_id = get_rank()
|
||||
|
||||
file_name = file_name + EVENT_FILE_NAME_MARK + time_second + "." + hostname
|
||||
file_name = f'{file_name}{EVENT_FILE_NAME_MARK}{time_second}.{device_id}.{hostname}'
|
||||
|
||||
if prefix is not None:
|
||||
file_name = prefix + file_name
|
||||
|
||||
if suffix is not None:
|
||||
file_name = file_name + suffix
|
||||
|
|
|
@ -97,6 +97,11 @@ class WriterPool(ctx.Process):
|
|||
with ctx.Pool(min(ctx.cpu_count(), 32)) as pool:
|
||||
deq = deque()
|
||||
while True:
|
||||
if not self._writers:
|
||||
logger.warning("Can not find any writer to write summary data, "
|
||||
"so SummaryRecord will not record data.")
|
||||
break
|
||||
|
||||
while deq and deq[0].ready():
|
||||
for plugin, data in deq.popleft().get():
|
||||
self._write(plugin, data)
|
||||
|
|
|
@ -112,8 +112,8 @@ class SummaryRecord:
|
|||
network (Cell): Obtain a pipeline through network for saving graph summary. Default: None.
|
||||
max_file_size (int, optional): The maximum size of each file that can be written to disk (in bytes). \
|
||||
Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`.
|
||||
raise_exception (bool, optional): Sets whether to throw an exception when an RuntimeError exception occurs
|
||||
in recording data. Default: False, this means that error logs are printed and no exception is thrown.
|
||||
raise_exception (bool, optional): Sets whether to throw an exception when a RuntimeError or OSError exception
|
||||
occurs in recording data. Default: False, this means that error logs are printed and no exception is thrown.
|
||||
export_options (Union[None, dict]): Perform custom operations on the export data.
|
||||
Default: None, it means there is no export data.
|
||||
Note that the size of export files is not limited by the max_file_size.
|
||||
|
@ -177,7 +177,7 @@ class SummaryRecord:
|
|||
if self._export_options is not None:
|
||||
export_dir = "export_{}".format(time_second)
|
||||
|
||||
filename_dict = dict(summary=self.full_file_name,
|
||||
filename_dict = dict(summary=self.event_file_name,
|
||||
lineage=get_event_file_name(self.prefix, '_lineage', time_second),
|
||||
explainer=get_event_file_name(self.prefix, '_explain', time_second),
|
||||
exporter=export_dir)
|
||||
|
|
Loading…
Reference in New Issue