From 336fca14bc488ff6e44e05f2c4ac06bcc77e4d5a Mon Sep 17 00:00:00 2001 From: ougongchang Date: Tue, 21 Jul 2020 22:08:31 +0800 Subject: [PATCH] Fix collecting bert network name faild in MindInsight lineage. 1. collect the origin network in model, and set it to cb_params 2. collect the origin network name in SummaryCollector 3. Update the SummaryCollector API Doc --- .../train/callback/_summary_collector.py | 29 ++----------------- mindspore/train/model.py | 2 ++ 2 files changed, 5 insertions(+), 26 deletions(-) diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 6681e2d13a..c76e27d699 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -73,7 +73,8 @@ class SummaryCollector(Callback): summary_dir (str): The collected data will be persisted to this directory. If the directory does not exist, it will be created automatically. collect_freq (int): Set the frequency of data collection, it should be greater then zero, - and the unit is `step`. Default: 10. The first step will be recorded at any time. + and the unit is `step`. Default: 10. If a frequency is set, we will collect data + at (current steps % freq) == 0, and the first step will be collected at any time. It is important to note that if the data sink mode is used, the unit will become the `epoch`. It is not recommended to collect data too frequently, which can affect performance. collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None. @@ -593,7 +594,7 @@ class SummaryCollector(Callback): else: train_lineage[LineageMetadata.learning_rate] = None train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None - train_lineage[LineageMetadata.train_network] = self._get_backbone(cb_params.train_network) + train_lineage[LineageMetadata.train_network] = type(cb_params.network).__name__ loss_fn = self._get_loss_fn(cb_params) train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None @@ -750,30 +751,6 @@ class SummaryCollector(Callback): return ckpt_file_path - @staticmethod - def _get_backbone(network): - """ - Get the name of backbone network. - - Args: - network (Cell): The train network. - - Returns: - Union[str, None], If parse success, will return the name of the backbone network, else return None. - """ - backbone_name = None - backbone_key = '_backbone' - - for _, cell in network.cells_and_names(): - if hasattr(cell, backbone_key): - backbone_network = getattr(cell, backbone_key) - backbone_name = type(backbone_network).__name__ - - if backbone_name is None and network is not None: - backbone_name = type(network).__name__ - - return backbone_name - @staticmethod def _get_loss_fn(cb_params): """ diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 54128c66ce..0726108028 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -355,6 +355,7 @@ class Model: cb_params.train_dataset = train_dataset cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.train_dataset_element = None + cb_params.network = self._network ms_role = os.getenv("MS_ROLE") if ms_role in ("MS_PSERVER", "MS_SCHED"): epoch = 1 @@ -660,6 +661,7 @@ class Model: cb_params.mode = "eval" cb_params.cur_step_num = 0 cb_params.list_callback = self._transform_callbacks(callbacks) + cb_params.network = self._network self._eval_network.set_train(mode=False) self._eval_network.phase = 'eval'