diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index ded0e9a650b..6681e2d13ae 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -108,6 +108,10 @@ class SummaryCollector(Callback): custom_lineage_data (Union[dict, None]): Allows you to customize the data and present it on the MingInsight lineage page. In the custom data, the key type support str, and the value type support str/int/float. Default: None, it means there is no custom data. + collect_tensor_freq (Optional[int]): Same as the `collect_freq`, but controls TensorSummary specifically. + Default: None, which means the frequency is auto-calculated just to collect at most 50 steps TensorSummary. + max_file_size (Optional[int]): The maximum size in bytes each file can be written to the disk. + Default: None, which means no limit. Raises: ValueError: If the parameter value is not expected. @@ -145,16 +149,28 @@ class SummaryCollector(Callback): 'histogram_regular': None } - def __init__(self, summary_dir, collect_freq=10, collect_specified_data=None, - keep_default_action=True, custom_lineage_data=None): + def __init__(self, + summary_dir, + collect_freq=10, + collect_specified_data=None, + keep_default_action=True, + custom_lineage_data=None, + collect_tensor_freq=None, + max_file_size=None): super(SummaryCollector, self).__init__() self._summary_dir = self._process_summary_dir(summary_dir) self._record = None - self._check_collect_freq(collect_freq) + self._check_positive('collect_freq', collect_freq) self._collect_freq = collect_freq + self._check_positive('collect_tensor_freq', collect_tensor_freq, allow_none=True) + self._collect_tensor_freq = collect_tensor_freq + + self._check_positive('max_file_size', max_file_size, allow_none=True) + self._max_file_size = max_file_size + self._check_action(keep_default_action) self._collect_specified_data = self._process_specified_data(collect_specified_data, keep_default_action) @@ -165,16 +181,14 @@ class SummaryCollector(Callback): self._custom_lineage_data = custom_lineage_data self._temp_optimizer = None - self._has_saved_train_network = False self._has_saved_custom_data = False self._is_parse_loss_success = True self._first_step = True self._dataset_sink_mode = True def __enter__(self): - self._first_step = True - self._dataset_sink_mode = True - self._record = SummaryRecord(log_dir=self._summary_dir) + self._record = SummaryRecord(log_dir=self._summary_dir, max_file_size=self._max_file_size) + self._first_step, self._dataset_sink_mode = True, True return self def __exit__(self, *err): @@ -198,11 +212,13 @@ class SummaryCollector(Callback): return summary_dir @staticmethod - def _check_collect_freq(freq): - """Check collect freq type and value.""" - check_value_type('collect_freq', freq, int) - if freq <= 0: - raise ValueError(f'For `collect_freq` the value should be greater than 0, but got `{freq}`.') + def _check_positive(name, value, allow_none=False): + """Check if the value to be int type and positive.""" + if allow_none: + return + check_value_type(name, value, int) + if value <= 0: + raise ValueError(f'For `{name}` the value should be greater than 0, but got `{value}`.') @staticmethod def _check_custom_lineage_data(custom_lineage_data): @@ -276,6 +292,9 @@ class SummaryCollector(Callback): self._collect_graphs(cb_params) self._collect_dataset_graph(cb_params) + if self._collect_tensor_freq is None: + total_step = cb_params.epoch_num * cb_params.batch_num + self._collect_tensor_freq = max(self._collect_freq, total_step // 50) if self._custom_lineage_data and not self._has_saved_custom_data: packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data) @@ -287,24 +306,29 @@ class SummaryCollector(Callback): def step_end(self, run_context): cb_params = run_context.original_args() + if cb_params.mode != ModeEnum.TRAIN.value: + return if self._first_step: # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario - self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num) + self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num + self._collect_at_step_end(cb_params, plugin_filter=None) + self._first_step = False + else: + current = cb_params.cur_epoch_num if self._dataset_sink_mode else cb_params.cur_step_num + if current % self._collect_freq == 0 and current % self._collect_tensor_freq == 0: + self._collect_at_step_end(cb_params, plugin_filter=None) + elif current % self._collect_tensor_freq == 0: + self._collect_at_step_end(cb_params, lambda plugin: plugin == PluginEnum.TENSOR.value) + elif current % self._collect_freq == 0: + self._collect_at_step_end(cb_params, lambda plugin: plugin != PluginEnum.TENSOR.value) - if cb_params.mode == ModeEnum.TRAIN.value: - if not self._is_collect_this_step(cb_params): - return + def _collect_at_step_end(self, cb_params, plugin_filter): + self._collect_input_data(cb_params) + self._collect_metric(cb_params) + self._collect_histogram(cb_params) + self._record.record(cb_params.cur_step_num, plugin_filter=plugin_filter) - if not self._has_saved_train_network: - self._collect_graphs(cb_params) - - self._collect_input_data(cb_params) - self._collect_metric(cb_params) - self._collect_histogram(cb_params) - - self._first_step = False - self._record.record(cb_params.cur_step_num) def end(self, run_context): cb_params = run_context.original_args() @@ -331,18 +355,6 @@ class SummaryCollector(Callback): raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," f"but expected only one {self.__class__.__name__} instance.") - def _is_collect_this_step(self, cb_params): - """Decide whether to collect data for the current step.""" - # Make sure the first step data is recorded - if not self._first_step: - if self._dataset_sink_mode: - if cb_params.cur_epoch_num % self._collect_freq: - return False - else: - if cb_params.cur_step_num % self._collect_freq: - return False - return True - @staticmethod def _package_custom_lineage_data(custom_lineage_data): """ @@ -411,7 +423,6 @@ class SummaryCollector(Callback): if graph_proto is None: return - self._has_saved_train_network = True self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto) def _collect_metric(self, cb_params): diff --git a/mindspore/train/summary/_summary_writer.py b/mindspore/train/summary/_summary_writer.py index a5648fc94e2..622282dcf6e 100644 --- a/mindspore/train/summary/_summary_writer.py +++ b/mindspore/train/summary/_summary_writer.py @@ -24,8 +24,8 @@ from ._summary_adapter import package_init_event class BaseWriter: """BaseWriter to be subclass.""" - def __init__(self, filepath) -> None: - self._filepath = filepath + def __init__(self, filepath, max_file_size=None) -> None: + self._filepath, self._max_file_size = filepath, max_file_size self._writer: EventWriter_ = None def init_writer(self): @@ -46,8 +46,15 @@ class BaseWriter: def write(self, plugin, data): """Write data to file.""" if self.writer and disk_usage(self._filepath).free < len(data) * 32: - raise RuntimeError('The disk space may be soon exhausted.') - self.writer.Write(data) + raise RuntimeError(f'The disk space may be soon exhausted by the {type(self).__name__}.') + if self._max_file_size is None: + self.writer.Write(data) + elif self._max_file_size > 0: + self._max_file_size -= len(data) + self.writer.Write(data) + else: + raise RuntimeError(f"The file written by the {type(self).__name__} " + f"has exceeded the specified max file size.") def flush(self): """Flush the writer.""" diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py index d0cf998b30c..6a653f349da 100644 --- a/mindspore/train/summary/_writer_pool.py +++ b/mindspore/train/summary/_writer_pool.py @@ -51,10 +51,11 @@ class WriterPool(Process): filelist (str): The mapping from short name to long filename. """ - def __init__(self, base_dir, **filedict) -> None: + def __init__(self, base_dir, max_file_size, **filedict) -> None: super().__init__() self._base_dir, self._filedict = base_dir, filedict self._queue, self._writers_ = Queue(cpu_count() * 2), None + self._max_file_size = max_file_size self.start() def run(self): @@ -88,9 +89,9 @@ class WriterPool(Process): for plugin, filename in self._filedict.items(): filepath = os.path.join(self._base_dir, filename) if plugin == 'summary': - self._writers_.append(SummaryWriter(filepath)) + self._writers_.append(SummaryWriter(filepath, self._max_file_size)) elif plugin == 'lineage': - self._writers_.append(LineageWriter(filepath)) + self._writers_.append(LineageWriter(filepath, self._max_file_size)) return self._writers_ def _write(self, plugin, data): @@ -98,9 +99,8 @@ class WriterPool(Process): for writer in self._writers[:]: try: writer.write(plugin, data) - except RuntimeError: - logger.warning(f'The disk space may be soon exhausted by this {type(writer).__name__}, ' - 'so the writer will be closed and not for further writing.') + except RuntimeError as e: + logger.warning(e.args[0]) self._writers.remove(writer) writer.close() diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 2bc605797fc..48948079ab8 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -75,14 +75,17 @@ class SummaryRecord: Args: log_dir (str): The log_dir is a directory location to save the summary. - queue_max_size (int): The capacity of event queue.(reserved). Default: 0. - flush_time (int): Frequency to flush the summaries to disk, the unit is second. Default: 120. + queue_max_size (int): Deprecated. The capacity of event queue.(reserved). Default: 0. + flush_time (int): Deprecated. Frequency to flush the summaries to disk, the unit is second. Default: 120. file_prefix (str): The prefix of file. Default: "events". file_suffix (str): The suffix of file. Default: "_MS". network (Cell): Obtain a pipeline through network for saving graph summary. Default: None. + max_file_size (Optional[int]): The maximum size in bytes each file can be written to the disk. \ + Unlimited by default. Raises: - TypeError: If `queue_max_size` and `flush_time` is not int, or `file_prefix` and `file_suffix` is not str. + TypeError: If `max_file_size`, `queue_max_size` or `flush_time` is not int, \ + or `file_prefix` and `file_suffix` is not str. RuntimeError: If the log_dir can not be resolved to a canonicalized absolute pathname. Examples: @@ -103,7 +106,8 @@ class SummaryRecord: flush_time=120, file_prefix="events", file_suffix="_MS", - network=None): + network=None, + max_file_size=None): self._closed, self._event_writer = False, None self._mode, self._data_pool = 'train', _dictlist() @@ -113,11 +117,18 @@ class SummaryRecord: self.log_path = _make_directory(log_dir) + if not isinstance(max_file_size, (int, type(None))): + raise TypeError("The 'max_file_size' should be int type.") + if not isinstance(queue_max_size, int) or not isinstance(flush_time, int): raise TypeError("`queue_max_size` and `flush_time` should be int") if not isinstance(file_prefix, str) or not isinstance(file_suffix, str): raise TypeError("`file_prefix` and `file_suffix` should be str.") + if max_file_size is not None and max_file_size < 0: + logger.warning("The 'max_file_size' should be greater than 0.") + max_file_size = None + self.queue_max_size = queue_max_size if queue_max_size < 0: # 0 is not limit @@ -142,6 +153,7 @@ class SummaryRecord: raise RuntimeError(ex) self._event_writer = WriterPool(log_dir, + max_file_size, summary=self.full_file_name, lineage=get_event_file_name('events', '_lineage')) atexit.register(self.close) @@ -152,7 +164,7 @@ class SummaryRecord: raise ValueError('SummaryRecord has been closed.') return self - def __exit__(self, extype, exvalue, traceback): + def __exit__(self, *err): """Exit the context manager.""" self.close() @@ -229,13 +241,15 @@ class SummaryRecord: else: raise ValueError(f'No such plugin of {repr(plugin)}') - def record(self, step, train_network=None): + def record(self, step, train_network=None, plugin_filter=None): """ Record the summary. Args: step (int): Represents training step number. train_network (Cell): The network that called the callback. + plugin_filter (Optional[Callable[[str], bool]]): The filter function, \ + which is used to filter out plugins from being written by return False. Returns: bool, whether the record process is successful or not. @@ -266,7 +280,14 @@ class SummaryRecord: if self._mode == 'train': self._add_summary_tensor_data() - self._event_writer.write(self._consume_data_pool(step)) + if not plugin_filter: + self._event_writer.write(self._consume_data_pool(step)) + else: + filtered = {} + for plugin, datalist in self._consume_data_pool(step).items(): + if plugin_filter(plugin): + filtered[plugin] = datalist + self._event_writer.write(filtered) return True def _add_summary_tensor_data(self):