forked from OSSInnovation/mindspore
!3296 Fix collecting bert network name faild in MindInsight lineage.
Merge pull request !3296 from ougongchang/fix_lineage_bug
This commit is contained in:
commit
6ea74a3669
|
@ -73,7 +73,8 @@ class SummaryCollector(Callback):
|
|||
summary_dir (str): The collected data will be persisted to this directory.
|
||||
If the directory does not exist, it will be created automatically.
|
||||
collect_freq (int): Set the frequency of data collection, it should be greater then zero,
|
||||
and the unit is `step`. Default: 10. The first step will be recorded at any time.
|
||||
and the unit is `step`. Default: 10. If a frequency is set, we will collect data
|
||||
at (current steps % freq) == 0, and the first step will be collected at any time.
|
||||
It is important to note that if the data sink mode is used, the unit will become the `epoch`.
|
||||
It is not recommended to collect data too frequently, which can affect performance.
|
||||
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
|
||||
|
@ -593,7 +594,7 @@ class SummaryCollector(Callback):
|
|||
else:
|
||||
train_lineage[LineageMetadata.learning_rate] = None
|
||||
train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None
|
||||
train_lineage[LineageMetadata.train_network] = self._get_backbone(cb_params.train_network)
|
||||
train_lineage[LineageMetadata.train_network] = type(cb_params.network).__name__
|
||||
|
||||
loss_fn = self._get_loss_fn(cb_params)
|
||||
train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None
|
||||
|
@ -750,30 +751,6 @@ class SummaryCollector(Callback):
|
|||
|
||||
return ckpt_file_path
|
||||
|
||||
@staticmethod
|
||||
def _get_backbone(network):
|
||||
"""
|
||||
Get the name of backbone network.
|
||||
|
||||
Args:
|
||||
network (Cell): The train network.
|
||||
|
||||
Returns:
|
||||
Union[str, None], If parse success, will return the name of the backbone network, else return None.
|
||||
"""
|
||||
backbone_name = None
|
||||
backbone_key = '_backbone'
|
||||
|
||||
for _, cell in network.cells_and_names():
|
||||
if hasattr(cell, backbone_key):
|
||||
backbone_network = getattr(cell, backbone_key)
|
||||
backbone_name = type(backbone_network).__name__
|
||||
|
||||
if backbone_name is None and network is not None:
|
||||
backbone_name = type(network).__name__
|
||||
|
||||
return backbone_name
|
||||
|
||||
@staticmethod
|
||||
def _get_loss_fn(cb_params):
|
||||
"""
|
||||
|
|
|
@ -355,6 +355,7 @@ class Model:
|
|||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.train_dataset_element = None
|
||||
cb_params.network = self._network
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
epoch = 1
|
||||
|
@ -660,6 +661,7 @@ class Model:
|
|||
cb_params.mode = "eval"
|
||||
cb_params.cur_step_num = 0
|
||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.network = self._network
|
||||
|
||||
self._eval_network.set_train(mode=False)
|
||||
self._eval_network.phase = 'eval'
|
||||
|
|
Loading…
Reference in New Issue