!3296 Fix collecting bert network name faild in MindInsight lineage.

Merge pull request !3296 from ougongchang/fix_lineage_bug
This commit is contained in:
mindspore-ci-bot 2020-07-22 09:40:38 +08:00 committed by Gitee
commit 6ea74a3669
2 changed files with 5 additions and 26 deletions

View File

@ -73,7 +73,8 @@ class SummaryCollector(Callback):
summary_dir (str): The collected data will be persisted to this directory. summary_dir (str): The collected data will be persisted to this directory.
If the directory does not exist, it will be created automatically. If the directory does not exist, it will be created automatically.
collect_freq (int): Set the frequency of data collection, it should be greater then zero, collect_freq (int): Set the frequency of data collection, it should be greater then zero,
and the unit is `step`. Default: 10. The first step will be recorded at any time. and the unit is `step`. Default: 10. If a frequency is set, we will collect data
at (current steps % freq) == 0, and the first step will be collected at any time.
It is important to note that if the data sink mode is used, the unit will become the `epoch`. It is important to note that if the data sink mode is used, the unit will become the `epoch`.
It is not recommended to collect data too frequently, which can affect performance. It is not recommended to collect data too frequently, which can affect performance.
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None. collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
@ -593,7 +594,7 @@ class SummaryCollector(Callback):
else: else:
train_lineage[LineageMetadata.learning_rate] = None train_lineage[LineageMetadata.learning_rate] = None
train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None
train_lineage[LineageMetadata.train_network] = self._get_backbone(cb_params.train_network) train_lineage[LineageMetadata.train_network] = type(cb_params.network).__name__
loss_fn = self._get_loss_fn(cb_params) loss_fn = self._get_loss_fn(cb_params)
train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None
@ -750,30 +751,6 @@ class SummaryCollector(Callback):
return ckpt_file_path return ckpt_file_path
@staticmethod
def _get_backbone(network):
"""
Get the name of backbone network.
Args:
network (Cell): The train network.
Returns:
Union[str, None], If parse success, will return the name of the backbone network, else return None.
"""
backbone_name = None
backbone_key = '_backbone'
for _, cell in network.cells_and_names():
if hasattr(cell, backbone_key):
backbone_network = getattr(cell, backbone_key)
backbone_name = type(backbone_network).__name__
if backbone_name is None and network is not None:
backbone_name = type(network).__name__
return backbone_name
@staticmethod @staticmethod
def _get_loss_fn(cb_params): def _get_loss_fn(cb_params):
""" """

View File

@ -355,6 +355,7 @@ class Model:
cb_params.train_dataset = train_dataset cb_params.train_dataset = train_dataset
cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.train_dataset_element = None cb_params.train_dataset_element = None
cb_params.network = self._network
ms_role = os.getenv("MS_ROLE") ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"): if ms_role in ("MS_PSERVER", "MS_SCHED"):
epoch = 1 epoch = 1
@ -660,6 +661,7 @@ class Model:
cb_params.mode = "eval" cb_params.mode = "eval"
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.network = self._network
self._eval_network.set_train(mode=False) self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval' self._eval_network.phase = 'eval'