forked from OSSInnovation/mindspore
!3296 Fix collecting bert network name faild in MindInsight lineage.
Merge pull request !3296 from ougongchang/fix_lineage_bug
This commit is contained in:
commit
6ea74a3669
|
@ -73,7 +73,8 @@ class SummaryCollector(Callback):
|
||||||
summary_dir (str): The collected data will be persisted to this directory.
|
summary_dir (str): The collected data will be persisted to this directory.
|
||||||
If the directory does not exist, it will be created automatically.
|
If the directory does not exist, it will be created automatically.
|
||||||
collect_freq (int): Set the frequency of data collection, it should be greater then zero,
|
collect_freq (int): Set the frequency of data collection, it should be greater then zero,
|
||||||
and the unit is `step`. Default: 10. The first step will be recorded at any time.
|
and the unit is `step`. Default: 10. If a frequency is set, we will collect data
|
||||||
|
at (current steps % freq) == 0, and the first step will be collected at any time.
|
||||||
It is important to note that if the data sink mode is used, the unit will become the `epoch`.
|
It is important to note that if the data sink mode is used, the unit will become the `epoch`.
|
||||||
It is not recommended to collect data too frequently, which can affect performance.
|
It is not recommended to collect data too frequently, which can affect performance.
|
||||||
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
|
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
|
||||||
|
@ -593,7 +594,7 @@ class SummaryCollector(Callback):
|
||||||
else:
|
else:
|
||||||
train_lineage[LineageMetadata.learning_rate] = None
|
train_lineage[LineageMetadata.learning_rate] = None
|
||||||
train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None
|
train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None
|
||||||
train_lineage[LineageMetadata.train_network] = self._get_backbone(cb_params.train_network)
|
train_lineage[LineageMetadata.train_network] = type(cb_params.network).__name__
|
||||||
|
|
||||||
loss_fn = self._get_loss_fn(cb_params)
|
loss_fn = self._get_loss_fn(cb_params)
|
||||||
train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None
|
train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None
|
||||||
|
@ -750,30 +751,6 @@ class SummaryCollector(Callback):
|
||||||
|
|
||||||
return ckpt_file_path
|
return ckpt_file_path
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_backbone(network):
|
|
||||||
"""
|
|
||||||
Get the name of backbone network.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): The train network.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[str, None], If parse success, will return the name of the backbone network, else return None.
|
|
||||||
"""
|
|
||||||
backbone_name = None
|
|
||||||
backbone_key = '_backbone'
|
|
||||||
|
|
||||||
for _, cell in network.cells_and_names():
|
|
||||||
if hasattr(cell, backbone_key):
|
|
||||||
backbone_network = getattr(cell, backbone_key)
|
|
||||||
backbone_name = type(backbone_network).__name__
|
|
||||||
|
|
||||||
if backbone_name is None and network is not None:
|
|
||||||
backbone_name = type(network).__name__
|
|
||||||
|
|
||||||
return backbone_name
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_loss_fn(cb_params):
|
def _get_loss_fn(cb_params):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -355,6 +355,7 @@ class Model:
|
||||||
cb_params.train_dataset = train_dataset
|
cb_params.train_dataset = train_dataset
|
||||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||||
cb_params.train_dataset_element = None
|
cb_params.train_dataset_element = None
|
||||||
|
cb_params.network = self._network
|
||||||
ms_role = os.getenv("MS_ROLE")
|
ms_role = os.getenv("MS_ROLE")
|
||||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||||
epoch = 1
|
epoch = 1
|
||||||
|
@ -660,6 +661,7 @@ class Model:
|
||||||
cb_params.mode = "eval"
|
cb_params.mode = "eval"
|
||||||
cb_params.cur_step_num = 0
|
cb_params.cur_step_num = 0
|
||||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||||
|
cb_params.network = self._network
|
||||||
|
|
||||||
self._eval_network.set_train(mode=False)
|
self._eval_network.set_train(mode=False)
|
||||||
self._eval_network.phase = 'eval'
|
self._eval_network.phase = 'eval'
|
||||||
|
|
Loading…
Reference in New Issue