forked from mindspore-Ecosystem/mindspore
!2147 Add a callback named SummaryCollector and delete SummaryStep callback
Merge pull request !2147 from ougongchang/master
This commit is contained in:
commit
9514b52a23
|
@ -14,7 +14,10 @@
|
|||
# ============================================================================
|
||||
"""Train utility."""
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -213,6 +216,7 @@ def _check_to_numpy(plugin, tensor):
|
|||
raise ValueError('The tensor should not be empty.')
|
||||
return np_value
|
||||
|
||||
|
||||
def _check_lineage_value(plugin, value):
|
||||
"""Check the lineage value."""
|
||||
def raises(plugin, prototype):
|
||||
|
@ -229,3 +233,20 @@ def _check_lineage_value(plugin, value):
|
|||
|
||||
if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo):
|
||||
raises(plugin, UserDefinedInfo)
|
||||
|
||||
|
||||
def check_value_type(arg_name, arg_value, valid_types):
|
||||
"""Checks whether a value is instance of some types."""
|
||||
valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,)
|
||||
is_valid = True
|
||||
|
||||
# bool is subclass of int, so for a bool value, we need to extra check
|
||||
if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types:
|
||||
is_valid = False
|
||||
|
||||
if not isinstance(arg_value, valid_types):
|
||||
is_valid = False
|
||||
|
||||
if not is_valid:
|
||||
raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
|
||||
f'bug got {type(arg_value).__name__}.')
|
||||
|
|
|
@ -22,7 +22,8 @@ from ._checkpoint import CheckpointConfig
|
|||
from ._checkpoint import CheckpointManager as _CheckpointManager
|
||||
from ._checkpoint import ModelCheckpoint
|
||||
from ._loss_monitor import LossMonitor
|
||||
from ._summary_step import SummaryStep
|
||||
from ._time_monitor import TimeMonitor
|
||||
from ._summary_collector import SummaryCollector
|
||||
|
||||
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]
|
||||
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
|
||||
"SummaryCollector", "CheckpointConfig", "RunContext"]
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define dataset graph related operations."""
|
||||
import json
|
||||
from importlib import import_module
|
||||
|
||||
from mindspore.train import lineage_pb2
|
||||
|
||||
|
||||
class DatasetGraph:
|
||||
"""Handle the data graph and packages it into binary data."""
|
||||
def package_dataset_graph(self, dataset):
|
||||
"""
|
||||
packages dataset graph into binary data
|
||||
|
||||
Args:
|
||||
dataset (MindData): refer to MindDataset
|
||||
|
||||
Returns:
|
||||
DatasetGraph, a object of lineage_pb2.DatasetGraph.
|
||||
"""
|
||||
dataset_package = import_module('mindspore.dataset')
|
||||
dataset_dict = dataset_package.serialize(dataset)
|
||||
json_str = json.dumps(dataset_dict, indent=2)
|
||||
dataset_dict = json.loads(json_str)
|
||||
dataset_graph_proto = lineage_pb2.DatasetGraph()
|
||||
if "children" in dataset_dict:
|
||||
children = dataset_dict.pop("children")
|
||||
if children:
|
||||
self._package_children(children=children, message=dataset_graph_proto)
|
||||
self._package_current_dataset(operation=dataset_dict, message=dataset_graph_proto)
|
||||
return dataset_graph_proto
|
||||
|
||||
def _package_children(self, children, message):
|
||||
"""
|
||||
Package children in dataset operation.
|
||||
|
||||
Args:
|
||||
children (list[dict]): Child operations.
|
||||
message (DatasetGraph): Children proto message.
|
||||
"""
|
||||
for child in children:
|
||||
if child:
|
||||
child_graph_message = getattr(message, "children").add()
|
||||
grandson = child.pop("children")
|
||||
if grandson:
|
||||
self._package_children(children=grandson, message=child_graph_message)
|
||||
# package other parameters
|
||||
self._package_current_dataset(operation=child, message=child_graph_message)
|
||||
|
||||
def _package_current_dataset(self, operation, message):
|
||||
"""
|
||||
Package operation parameters in event message.
|
||||
|
||||
Args:
|
||||
operation (dict): Operation dict.
|
||||
message (Operation): Operation proto message.
|
||||
"""
|
||||
for key, value in operation.items():
|
||||
if value and key == "operations":
|
||||
for operator in value:
|
||||
self._package_enhancement_operation(
|
||||
operator,
|
||||
message.operations.add()
|
||||
)
|
||||
elif value and key == "sampler":
|
||||
self._package_enhancement_operation(
|
||||
value,
|
||||
message.sampler
|
||||
)
|
||||
else:
|
||||
self._package_parameter(key, value, message.parameter)
|
||||
|
||||
def _package_enhancement_operation(self, operation, message):
|
||||
"""
|
||||
Package enhancement operation in MapDataset.
|
||||
|
||||
Args:
|
||||
operation (dict): Enhancement operation.
|
||||
message (Operation): Enhancement operation proto message.
|
||||
"""
|
||||
for key, value in operation.items():
|
||||
if isinstance(value, list):
|
||||
if all(isinstance(ele, int) for ele in value):
|
||||
message.size.extend(value)
|
||||
else:
|
||||
message.weights.extend(value)
|
||||
else:
|
||||
self._package_parameter(key, value, message.operationParam)
|
||||
|
||||
@staticmethod
|
||||
def _package_parameter(key, value, message):
|
||||
"""
|
||||
Package parameters in operation.
|
||||
|
||||
Args:
|
||||
key (str): Operation name.
|
||||
value (Union[str, bool, int, float, list, None]): Operation args.
|
||||
message (OperationParameter): Operation proto message.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
message.mapStr[key] = value
|
||||
elif isinstance(value, bool):
|
||||
message.mapBool[key] = value
|
||||
elif isinstance(value, int):
|
||||
message.mapInt[key] = value
|
||||
elif isinstance(value, float):
|
||||
message.mapDouble[key] = value
|
||||
elif isinstance(value, list) and key != "operations":
|
||||
if value:
|
||||
replace_value_list = list(map(lambda x: "" if x is None else x, value))
|
||||
message.mapStrList[key].strValue.extend(replace_value_list)
|
||||
elif value is None:
|
||||
message.mapStr[key] = "None"
|
||||
else:
|
||||
raise ValueError(f"Parameter {key} is not supported in event package.")
|
|
@ -0,0 +1,786 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Summary collector callback."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from mindspore.train.summary.enum import PluginEnum, ModeEnum
|
||||
from mindspore.train.callback import Callback, ModelCheckpoint
|
||||
from mindspore.train import lineage_pb2
|
||||
from mindspore.train.callback._dataset_graph import DatasetGraph
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.train._utils import check_value_type
|
||||
|
||||
|
||||
class LineageMetadata:
|
||||
"""Initialize parameters used in model lineage management."""
|
||||
train_dataset_path = 'train_dataset_path'
|
||||
valid_dataset_path = 'valid_dataset_path'
|
||||
train_network = 'train_network'
|
||||
loss_function = 'loss_function'
|
||||
loss = 'loss'
|
||||
optimizer = 'optimizer'
|
||||
learning_rate = 'learning_rate'
|
||||
epoch = 'epoch'
|
||||
step_num = 'step_num'
|
||||
parallel_mode = 'parallel_mode'
|
||||
device_num = 'device_num'
|
||||
batch_size = 'batch_size'
|
||||
model_path = 'model_path'
|
||||
model_ckpt = 'model_ckpt'
|
||||
model_size = 'model_size'
|
||||
metrics = 'metrics'
|
||||
train_dataset_size = 'train_dataset_size'
|
||||
valid_dataset_size = 'valid_dataset_size'
|
||||
|
||||
|
||||
class SummaryCollector(Callback):
|
||||
"""
|
||||
SummaryCollector can help you to collect some common information.
|
||||
|
||||
It can help you to collect loss, learning late, computational graph and so on.
|
||||
SummaryCollector also persists data collected by the summary operator into a summary file.
|
||||
|
||||
Note:
|
||||
1. Multiple SummaryCollector instances in callback list are not allowed.
|
||||
2. Not all information is collected at the training phase or at the eval phase.
|
||||
3. SummaryCollector always record the data collected by the summary operator.
|
||||
|
||||
Args:
|
||||
summary_dir (str): The collected data will be persisted to this directory.
|
||||
If the directory does not exist, it will be created automatically.
|
||||
collect_freq (int): Set the frequency of data collection, it should be greater then zero,
|
||||
and the unit is `step`. Default: 10.
|
||||
It is important to note that if the data sink mode is used, the unit will become the `epoch`.
|
||||
It is not recommended to collect data too frequently, which can affect performance.
|
||||
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
|
||||
By default, if set to None, all data is collected as the default behavior.
|
||||
If you want to customize the data collected, you can do so with a dictionary.
|
||||
Examples,you can set {'collect_metric': False} to control not collecting metrics.
|
||||
The data that supports control is shown below.
|
||||
|
||||
- collect_metric: Whether to collect training metrics, currently only loss is collected.
|
||||
Optional: True/False. Default: True.
|
||||
- collect_graph: Whether to collect computational graph, currently only
|
||||
training computational graph is collected. Optional: True/False. Default: True.
|
||||
- collect_train_lineage: Whether to collect lineage data for the training phase,
|
||||
this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.
|
||||
- collect_eval_lineage: Whether to collect lineage data for the eval phase,
|
||||
this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.
|
||||
- collect_input_data: Whether to collect dataset for each training. Currently only image data is supported.
|
||||
Optional: True/False. Default: True.
|
||||
- collect_dataset_graph: Whether to collect dataset graph for the training phase.
|
||||
Optional: True/False. Default: True.
|
||||
- histogram_regular: Collect weight and bias for parameter distribution page display in MindInsight.
|
||||
This field allows regular strings to control which parameters to collect.
|
||||
Default: None, it means only the first five parameters are collected.
|
||||
It is not recommended to collect too many parameters at once, as it can affect performance.
|
||||
Note that if you collect too many parameters and run out of memory, the training will fail.
|
||||
keep_default_action (bool): This field affects the collection behavior of the 'collect_specified_data' field.
|
||||
Optional: True/False, Default: True.
|
||||
True: means that after specified data is set, non-specified data is collected as the default behavior.
|
||||
False: means that after specified data is set, only the specified data is collected,
|
||||
and the others are not collected.
|
||||
custom_lineage_data (Union[dict, None]): Allows you to customize the data and present it on the MingInsight
|
||||
lineage page. In the custom data, the key type support str, and the value type support str/int/float.
|
||||
Default: None, it means there is no custom data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the parameter value is not expected.
|
||||
TypeError: If the parameter type is not expected.
|
||||
RuntimeError: If an error occurs during data collection.
|
||||
|
||||
Examples:
|
||||
>>> # Simple usage:
|
||||
>>> summary_collector = SummaryCollector(summary_dir='./summary_dir')
|
||||
>>> model.train(epoch, dataset, callbacks=summary_collector)
|
||||
>>>
|
||||
>>> # Do not collect metric and collect the first layer parameter, others are collected by default
|
||||
>>> specified={'collect_metric': False, 'histogram_regular': '^conv1.*'}
|
||||
>>> summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_specified_data=specified)
|
||||
>>> model.train(epoch, dataset, callbacks=summary_collector)
|
||||
>>>
|
||||
>>> # Only collect metric, custom lineage data and record data that collected by the summary operator,
|
||||
>>> # others are not collected
|
||||
>>> specified = {'collect_metric':True, 'custom_lineage_data': {'version': 'resnet50_v1'}}
|
||||
>>> summary_collector = SummaryCollector('./summary_dir',
|
||||
>>> collect_specified_data=specified,
|
||||
>>> keep_default_action=False)
|
||||
>>> model.train(epoch, dataset, callbacks=summary_collector)
|
||||
"""
|
||||
|
||||
_DEFAULT_SPECIFIED_DATA = {
|
||||
'collect_metric': True,
|
||||
'collect_graph': True,
|
||||
'collect_train_lineage': True,
|
||||
'collect_eval_lineage': True,
|
||||
'collect_input_data': True,
|
||||
'collect_dataset_graph': True,
|
||||
'histogram_regular': None
|
||||
}
|
||||
|
||||
# _OPTIMIZER_FAILED means find optimizer failed, so we will not collect data about optimizer.
|
||||
_OPTIMIZER_FAILED = 'Failed'
|
||||
|
||||
def __init__(self, summary_dir, collect_freq=10, collect_specified_data=None,
|
||||
keep_default_action=True, custom_lineage_data=None):
|
||||
super(SummaryCollector, self).__init__()
|
||||
|
||||
self._summary_dir = self._process_summary_dir(summary_dir)
|
||||
self._record = None
|
||||
|
||||
self._check_collect_freq(collect_freq)
|
||||
self._collect_freq = collect_freq
|
||||
|
||||
self._check_action(keep_default_action)
|
||||
|
||||
self._collect_specified_data = self._process_specified_data(collect_specified_data, keep_default_action)
|
||||
logger.info(f"For `collect_specified_data` the value after processing is: {self._collect_specified_data}.")
|
||||
|
||||
self._check_custom_lineage_data(custom_lineage_data)
|
||||
self._custom_lineage_data = custom_lineage_data
|
||||
|
||||
self._optimizer = None
|
||||
self._has_saved_train_network = False
|
||||
self._has_saved_custom_data = False
|
||||
self._is_parse_loss_success = True
|
||||
|
||||
def __enter__(self):
|
||||
self._record = SummaryRecord(log_dir=self._summary_dir)
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
self._record.close()
|
||||
|
||||
@staticmethod
|
||||
def _process_summary_dir(summary_dir):
|
||||
"""Check the summary dir, and create a new directory if it not exists."""
|
||||
check_value_type('summary_dir', summary_dir, str)
|
||||
summary_dir = summary_dir.strip()
|
||||
if not summary_dir:
|
||||
raise ValueError('For `summary_dir` the value should be a valid string of path, but got empty string.')
|
||||
|
||||
summary_dir = os.path.realpath(summary_dir)
|
||||
if not os.path.exists(summary_dir):
|
||||
os.makedirs(summary_dir, exist_ok=True)
|
||||
else:
|
||||
if not os.path.isdir(summary_dir):
|
||||
raise NotADirectoryError('For `summary_dir` it should be a directory path.')
|
||||
|
||||
return summary_dir
|
||||
|
||||
@staticmethod
|
||||
def _check_collect_freq(freq):
|
||||
"""Check collect freq type and value."""
|
||||
check_value_type('collect_freq', freq, int)
|
||||
if freq <= 0:
|
||||
raise ValueError(f'For `collect_freq` the value should be greater than 0, but got `{freq}`.')
|
||||
|
||||
@staticmethod
|
||||
def _check_custom_lineage_data(custom_lineage_data):
|
||||
"""
|
||||
Check user custom lineage data.
|
||||
|
||||
Args:
|
||||
custom_lineage_data (dict): The user custom defined data.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of parameters is invalid.
|
||||
"""
|
||||
if custom_lineage_data is None:
|
||||
return
|
||||
|
||||
check_value_type('custom_lineage_data', custom_lineage_data, [dict, type(None)])
|
||||
for key, value in custom_lineage_data.items():
|
||||
check_value_type(f'custom_lineage_data -> {key}', key, str)
|
||||
check_value_type(f'the value of custom_lineage_data -> {key}', value, (int, str, float))
|
||||
|
||||
@staticmethod
|
||||
def _check_action(action):
|
||||
"""Check action type."""
|
||||
check_value_type('keep_default_action', action, bool)
|
||||
|
||||
def _process_specified_data(self, specified_data, action):
|
||||
"""Check specified data type and value."""
|
||||
if specified_data is None:
|
||||
if action:
|
||||
return self._DEFAULT_SPECIFIED_DATA
|
||||
return None
|
||||
|
||||
check_value_type('collect_specified_data', specified_data, [dict, type(None)])
|
||||
|
||||
for param_name in specified_data:
|
||||
check_value_type(param_name, param_name, [str])
|
||||
|
||||
unexpected_params = set(specified_data) - set(self._DEFAULT_SPECIFIED_DATA)
|
||||
if unexpected_params:
|
||||
raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported.')
|
||||
|
||||
if 'histogram_regular' in specified_data:
|
||||
check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None)))
|
||||
|
||||
bool_items = set(self._DEFAULT_SPECIFIED_DATA) - {'histogram_regular'}
|
||||
for item in bool_items:
|
||||
if item in specified_data:
|
||||
check_value_type(item, specified_data.get(item), bool)
|
||||
|
||||
if action:
|
||||
result = dict(self._DEFAULT_SPECIFIED_DATA).update(specified_data)
|
||||
else:
|
||||
result = specified_data
|
||||
return result
|
||||
|
||||
def begin(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
self._check_callbacks(cb_params)
|
||||
|
||||
if cb_params.mode not in ModeEnum.to_list():
|
||||
raise ValueError('Only support `train` (model.train) and `eval` (model.eval) mode, '
|
||||
'but got `{cb_params.mode}` mode.')
|
||||
|
||||
self._record.set_mode(cb_params.mode)
|
||||
if cb_params.mode == ModeEnum.TRAIN.value:
|
||||
# Note: if model.init is not executed then the computed graph will not be obtained here
|
||||
# The purpose of recording the graph here was to collect_freq if it was set to a large size,
|
||||
# but also want to see the graph as soon after compilation.
|
||||
self._collect_graphs(cb_params)
|
||||
|
||||
self._collect_dataset_graph(cb_params)
|
||||
|
||||
if self._custom_lineage_data and not self._has_saved_custom_data:
|
||||
packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data)
|
||||
self._record.add_value('custom_lineage_data', 'custom_lineage_data', packaged_custom_data)
|
||||
self._has_saved_custom_data = True
|
||||
|
||||
# There's nothing special about setting step to 0 here, just to satisfy the interface call
|
||||
self._record.record(step=0)
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
if cb_params.mode == ModeEnum.TRAIN.value:
|
||||
if cb_params.cur_step_num % self._collect_freq:
|
||||
return
|
||||
|
||||
if not self._has_saved_train_network:
|
||||
self._collect_graphs(cb_params)
|
||||
|
||||
self._collect_input_data(cb_params)
|
||||
self._collect_metric(cb_params)
|
||||
self._collect_histogram(cb_params)
|
||||
|
||||
self._record.record(cb_params.cur_step_num)
|
||||
|
||||
def end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.mode == ModeEnum.TRAIN.value:
|
||||
self._collect_train_lineage(cb_params)
|
||||
else:
|
||||
self._collect_eval_lineage(cb_params)
|
||||
|
||||
# There's nothing special about setting step to 0 here, just to satisfy the interface call
|
||||
self._record.record(step=0)
|
||||
|
||||
def _check_callbacks(self, cb_params):
|
||||
"""Check there if there are duplicate instances of SummaryCollector."""
|
||||
callbacks = cb_params.list_callback
|
||||
|
||||
is_find = False
|
||||
for callback in callbacks:
|
||||
if type(callback).__name__ == self.__class__.__name__:
|
||||
if not is_find:
|
||||
is_find = True
|
||||
continue
|
||||
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
|
||||
f"but expected only one {self.__class__.__name__} instance.")
|
||||
|
||||
@staticmethod
|
||||
def _package_custom_lineage_data(custom_lineage_data):
|
||||
"""
|
||||
Package user-defined lineage data into binary data.
|
||||
|
||||
Args:
|
||||
custom_lineage_data (dict): User custom lineage data.
|
||||
|
||||
Returns:
|
||||
UserDefinedInfo, a object of lineage_pb2.UserDefinedInfo.
|
||||
"""
|
||||
user_defined_info = lineage_pb2.UserDefinedInfo()
|
||||
for key, value in custom_lineage_data.items():
|
||||
if isinstance(value, int):
|
||||
attr_name = "map_int32"
|
||||
elif isinstance(value, float):
|
||||
attr_name = "map_double"
|
||||
else:
|
||||
attr_name = "map_str"
|
||||
|
||||
user_info = user_defined_info.user_info.add()
|
||||
getattr(user_info, attr_name)[key] = value
|
||||
|
||||
return user_defined_info
|
||||
|
||||
def _collect_input_data(self, cb_params):
|
||||
"""Only support to collect image data."""
|
||||
if not self._collect_specified_data.get('collect_input_data'):
|
||||
return
|
||||
|
||||
input_data = getattr(cb_params, 'train_dataset_element', None)
|
||||
if input_data is None:
|
||||
self._collect_specified_data['collect_input_data'] = False
|
||||
logger.info("There is not a `train_dataset_element` in cb_params.")
|
||||
return
|
||||
|
||||
if isinstance(input_data, (list, tuple)):
|
||||
input_data = input_data[0]
|
||||
try:
|
||||
self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data)
|
||||
except ValueError:
|
||||
self._collect_specified_data['collect_input_data'] = False
|
||||
return
|
||||
|
||||
def _collect_dataset_graph(self, cb_params):
|
||||
"""Only collect train dataset graph."""
|
||||
if not self._collect_specified_data.get('collect_dataset_graph'):
|
||||
return
|
||||
|
||||
# After analysis, we think that the validated dataset graph and the training dataset graph
|
||||
# should be consistent under normal scenarios, so only the training dataset graph is collected.
|
||||
if cb_params.mode == ModeEnum.TRAIN.value:
|
||||
train_dataset = cb_params.train_dataset
|
||||
dataset_graph = DatasetGraph()
|
||||
graph_bytes = dataset_graph.package_dataset_graph(train_dataset)
|
||||
self._record.add_value('dataset_graph', 'train_dataset', graph_bytes)
|
||||
|
||||
def _collect_graphs(self, cb_params):
|
||||
"""Collect the graph of train network and eval network."""
|
||||
if not self._collect_specified_data.get('collect_graph'):
|
||||
return
|
||||
|
||||
network = cb_params.train_network if cb_params.mode == ModeEnum.TRAIN.value else cb_params.eval_network
|
||||
graph_proto = network.get_func_graph_proto()
|
||||
if graph_proto is None:
|
||||
return
|
||||
|
||||
self._has_saved_train_network = True
|
||||
self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto)
|
||||
|
||||
def _collect_metric(self, cb_params):
|
||||
"""Collect metric, currently only collection Loss is supported."""
|
||||
if not self._collect_specified_data.get('collect_metric'):
|
||||
return
|
||||
|
||||
loss = self._get_loss(cb_params)
|
||||
if loss is None:
|
||||
return
|
||||
self._record.add_value(PluginEnum.SCALAR.value, 'loss/auto', loss)
|
||||
|
||||
def _get_loss(self, cb_params):
|
||||
"""
|
||||
Get loss from the network output.
|
||||
|
||||
Args:
|
||||
cb_params (_InternalCallbackParam): Callback parameters.
|
||||
|
||||
Returns:
|
||||
Union[Tensor, None], if parse loss success, will return a Tensor value(shape is [1]), else return None.
|
||||
"""
|
||||
if not self._is_parse_loss_success:
|
||||
# If parsing has failed before, avoid repeating it
|
||||
return None
|
||||
|
||||
output = cb_params.net_outputs
|
||||
if output is None:
|
||||
logger.warning("Can not find any output by this network.")
|
||||
self._is_parse_loss_success = False
|
||||
return None
|
||||
|
||||
if isinstance(output, (int, float)):
|
||||
loss = output
|
||||
elif isinstance(output, (list, tuple)):
|
||||
# If the output is a list, since the default network returns loss first,
|
||||
# we assume that the first one is loss.
|
||||
loss = output[0]
|
||||
elif isinstance(output, Tensor) and (not output.shape or output.shape == [1]):
|
||||
loss_numpy = output.asnumpy()
|
||||
loss = float(np.atleast_1d(loss_numpy)[0])
|
||||
else:
|
||||
logger.warning("The output type could not be identified, so no loss was recorded in SummaryCollector.")
|
||||
self._is_parse_loss_success = False
|
||||
return None
|
||||
|
||||
if not isinstance(loss, Tensor):
|
||||
loss = Tensor(loss)
|
||||
|
||||
return loss
|
||||
|
||||
def _get_optimizer(self, cb_params):
|
||||
"""
|
||||
Get optimizer from the cb_params or parse from the network.
|
||||
|
||||
Args:
|
||||
cb_params (_InternalCallbackParam): Callback parameters.
|
||||
|
||||
Returns:
|
||||
Union[Optimizer, None], if parse optimizer success, will return a optimizer, else return None.
|
||||
"""
|
||||
if self._optimizer == self._OPTIMIZER_FAILED:
|
||||
return None
|
||||
|
||||
if self._optimizer is not None:
|
||||
return self._optimizer
|
||||
|
||||
optimizer = cb_params.optimizer
|
||||
if optimizer is None:
|
||||
network = cb_params.train_network if cb_params.mode == 'train' else cb_params.eval_work
|
||||
optimizer = self._parse_optimizer_by_network(network)
|
||||
|
||||
if optimizer is None or not isinstance(optimizer, Optimizer):
|
||||
logger.warning("Can not find optimizer in network, or the optimizer does not inherit Mindpore's optimizer, "
|
||||
"so we will not collect data about optimizer in SummaryCollector.")
|
||||
optimizer = self._OPTIMIZER_FAILED
|
||||
|
||||
return optimizer
|
||||
|
||||
@staticmethod
|
||||
def _parse_optimizer_by_network(network):
|
||||
"""Parse optimizer from network, if parse success will return a optimizer, else return None."""
|
||||
optimizer = None
|
||||
for _, cell in network.cells_and_names():
|
||||
try:
|
||||
optimizer = getattr(cell, 'optimizer')
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
continue
|
||||
|
||||
# Optimizer found successfully
|
||||
break
|
||||
|
||||
return optimizer
|
||||
|
||||
def _collect_histogram(self, cb_params):
|
||||
"""Collect histogram data, contain the parameter weight and bias."""
|
||||
# Note: if there is not a key named `histogram_regular` in `self._collect_specified_data`,
|
||||
# it means we will not collect histogram data.
|
||||
if 'histogram_regular' not in self._collect_specified_data:
|
||||
return
|
||||
|
||||
self._optimizer = self._get_optimizer(cb_params)
|
||||
if self._optimizer is None:
|
||||
return
|
||||
|
||||
parameters = self._optimizer.parameters
|
||||
regular = self._collect_specified_data.get('histogram_regular')
|
||||
if regular is not None:
|
||||
for parameter in parameters:
|
||||
if re.match(regular, parameter.name):
|
||||
self._record.add_value(PluginEnum.HISTOGRAM.value, parameter.name+'/auto', parameter.data)
|
||||
return
|
||||
|
||||
# Note: If `histogram_regular` in `self._collect_specified_data` and the value is None,
|
||||
# we will collect the first five parameters.
|
||||
default_parameter_count = 5
|
||||
for parameter in parameters[:default_parameter_count]:
|
||||
self._record.add_value(PluginEnum.HISTOGRAM.value, parameter.name+'/auto', parameter.data)
|
||||
|
||||
@staticmethod
|
||||
def _get_learning_rate(optimizer):
|
||||
"""
|
||||
parse the learning rate from optimizer.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): A optimizer which inherit the MindSpore Optimizer class.
|
||||
|
||||
Returns:
|
||||
Union[Tensor, None], if parse learning rate success, will return a Tensor, else return None.
|
||||
"""
|
||||
learning_rate = optimizer.learning_rate
|
||||
if not isinstance(learning_rate, Parameter):
|
||||
logger.info("The learning rate detected in the optimizer is not a Parameter type, so it is not recorded.")
|
||||
return None
|
||||
return learning_rate.data
|
||||
|
||||
def _collect_train_lineage(self, cb_params):
|
||||
"""Collect train lineage data, the detail refer to lineage_pb2.TrainLineage."""
|
||||
if not self._collect_specified_data.get('collect_train_lineage'):
|
||||
return
|
||||
train_lineage = {}
|
||||
loss = self._get_loss(cb_params)
|
||||
if loss:
|
||||
loss_numpy = loss.asnumpy()
|
||||
loss = float(np.atleast_1d(loss_numpy)[0])
|
||||
train_lineage[LineageMetadata.loss] = loss
|
||||
else:
|
||||
train_lineage[LineageMetadata.loss] = None
|
||||
|
||||
optimizer = self._get_optimizer(cb_params)
|
||||
learning_rate = self._get_learning_rate(optimizer)
|
||||
|
||||
if learning_rate is not None:
|
||||
train_lineage[LineageMetadata.learning_rate] = list(np.atleast_1d(learning_rate.asnumpy()))[0]
|
||||
else:
|
||||
train_lineage[LineageMetadata.learning_rate] = None
|
||||
train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None
|
||||
train_lineage[LineageMetadata.train_network] = self._get_backbone(cb_params.train_network)
|
||||
|
||||
loss_fn = self._get_loss_fn(cb_params)
|
||||
train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None
|
||||
|
||||
train_lineage[LineageMetadata.epoch] = cb_params.epoch_num
|
||||
train_lineage[LineageMetadata.step_num] = cb_params.cur_step_num
|
||||
train_lineage[LineageMetadata.parallel_mode] = cb_params.parallel_mode
|
||||
train_lineage[LineageMetadata.device_num] = cb_params.device_number
|
||||
train_lineage[LineageMetadata.batch_size] = cb_params.batch_num
|
||||
|
||||
ckpt_file_path = self._get_ckpt_file_path(cb_params)
|
||||
train_lineage[LineageMetadata.model_path] = json.dumps(dict(ckpt=ckpt_file_path))
|
||||
|
||||
model_size = os.path.getsize(ckpt_file_path) if ckpt_file_path else 0
|
||||
train_lineage[LineageMetadata.model_size] = model_size
|
||||
|
||||
self._parse_dataset(cb_params, train_lineage)
|
||||
|
||||
train_lineage_message = self._package_train_lineage_message(train_lineage)
|
||||
|
||||
self._record.add_value(PluginEnum.TRAIN_LINEAGE.value, 'train_lineage', train_lineage_message)
|
||||
|
||||
@staticmethod
|
||||
def _package_train_lineage_message(train_lineage):
|
||||
"""
|
||||
Package train lineage data into binary data.
|
||||
|
||||
Args:
|
||||
train_lineage (dict): The train lineage dict, refer to the attribute of `_collect_train_lineage` method.
|
||||
|
||||
Returns:
|
||||
TrainLineage, a object of lineage_pb2.TrainLineage.
|
||||
"""
|
||||
lineage_message = lineage_pb2.TrainLineage()
|
||||
|
||||
if train_lineage.get(LineageMetadata.train_network) is not None:
|
||||
lineage_message.algorithm.network = train_lineage.get(LineageMetadata.train_network)
|
||||
if train_lineage.get(LineageMetadata.loss) is not None:
|
||||
lineage_message.algorithm.loss = train_lineage.get(LineageMetadata.loss)
|
||||
|
||||
# Construct train_dataset message.
|
||||
if train_lineage.get(LineageMetadata.train_dataset_path) is not None:
|
||||
lineage_message.train_dataset.train_dataset_path = train_lineage.get(LineageMetadata.train_dataset_path)
|
||||
if train_lineage.get(LineageMetadata.train_dataset_size) is not None:
|
||||
lineage_message.train_dataset.train_dataset_size = train_lineage.get(LineageMetadata.train_dataset_size)
|
||||
|
||||
# Construct model message
|
||||
lineage_message.model.path = train_lineage.get(LineageMetadata.model_path)
|
||||
lineage_message.model.size = train_lineage.get(LineageMetadata.model_size)
|
||||
|
||||
# Construct hyper_parameters message.
|
||||
if train_lineage.get(LineageMetadata.learning_rate) is not None:
|
||||
lineage_message.hyper_parameters.learning_rate = train_lineage.get(LineageMetadata.learning_rate)
|
||||
if train_lineage.get(LineageMetadata.optimizer) is not None:
|
||||
lineage_message.hyper_parameters.optimizer = train_lineage.get(LineageMetadata.optimizer)
|
||||
if train_lineage.get(LineageMetadata.loss_function) is not None:
|
||||
lineage_message.hyper_parameters.loss_function = train_lineage.get(LineageMetadata.loss_function)
|
||||
if train_lineage.get(LineageMetadata.parallel_mode) is not None:
|
||||
lineage_message.hyper_parameters.parallel_mode = train_lineage.get(LineageMetadata.parallel_mode)
|
||||
|
||||
lineage_message.hyper_parameters.epoch = train_lineage.get(LineageMetadata.epoch)
|
||||
lineage_message.hyper_parameters.device_num = train_lineage.get(LineageMetadata.device_num)
|
||||
lineage_message.hyper_parameters.batch_size = train_lineage.get(LineageMetadata.batch_size)
|
||||
|
||||
return lineage_message
|
||||
|
||||
def _parse_dataset(self, cb_params, lineage_dict):
|
||||
"""
|
||||
Analyze Dataset to get the dataset path and dataset size.
|
||||
|
||||
Args:
|
||||
cb_params (_InternalCallbackParam): Callback parameters.
|
||||
lineage_dict (dict): The lineage dict, refer to the attribute
|
||||
of `_collect_train_lineage` method or `_collect_eval_lineage`.
|
||||
|
||||
Returns:
|
||||
dict, the lineage metadata.
|
||||
"""
|
||||
dataset = cb_params.train_dataset if cb_params.mode == ModeEnum.TRAIN.value else cb_params.valid_dataset
|
||||
|
||||
try:
|
||||
dataset_path = self._get_dataset_path(dataset)
|
||||
except IndexError:
|
||||
dataset_path = None
|
||||
|
||||
if dataset_path and os.path.isfile(dataset_path):
|
||||
dataset_dir = os.path.dirname(dataset_path)
|
||||
else:
|
||||
dataset_dir = dataset_path
|
||||
|
||||
batch_num = dataset.get_dataset_size()
|
||||
batch_size = dataset.get_batch_size()
|
||||
dataset_size = int(batch_num * batch_size)
|
||||
|
||||
if cb_params.mode == ModeEnum.TRAIN.value:
|
||||
lineage_dict[LineageMetadata.train_dataset_path] = dataset_dir
|
||||
lineage_dict[LineageMetadata.train_dataset_size] = dataset_size
|
||||
else:
|
||||
lineage_dict[LineageMetadata.valid_dataset_path] = dataset_dir
|
||||
lineage_dict[LineageMetadata.valid_dataset_size] = dataset_size
|
||||
|
||||
return lineage_dict
|
||||
|
||||
def _get_dataset_path(self, output_dataset):
|
||||
"""
|
||||
Get dataset path of MindDataset object.
|
||||
|
||||
Args:
|
||||
output_dataset (Union[Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset,
|
||||
VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]):
|
||||
Refer to mindspore.dataset.Dataset.
|
||||
|
||||
Returns:
|
||||
str, dataset path.
|
||||
|
||||
Raises:
|
||||
IndexError: it means get dataset path failed.
|
||||
"""
|
||||
dataset_package = import_module('mindspore.dataset')
|
||||
dataset_dir_set = (dataset_package.ImageFolderDatasetV2, dataset_package.MnistDataset,
|
||||
dataset_package.Cifar10Dataset, dataset_package.Cifar100Dataset,
|
||||
dataset_package.VOCDataset, dataset_package.CelebADataset)
|
||||
dataset_file_set = (dataset_package.MindDataset, dataset_package.ManifestDataset)
|
||||
dataset_files_set = (dataset_package.TFRecordDataset, dataset_package.TextFileDataset)
|
||||
|
||||
if isinstance(output_dataset, dataset_file_set):
|
||||
return output_dataset.dataset_file
|
||||
if isinstance(output_dataset, dataset_dir_set):
|
||||
return output_dataset.dataset_dir
|
||||
if isinstance(output_dataset, dataset_files_set):
|
||||
return output_dataset.dataset_files[0]
|
||||
return self._get_dataset_path(output_dataset.input[0])
|
||||
|
||||
@staticmethod
|
||||
def _get_ckpt_file_path(cb_params):
|
||||
"""
|
||||
Get checkpoint file path from MindSpore callback list.
|
||||
|
||||
Args:
|
||||
cb_params (_InternalCallbackParam): Callback parameters.
|
||||
|
||||
Returns:
|
||||
Union[str, None], if parse success will checkpoint file absolute path, else return None.
|
||||
"""
|
||||
callbacks = cb_params.list_callback
|
||||
ckpt_file_path = None
|
||||
for callback in callbacks:
|
||||
if isinstance(callback, ModelCheckpoint):
|
||||
ckpt_file_path = callback.latest_ckpt_file_name
|
||||
|
||||
if ckpt_file_path:
|
||||
ckpt_file_path = os.path.realpath(ckpt_file_path)
|
||||
|
||||
return ckpt_file_path
|
||||
|
||||
@staticmethod
|
||||
def _get_backbone(network):
|
||||
"""
|
||||
Get the name of backbone network.
|
||||
|
||||
Args:
|
||||
network (Cell): The train network.
|
||||
|
||||
Returns:
|
||||
Union[str, None], If parse success, will return the name of the backbone network, else return None.
|
||||
"""
|
||||
backbone_name = None
|
||||
backbone_key = '_backbone'
|
||||
|
||||
for _, cell in network.cells_and_names():
|
||||
if hasattr(cell, backbone_key):
|
||||
backbone_network = getattr(cell, backbone_key)
|
||||
backbone_name = type(backbone_network).__name__
|
||||
|
||||
if backbone_name is None and network is not None:
|
||||
backbone_name = type(network).__name__
|
||||
|
||||
return backbone_name
|
||||
|
||||
@staticmethod
|
||||
def _get_loss_fn(cb_params):
|
||||
"""
|
||||
Get loss function by cb_params and analyzing network.
|
||||
|
||||
Args:
|
||||
cb_params (_InternalCallbackParam): Callback parameters.
|
||||
|
||||
Returns:
|
||||
Union[Loss_fn, None], a Cell object, if parse failed, will return None.
|
||||
"""
|
||||
loss_fn = cb_params.loss_fn
|
||||
if loss_fn is not None:
|
||||
return loss_fn
|
||||
|
||||
if cb_params.mode == ModeEnum.TRAIN.value:
|
||||
network = cb_params.train_network
|
||||
else:
|
||||
network = cb_params.eval_network
|
||||
|
||||
for _, cell in network.cells_and_names():
|
||||
if isinstance(cell, _Loss):
|
||||
loss_fn = cell
|
||||
break
|
||||
return loss_fn
|
||||
|
||||
def _collect_eval_lineage(self, cb_params):
|
||||
"""Collect eval lineage data, the detail refer to lineage_pb2.EvaluationLineage."""
|
||||
if not self._collect_specified_data.get('collect_eval_lineage'):
|
||||
return
|
||||
eval_lineage = dict()
|
||||
|
||||
eval_lineage[LineageMetadata.metrics] = json.dumps(cb_params.metrics)
|
||||
self._parse_dataset(cb_params, eval_lineage)
|
||||
|
||||
eval_lineage_message = self._package_eval_lineage_message(eval_lineage)
|
||||
self._record.add_value(PluginEnum.EVAL_LINEAGE.value, 'eval_lineage', eval_lineage_message)
|
||||
|
||||
@staticmethod
|
||||
def _package_eval_lineage_message(eval_lineage):
|
||||
"""
|
||||
Package eval lineage data into binary data.
|
||||
|
||||
Args:
|
||||
eval_lineage (dict): The eval lineage dict, refer to the attribute of `_collect_eval_lineage` method.
|
||||
|
||||
Returns:
|
||||
EvaluationLineage, a object of lineage_pb2.EvaluationLineage.
|
||||
"""
|
||||
lineage_message = lineage_pb2.EvaluationLineage()
|
||||
|
||||
if eval_lineage.get(LineageMetadata.metrics) is not None:
|
||||
lineage_message.metric = eval_lineage.get(LineageMetadata.metrics)
|
||||
if eval_lineage.get(LineageMetadata.valid_dataset_path) is not None:
|
||||
lineage_message.valid_dataset.valid_dataset_path = eval_lineage.get(LineageMetadata.valid_dataset_path)
|
||||
if eval_lineage.get(LineageMetadata.valid_dataset_size) is not None:
|
||||
lineage_message.valid_dataset.valid_dataset_size = eval_lineage.get(LineageMetadata.valid_dataset_size)
|
||||
|
||||
return lineage_message
|
|
@ -1,56 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""SummaryStep Callback class."""
|
||||
|
||||
from ._callback import Callback
|
||||
|
||||
|
||||
class SummaryStep(Callback):
|
||||
"""
|
||||
The summary callback class.
|
||||
|
||||
Args:
|
||||
summary (Object): Summary recode object.
|
||||
flush_step (int): Number of interval steps to execute. Default: 10.
|
||||
"""
|
||||
|
||||
def __init__(self, summary, flush_step=10):
|
||||
super(SummaryStep, self).__init__()
|
||||
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
|
||||
raise ValueError("`flush_step` should be int and greater than 0")
|
||||
self._summary = summary
|
||||
self._flush_step = flush_step
|
||||
|
||||
def __enter__(self):
|
||||
self._summary.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
return self._summary.__exit__(*err)
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Save summary.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % self._flush_step == 0:
|
||||
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
|
||||
|
||||
@property
|
||||
def summary_file_name(self):
|
||||
return self._summary.full_file_name
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Model."""
|
||||
from collections.abc import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
|
@ -345,7 +347,8 @@ class Model:
|
|||
cb_params.parallel_mode = self._parallel_mode
|
||||
cb_params.device_number = self._device_number
|
||||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = callbacks
|
||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.train_dataset_element = None
|
||||
|
||||
# build callback list
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
|
@ -358,6 +361,17 @@ class Model:
|
|||
else:
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||
|
||||
@staticmethod
|
||||
def _transform_callbacks(callbacks):
|
||||
"""Transform callback to a list."""
|
||||
if callbacks is None:
|
||||
return []
|
||||
|
||||
if isinstance(callbacks, Iterable):
|
||||
return list(callbacks)
|
||||
|
||||
return [callbacks]
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Training process. The data would be passed to network through dataset channel.
|
||||
|
@ -449,6 +463,7 @@ class Model:
|
|||
scaling_sens = self._get_scaling_sens()
|
||||
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
|
||||
|
||||
cb_params.train_dataset_element = next_element
|
||||
outputs = self._train_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
|
@ -628,6 +643,7 @@ class Model:
|
|||
cb_params.batch_num = valid_dataset.get_dataset_size()
|
||||
cb_params.mode = "eval"
|
||||
cb_params.cur_step_num = 0
|
||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
|
||||
self._eval_network.set_train(mode=False)
|
||||
self._eval_network.phase = 'eval'
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Summary's enumeration file."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BaseEnum(Enum):
|
||||
"""The base enum class."""
|
||||
|
||||
@classmethod
|
||||
def to_list(cls):
|
||||
"""Converts the enumeration into a list."""
|
||||
return [member.value for member in cls.__members__.values()]
|
||||
|
||||
|
||||
class PluginEnum(BaseEnum):
|
||||
"""The list of plugins currently supported by the summary."""
|
||||
GRAPH = 'graph'
|
||||
SCALAR = 'scalar'
|
||||
IMAGE = 'image'
|
||||
TENSOR = 'tensor'
|
||||
HISTOGRAM = 'histogram'
|
||||
TRAIN_LINEAGE = 'train_lineage'
|
||||
EVAL_LINEAGE = 'eval_lineage'
|
||||
DATASET_GRAPH = 'dataset_graph'
|
||||
|
||||
|
||||
class ModeEnum(BaseEnum):
|
||||
"""The modes currently supported by the summary."""
|
||||
TRAIN = 'train'
|
||||
EVAL = 'eval'
|
|
@ -75,7 +75,7 @@ class TestGpuSummary:
|
|||
if not os.path.exists(self.summary_dir):
|
||||
os.mkdir(self.summary_dir)
|
||||
|
||||
def teardown_emthod(self):
|
||||
def teardown_method(self):
|
||||
"""Run after method."""
|
||||
if os.path.exists(self.summary_dir):
|
||||
shutil.rmtree(self.summary_dir)
|
||||
|
|
|
@ -20,8 +20,8 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Model, context
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.callback import SummaryStep
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from mindspore.train.summary import SummaryRecord
|
||||
from mindspore.train.callback import SummaryCollector
|
||||
from .....dataset_mock import MindData
|
||||
|
||||
CUR_DIR = os.getcwd()
|
||||
|
@ -107,16 +107,9 @@ def test_graph_summary_sample():
|
|||
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer:
|
||||
model.train(2, dataset)
|
||||
# step 2: create the Event
|
||||
for i in range(1, 5):
|
||||
test_writer.record(i)
|
||||
|
||||
# step 3: send the event to mq
|
||||
|
||||
# step 4: accept the event and write the file
|
||||
|
||||
log.debug("finished test_graph_summary_sample")
|
||||
|
||||
|
||||
def test_graph_summary_callback():
|
||||
dataset = get_dataset()
|
||||
|
@ -125,18 +118,8 @@ def test_graph_summary_callback():
|
|||
optim = Momentum(net.trainable_params(), 0.1, 0.9)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer:
|
||||
summary_cb = SummaryStep(test_writer, 1)
|
||||
model.train(2, dataset, callbacks=summary_cb)
|
||||
|
||||
|
||||
def test_graph_summary_callback2():
|
||||
dataset = get_dataset()
|
||||
net = Net()
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optim = Momentum(net.trainable_params(), 0.1, 0.9)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) as test_writer:
|
||||
summary_cb = SummaryStep(test_writer, 1)
|
||||
model.train(2, dataset, callbacks=summary_cb)
|
||||
summary_collector = SummaryCollector(SUMMARY_DIR,
|
||||
collect_freq=1,
|
||||
keep_default_action=False,
|
||||
collect_specified_data={'collect_graph': True})
|
||||
model.train(1, dataset, callbacks=[summary_collector])
|
||||
|
|
|
@ -26,9 +26,8 @@ import mindspore.nn as nn
|
|||
from mindspore import Model, context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.callback import SummaryStep
|
||||
from mindspore.train.summary.summary_record import SummaryRecord, \
|
||||
_cache_summary_tensor_data
|
||||
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
|
||||
from mindspore.train.callback import Callback
|
||||
from .....dataset_mock import MindData
|
||||
|
||||
CUR_DIR = os.getcwd()
|
||||
|
@ -155,7 +154,8 @@ def get_dataset():
|
|||
return dataset
|
||||
|
||||
|
||||
class ImageSummaryCallback:
|
||||
class ImageSummaryCallback(Callback):
|
||||
"""Image summary callback."""
|
||||
|
||||
def __init__(self, summary_record):
|
||||
self._summary_record = summary_record
|
||||
|
@ -164,9 +164,10 @@ class ImageSummaryCallback:
|
|||
return self
|
||||
|
||||
def __exit__(self, *err):
|
||||
pass
|
||||
self._summary_record.close()
|
||||
|
||||
def record(self, step, train_network=None):
|
||||
"""record data."""
|
||||
self._summary_record.record(step, train_network)
|
||||
self._summary_record.flush()
|
||||
|
||||
|
@ -183,9 +184,8 @@ def test_image_summary_train():
|
|||
# step 2: create the Event
|
||||
|
||||
model = get_model()
|
||||
fn = ImageSummaryCallback(test_writer)
|
||||
summary_recode = SummaryStep(fn, 1)
|
||||
model.train(2, dataset, callbacks=summary_recode)
|
||||
callback = ImageSummaryCallback(test_writer)
|
||||
model.train(2, dataset, callbacks=[callback])
|
||||
|
||||
# step 3: send the event to mq
|
||||
|
||||
|
|
|
@ -24,11 +24,9 @@ import random
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.callback import SummaryStep
|
||||
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
|
||||
|
||||
CUR_DIR = os.getcwd()
|
||||
|
@ -192,16 +190,6 @@ def test_scalar_summary_with_ge_2():
|
|||
|
||||
def test_validate():
|
||||
with SummaryRecord(SUMMARY_DIR) as sr:
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, 0)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, -1)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, 1.2)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, True)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, "str")
|
||||
sr.record(1)
|
||||
with pytest.raises(ValueError):
|
||||
sr.record(False)
|
||||
|
@ -215,17 +203,3 @@ def test_validate():
|
|||
sr.record("str")
|
||||
with pytest.raises(ValueError):
|
||||
sr.record(sr)
|
||||
|
||||
SummaryStep(sr, 1)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, 1.2)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, False)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, "str")
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, (1, 2))
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, [3, 4])
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, sr)
|
||||
|
|
|
@ -59,7 +59,8 @@ def test_summaryrecord_input_null_string():
|
|||
log.debug("begin test_summaryrecord_input_null_string")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
SummaryRecord("")
|
||||
with SummaryRecord(""):
|
||||
pass
|
||||
except:
|
||||
assert True
|
||||
else:
|
||||
|
@ -71,7 +72,8 @@ def test_summaryrecord_input_None():
|
|||
log.debug("begin test_summaryrecord_input_None")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
SummaryRecord(None)
|
||||
with SummaryRecord(None):
|
||||
pass
|
||||
except:
|
||||
assert True
|
||||
else:
|
||||
|
@ -83,7 +85,8 @@ def test_summaryrecord_input_relative_dir_1():
|
|||
log.debug("begin test_summaryrecord_input_relative_dir_1")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
SummaryRecord("./test_temp_summary_event_file/")
|
||||
with SummaryRecord("./test_temp_summary_event_file/"):
|
||||
pass
|
||||
except:
|
||||
assert False
|
||||
else:
|
||||
|
@ -95,7 +98,8 @@ def test_summaryrecord_input_relative_dir_2():
|
|||
log.debug("begin test_summaryrecord_input_relative_dir_2")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
SummaryRecord("../summary/")
|
||||
with SummaryRecord("../summary/"):
|
||||
pass
|
||||
except:
|
||||
assert False
|
||||
else:
|
||||
|
@ -107,7 +111,8 @@ def test_summaryrecord_input_invalid_type_dir():
|
|||
log.debug("begin test_summaryrecord_input_invalid_type_dir")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
SummaryRecord(32)
|
||||
with SummaryRecord(32):
|
||||
pass
|
||||
except:
|
||||
assert True
|
||||
else:
|
||||
|
@ -119,7 +124,8 @@ def test_mulit_layer_directory():
|
|||
log.debug("begin test_mulit_layer_directory")
|
||||
# step 0: create the thread
|
||||
try:
|
||||
SummaryRecord("./test_temp_summary_event_file/test/t1/")
|
||||
with SummaryRecord("./test_temp_summary_event_file/test/t1/"):
|
||||
pass
|
||||
except:
|
||||
assert False
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Test the exception parameter scenario for summary collector."""
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from mindspore.train.callback import SummaryCollector
|
||||
|
||||
|
||||
class TestSummaryCollector:
|
||||
"""Test the exception parameter for summary collector."""
|
||||
base_summary_dir = ''
|
||||
|
||||
def setup_class(self):
|
||||
"""Run before test this class."""
|
||||
self.base_summary_dir = tempfile.mkdtemp(suffix='summary')
|
||||
|
||||
def teardown_class(self):
|
||||
"""Run after test this class."""
|
||||
if os.path.exists(self.base_summary_dir):
|
||||
shutil.rmtree(self.base_summary_dir)
|
||||
|
||||
@pytest.mark.parametrize("summary_dir", [1234, None, True, ''])
|
||||
def test_params_with_summary_dir_value_error(self, summary_dir):
|
||||
"""Test the exception scenario for summary dir."""
|
||||
if isinstance(summary_dir, str):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir)
|
||||
assert str(exc.value) == 'For `summary_dir` the value should be a valid string of path, ' \
|
||||
'but got empty string.'
|
||||
else:
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir)
|
||||
assert 'For `summary_dir` the type should be a valid type' in str(exc.value)
|
||||
|
||||
def test_params_with_summary_dir_not_dir(self):
|
||||
"""Test the given summary dir parameter is not a directory."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
summary_file = os.path.join(summary_dir, 'temp_file.txt')
|
||||
with open(summary_file, 'w') as file_handle:
|
||||
file_handle.write('temp')
|
||||
print(os.path.isfile(summary_file))
|
||||
with pytest.raises(NotADirectoryError):
|
||||
SummaryCollector(summary_dir=summary_file)
|
||||
|
||||
@pytest.mark.parametrize("collect_freq", [None, 0, 0.01])
|
||||
def test_params_with_collect_freq_exception(self, collect_freq):
|
||||
"""Test the exception scenario for collect freq."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
if isinstance(collect_freq, int):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
|
||||
expected_msg = f'For `collect_freq` the value should be greater than 0, but got `{collect_freq}`.'
|
||||
assert expected_msg == str(exc.value)
|
||||
else:
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
|
||||
expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \
|
||||
f'bug got {type(collect_freq).__name__}.'
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("action", [None, 123, '', '123'])
|
||||
def test_params_with_action_exception(self, action):
|
||||
"""Test the exception scenario for action."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir, keep_default_action=action)
|
||||
expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \
|
||||
f"bug got {type(action).__name__}."
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("collect_specified_data", [123])
|
||||
def test_params_with_collect_specified_data_type_error(self, collect_specified_data):
|
||||
"""Test type error scenario for collect specified data param."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
|
||||
|
||||
expected_msg = f"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], " \
|
||||
f"bug got {type(collect_specified_data).__name__}."
|
||||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("collect_specified_data", [
|
||||
{
|
||||
123: 123
|
||||
},
|
||||
{
|
||||
None: True
|
||||
}
|
||||
])
|
||||
def test_params_with_collect_specified_data_key_type_error(self, collect_specified_data):
|
||||
"""Test the key of collect specified data param."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
|
||||
|
||||
param_name = list(collect_specified_data)[0]
|
||||
expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \
|
||||
f"bug got {type(param_name).__name__}."
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("collect_specified_data", [
|
||||
{
|
||||
'collect_metric': None
|
||||
},
|
||||
{
|
||||
'collect_graph': 123
|
||||
},
|
||||
{
|
||||
'histogram_regular': 123
|
||||
},
|
||||
])
|
||||
def test_params_with_collect_specified_data_value_type_error(self, collect_specified_data):
|
||||
"""Test the value of collect specified data param."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
|
||||
|
||||
param_name = list(collect_specified_data)[0]
|
||||
param_value = collect_specified_data[param_name]
|
||||
expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']"
|
||||
expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
|
||||
f'bug got {type(param_value).__name__}.'
|
||||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
def test_params_with_collect_specified_data_unexpected_key(self):
|
||||
"""Test the collect_specified_data parameter with unexpected key."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
data = {'unexpected_key': True}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryCollector(summary_dir, collect_specified_data=data)
|
||||
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported."
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize("custom_lineage_data", [
|
||||
123,
|
||||
{
|
||||
'custom': {}
|
||||
},
|
||||
{
|
||||
'custom': None
|
||||
},
|
||||
{
|
||||
123: 'custom'
|
||||
}
|
||||
])
|
||||
def test_params_with_custom_lineage_data_type_error(self, custom_lineage_data):
|
||||
"""Test the custom lineage data parameter type error."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir, custom_lineage_data=custom_lineage_data)
|
||||
|
||||
if not isinstance(custom_lineage_data, dict):
|
||||
expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \
|
||||
f"bug got {type(custom_lineage_data).__name__}."
|
||||
else:
|
||||
param_name = list(custom_lineage_data)[0]
|
||||
param_value = custom_lineage_data[param_name]
|
||||
if not isinstance(param_name, str):
|
||||
arg_name = f'custom_lineage_data -> {param_name}'
|
||||
expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \
|
||||
f'bug got {type(param_name).__name__}.'
|
||||
else:
|
||||
arg_name = f'the value of custom_lineage_data -> {param_name}'
|
||||
expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \
|
||||
f'bug got {type(param_value).__name__}.'
|
||||
|
||||
assert expected_msg == str(exc.value)
|
|
@ -20,8 +20,8 @@ import pytest
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Model, context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.callback import SummaryStep
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....dataset_mock import MindData
|
||||
|
||||
|
@ -174,7 +174,7 @@ class TestGraphMode:
|
|||
model.train(1, dataset)
|
||||
|
||||
|
||||
class CallbackTest:
|
||||
class CallbackTest(Callback):
|
||||
""" CallbackTest definition """
|
||||
|
||||
def __init__(self):
|
||||
|
@ -186,19 +186,19 @@ class CallbackTest:
|
|||
def __exit__(self, *err):
|
||||
pass
|
||||
|
||||
def record(self, step, *args):
|
||||
print(step, args)
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
print(cb_params.cur_epoch_num, cb_params.cur_step_num)
|
||||
|
||||
|
||||
def test_train_callback(test_with_simu):
|
||||
""" test_train_callback """
|
||||
dataset = get_dataset()
|
||||
model = get_model()
|
||||
fn = CallbackTest()
|
||||
summary_recode = SummaryStep(fn, 2)
|
||||
callback = CallbackTest()
|
||||
if test_with_simu:
|
||||
return
|
||||
model.train(2, dataset, callbacks=summary_recode)
|
||||
model.train(2, dataset, callbacks=callback)
|
||||
|
||||
|
||||
log = logging.getLogger("test")
|
||||
|
|
Loading…
Reference in New Issue