forked from mindspore-Ecosystem/mindspore
restore the ability to collect network graph
This commit is contained in:
parent
514c5d98b2
commit
879409ef97
|
@ -182,6 +182,7 @@ class SummaryCollector(Callback):
|
|||
self._custom_lineage_data = custom_lineage_data
|
||||
|
||||
self._temp_optimizer = None
|
||||
self._has_saved_train_network = False
|
||||
self._has_saved_custom_data = False
|
||||
self._is_parse_loss_success = True
|
||||
self._first_step = True
|
||||
|
@ -215,7 +216,7 @@ class SummaryCollector(Callback):
|
|||
@staticmethod
|
||||
def _check_positive(name, value, allow_none=False):
|
||||
"""Check if the value to be int type and positive."""
|
||||
if allow_none:
|
||||
if allow_none and value is None:
|
||||
return
|
||||
check_value_type(name, value, int)
|
||||
if value <= 0:
|
||||
|
@ -294,8 +295,9 @@ class SummaryCollector(Callback):
|
|||
|
||||
self._collect_dataset_graph(cb_params)
|
||||
if self._collect_tensor_freq is None:
|
||||
default_tensor_summary_limit = 50
|
||||
total_step = cb_params.epoch_num * cb_params.batch_num
|
||||
self._collect_tensor_freq = max(self._collect_freq, total_step // 50)
|
||||
self._collect_tensor_freq = max(self._collect_freq, total_step // default_tensor_summary_limit)
|
||||
|
||||
if self._custom_lineage_data and not self._has_saved_custom_data:
|
||||
packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data)
|
||||
|
@ -309,6 +311,8 @@ class SummaryCollector(Callback):
|
|||
cb_params = run_context.original_args()
|
||||
if cb_params.mode != ModeEnum.TRAIN.value:
|
||||
return
|
||||
if not self._has_saved_train_network:
|
||||
self._collect_graphs(cb_params)
|
||||
if self._first_step:
|
||||
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
|
||||
self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num
|
||||
|
@ -424,6 +428,7 @@ class SummaryCollector(Callback):
|
|||
if graph_proto is None:
|
||||
return
|
||||
|
||||
self._has_saved_train_network = True
|
||||
self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto)
|
||||
|
||||
def _collect_metric(self, cb_params):
|
||||
|
|
Loading…
Reference in New Issue