diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 67e36c30cf..85fd6fa189 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -14,7 +14,10 @@ # ============================================================================ """Train utility.""" import os +from collections.abc import Iterable + import numpy as np + from mindspore.common.tensor import Tensor from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype from mindspore.common import dtype as mstype @@ -213,6 +216,7 @@ def _check_to_numpy(plugin, tensor): raise ValueError('The tensor should not be empty.') return np_value + def _check_lineage_value(plugin, value): """Check the lineage value.""" def raises(plugin, prototype): @@ -229,3 +233,20 @@ def _check_lineage_value(plugin, value): if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo): raises(plugin, UserDefinedInfo) + + +def check_value_type(arg_name, arg_value, valid_types): + """Checks whether a value is instance of some types.""" + valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,) + is_valid = True + + # bool is subclass of int, so for a bool value, we need to extra check + if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types: + is_valid = False + + if not isinstance(arg_value, valid_types): + is_valid = False + + if not is_valid: + raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, ' + f'bug got {type(arg_value).__name__}.') diff --git a/mindspore/train/callback/__init__.py b/mindspore/train/callback/__init__.py index 4e0dd729fb..fd2e4760b3 100644 --- a/mindspore/train/callback/__init__.py +++ b/mindspore/train/callback/__init__.py @@ -22,7 +22,8 @@ from ._checkpoint import CheckpointConfig from ._checkpoint import CheckpointManager as _CheckpointManager from ._checkpoint import ModelCheckpoint from ._loss_monitor import LossMonitor -from ._summary_step import SummaryStep from ._time_monitor import TimeMonitor +from ._summary_collector import SummaryCollector -__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"] +__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", + "SummaryCollector", "CheckpointConfig", "RunContext"] diff --git a/mindspore/train/callback/_dataset_graph.py b/mindspore/train/callback/_dataset_graph.py new file mode 100644 index 0000000000..e8c8dcb2ba --- /dev/null +++ b/mindspore/train/callback/_dataset_graph.py @@ -0,0 +1,128 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define dataset graph related operations.""" +import json +from importlib import import_module + +from mindspore.train import lineage_pb2 + + +class DatasetGraph: + """Handle the data graph and packages it into binary data.""" + def package_dataset_graph(self, dataset): + """ + packages dataset graph into binary data + + Args: + dataset (MindData): refer to MindDataset + + Returns: + DatasetGraph, a object of lineage_pb2.DatasetGraph. + """ + dataset_package = import_module('mindspore.dataset') + dataset_dict = dataset_package.serialize(dataset) + json_str = json.dumps(dataset_dict, indent=2) + dataset_dict = json.loads(json_str) + dataset_graph_proto = lineage_pb2.DatasetGraph() + if "children" in dataset_dict: + children = dataset_dict.pop("children") + if children: + self._package_children(children=children, message=dataset_graph_proto) + self._package_current_dataset(operation=dataset_dict, message=dataset_graph_proto) + return dataset_graph_proto + + def _package_children(self, children, message): + """ + Package children in dataset operation. + + Args: + children (list[dict]): Child operations. + message (DatasetGraph): Children proto message. + """ + for child in children: + if child: + child_graph_message = getattr(message, "children").add() + grandson = child.pop("children") + if grandson: + self._package_children(children=grandson, message=child_graph_message) + # package other parameters + self._package_current_dataset(operation=child, message=child_graph_message) + + def _package_current_dataset(self, operation, message): + """ + Package operation parameters in event message. + + Args: + operation (dict): Operation dict. + message (Operation): Operation proto message. + """ + for key, value in operation.items(): + if value and key == "operations": + for operator in value: + self._package_enhancement_operation( + operator, + message.operations.add() + ) + elif value and key == "sampler": + self._package_enhancement_operation( + value, + message.sampler + ) + else: + self._package_parameter(key, value, message.parameter) + + def _package_enhancement_operation(self, operation, message): + """ + Package enhancement operation in MapDataset. + + Args: + operation (dict): Enhancement operation. + message (Operation): Enhancement operation proto message. + """ + for key, value in operation.items(): + if isinstance(value, list): + if all(isinstance(ele, int) for ele in value): + message.size.extend(value) + else: + message.weights.extend(value) + else: + self._package_parameter(key, value, message.operationParam) + + @staticmethod + def _package_parameter(key, value, message): + """ + Package parameters in operation. + + Args: + key (str): Operation name. + value (Union[str, bool, int, float, list, None]): Operation args. + message (OperationParameter): Operation proto message. + """ + if isinstance(value, str): + message.mapStr[key] = value + elif isinstance(value, bool): + message.mapBool[key] = value + elif isinstance(value, int): + message.mapInt[key] = value + elif isinstance(value, float): + message.mapDouble[key] = value + elif isinstance(value, list) and key != "operations": + if value: + replace_value_list = list(map(lambda x: "" if x is None else x, value)) + message.mapStrList[key].strValue.extend(replace_value_list) + elif value is None: + message.mapStr[key] = "None" + else: + raise ValueError(f"Parameter {key} is not supported in event package.") diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py new file mode 100644 index 0000000000..e2e4a9cc2d --- /dev/null +++ b/mindspore/train/callback/_summary_collector.py @@ -0,0 +1,786 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Summary collector callback.""" + +import os +import re +import json + +from importlib import import_module + +import numpy as np + +from mindspore import log as logger +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore.train.summary.summary_record import SummaryRecord +from mindspore.train.summary.enum import PluginEnum, ModeEnum +from mindspore.train.callback import Callback, ModelCheckpoint +from mindspore.train import lineage_pb2 +from mindspore.train.callback._dataset_graph import DatasetGraph +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.nn.loss.loss import _Loss +from mindspore.train._utils import check_value_type + + +class LineageMetadata: + """Initialize parameters used in model lineage management.""" + train_dataset_path = 'train_dataset_path' + valid_dataset_path = 'valid_dataset_path' + train_network = 'train_network' + loss_function = 'loss_function' + loss = 'loss' + optimizer = 'optimizer' + learning_rate = 'learning_rate' + epoch = 'epoch' + step_num = 'step_num' + parallel_mode = 'parallel_mode' + device_num = 'device_num' + batch_size = 'batch_size' + model_path = 'model_path' + model_ckpt = 'model_ckpt' + model_size = 'model_size' + metrics = 'metrics' + train_dataset_size = 'train_dataset_size' + valid_dataset_size = 'valid_dataset_size' + + +class SummaryCollector(Callback): + """ + SummaryCollector can help you to collect some common information. + + It can help you to collect loss, learning late, computational graph and so on. + SummaryCollector also persists data collected by the summary operator into a summary file. + + Note: + 1. Multiple SummaryCollector instances in callback list are not allowed. + 2. Not all information is collected at the training phase or at the eval phase. + 3. SummaryCollector always record the data collected by the summary operator. + + Args: + summary_dir (str): The collected data will be persisted to this directory. + If the directory does not exist, it will be created automatically. + collect_freq (int): Set the frequency of data collection, it should be greater then zero, + and the unit is `step`. Default: 10. + It is important to note that if the data sink mode is used, the unit will become the `epoch`. + It is not recommended to collect data too frequently, which can affect performance. + collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None. + By default, if set to None, all data is collected as the default behavior. + If you want to customize the data collected, you can do so with a dictionary. + Examples,you can set {'collect_metric': False} to control not collecting metrics. + The data that supports control is shown below. + + - collect_metric: Whether to collect training metrics, currently only loss is collected. + Optional: True/False. Default: True. + - collect_graph: Whether to collect computational graph, currently only + training computational graph is collected. Optional: True/False. Default: True. + - collect_train_lineage: Whether to collect lineage data for the training phase, + this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True. + - collect_eval_lineage: Whether to collect lineage data for the eval phase, + this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True. + - collect_input_data: Whether to collect dataset for each training. Currently only image data is supported. + Optional: True/False. Default: True. + - collect_dataset_graph: Whether to collect dataset graph for the training phase. + Optional: True/False. Default: True. + - histogram_regular: Collect weight and bias for parameter distribution page display in MindInsight. + This field allows regular strings to control which parameters to collect. + Default: None, it means only the first five parameters are collected. + It is not recommended to collect too many parameters at once, as it can affect performance. + Note that if you collect too many parameters and run out of memory, the training will fail. + keep_default_action (bool): This field affects the collection behavior of the 'collect_specified_data' field. + Optional: True/False, Default: True. + True: means that after specified data is set, non-specified data is collected as the default behavior. + False: means that after specified data is set, only the specified data is collected, + and the others are not collected. + custom_lineage_data (Union[dict, None]): Allows you to customize the data and present it on the MingInsight + lineage page. In the custom data, the key type support str, and the value type support str/int/float. + Default: None, it means there is no custom data. + + Raises: + ValueError: If the parameter value is not expected. + TypeError: If the parameter type is not expected. + RuntimeError: If an error occurs during data collection. + + Examples: + >>> # Simple usage: + >>> summary_collector = SummaryCollector(summary_dir='./summary_dir') + >>> model.train(epoch, dataset, callbacks=summary_collector) + >>> + >>> # Do not collect metric and collect the first layer parameter, others are collected by default + >>> specified={'collect_metric': False, 'histogram_regular': '^conv1.*'} + >>> summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_specified_data=specified) + >>> model.train(epoch, dataset, callbacks=summary_collector) + >>> + >>> # Only collect metric, custom lineage data and record data that collected by the summary operator, + >>> # others are not collected + >>> specified = {'collect_metric':True, 'custom_lineage_data': {'version': 'resnet50_v1'}} + >>> summary_collector = SummaryCollector('./summary_dir', + >>> collect_specified_data=specified, + >>> keep_default_action=False) + >>> model.train(epoch, dataset, callbacks=summary_collector) + """ + + _DEFAULT_SPECIFIED_DATA = { + 'collect_metric': True, + 'collect_graph': True, + 'collect_train_lineage': True, + 'collect_eval_lineage': True, + 'collect_input_data': True, + 'collect_dataset_graph': True, + 'histogram_regular': None + } + + # _OPTIMIZER_FAILED means find optimizer failed, so we will not collect data about optimizer. + _OPTIMIZER_FAILED = 'Failed' + + def __init__(self, summary_dir, collect_freq=10, collect_specified_data=None, + keep_default_action=True, custom_lineage_data=None): + super(SummaryCollector, self).__init__() + + self._summary_dir = self._process_summary_dir(summary_dir) + self._record = None + + self._check_collect_freq(collect_freq) + self._collect_freq = collect_freq + + self._check_action(keep_default_action) + + self._collect_specified_data = self._process_specified_data(collect_specified_data, keep_default_action) + logger.info(f"For `collect_specified_data` the value after processing is: {self._collect_specified_data}.") + + self._check_custom_lineage_data(custom_lineage_data) + self._custom_lineage_data = custom_lineage_data + + self._optimizer = None + self._has_saved_train_network = False + self._has_saved_custom_data = False + self._is_parse_loss_success = True + + def __enter__(self): + self._record = SummaryRecord(log_dir=self._summary_dir) + return self + + def __exit__(self, *err): + self._record.close() + + @staticmethod + def _process_summary_dir(summary_dir): + """Check the summary dir, and create a new directory if it not exists.""" + check_value_type('summary_dir', summary_dir, str) + summary_dir = summary_dir.strip() + if not summary_dir: + raise ValueError('For `summary_dir` the value should be a valid string of path, but got empty string.') + + summary_dir = os.path.realpath(summary_dir) + if not os.path.exists(summary_dir): + os.makedirs(summary_dir, exist_ok=True) + else: + if not os.path.isdir(summary_dir): + raise NotADirectoryError('For `summary_dir` it should be a directory path.') + + return summary_dir + + @staticmethod + def _check_collect_freq(freq): + """Check collect freq type and value.""" + check_value_type('collect_freq', freq, int) + if freq <= 0: + raise ValueError(f'For `collect_freq` the value should be greater than 0, but got `{freq}`.') + + @staticmethod + def _check_custom_lineage_data(custom_lineage_data): + """ + Check user custom lineage data. + + Args: + custom_lineage_data (dict): The user custom defined data. + + Raises: + TypeError: If the type of parameters is invalid. + """ + if custom_lineage_data is None: + return + + check_value_type('custom_lineage_data', custom_lineage_data, [dict, type(None)]) + for key, value in custom_lineage_data.items(): + check_value_type(f'custom_lineage_data -> {key}', key, str) + check_value_type(f'the value of custom_lineage_data -> {key}', value, (int, str, float)) + + @staticmethod + def _check_action(action): + """Check action type.""" + check_value_type('keep_default_action', action, bool) + + def _process_specified_data(self, specified_data, action): + """Check specified data type and value.""" + if specified_data is None: + if action: + return self._DEFAULT_SPECIFIED_DATA + return None + + check_value_type('collect_specified_data', specified_data, [dict, type(None)]) + + for param_name in specified_data: + check_value_type(param_name, param_name, [str]) + + unexpected_params = set(specified_data) - set(self._DEFAULT_SPECIFIED_DATA) + if unexpected_params: + raise ValueError(f'For `collect_specified_data` the keys {unexpected_params} are unsupported.') + + if 'histogram_regular' in specified_data: + check_value_type('histogram_regular', specified_data.get('histogram_regular'), (str, type(None))) + + bool_items = set(self._DEFAULT_SPECIFIED_DATA) - {'histogram_regular'} + for item in bool_items: + if item in specified_data: + check_value_type(item, specified_data.get(item), bool) + + if action: + result = dict(self._DEFAULT_SPECIFIED_DATA).update(specified_data) + else: + result = specified_data + return result + + def begin(self, run_context): + cb_params = run_context.original_args() + self._check_callbacks(cb_params) + + if cb_params.mode not in ModeEnum.to_list(): + raise ValueError('Only support `train` (model.train) and `eval` (model.eval) mode, ' + 'but got `{cb_params.mode}` mode.') + + self._record.set_mode(cb_params.mode) + if cb_params.mode == ModeEnum.TRAIN.value: + # Note: if model.init is not executed then the computed graph will not be obtained here + # The purpose of recording the graph here was to collect_freq if it was set to a large size, + # but also want to see the graph as soon after compilation. + self._collect_graphs(cb_params) + + self._collect_dataset_graph(cb_params) + + if self._custom_lineage_data and not self._has_saved_custom_data: + packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data) + self._record.add_value('custom_lineage_data', 'custom_lineage_data', packaged_custom_data) + self._has_saved_custom_data = True + + # There's nothing special about setting step to 0 here, just to satisfy the interface call + self._record.record(step=0) + + def step_end(self, run_context): + cb_params = run_context.original_args() + + if cb_params.mode == ModeEnum.TRAIN.value: + if cb_params.cur_step_num % self._collect_freq: + return + + if not self._has_saved_train_network: + self._collect_graphs(cb_params) + + self._collect_input_data(cb_params) + self._collect_metric(cb_params) + self._collect_histogram(cb_params) + + self._record.record(cb_params.cur_step_num) + + def end(self, run_context): + cb_params = run_context.original_args() + if cb_params.mode == ModeEnum.TRAIN.value: + self._collect_train_lineage(cb_params) + else: + self._collect_eval_lineage(cb_params) + + # There's nothing special about setting step to 0 here, just to satisfy the interface call + self._record.record(step=0) + + def _check_callbacks(self, cb_params): + """Check there if there are duplicate instances of SummaryCollector.""" + callbacks = cb_params.list_callback + + is_find = False + for callback in callbacks: + if type(callback).__name__ == self.__class__.__name__: + if not is_find: + is_find = True + continue + raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," + f"but expected only one {self.__class__.__name__} instance.") + + @staticmethod + def _package_custom_lineage_data(custom_lineage_data): + """ + Package user-defined lineage data into binary data. + + Args: + custom_lineage_data (dict): User custom lineage data. + + Returns: + UserDefinedInfo, a object of lineage_pb2.UserDefinedInfo. + """ + user_defined_info = lineage_pb2.UserDefinedInfo() + for key, value in custom_lineage_data.items(): + if isinstance(value, int): + attr_name = "map_int32" + elif isinstance(value, float): + attr_name = "map_double" + else: + attr_name = "map_str" + + user_info = user_defined_info.user_info.add() + getattr(user_info, attr_name)[key] = value + + return user_defined_info + + def _collect_input_data(self, cb_params): + """Only support to collect image data.""" + if not self._collect_specified_data.get('collect_input_data'): + return + + input_data = getattr(cb_params, 'train_dataset_element', None) + if input_data is None: + self._collect_specified_data['collect_input_data'] = False + logger.info("There is not a `train_dataset_element` in cb_params.") + return + + if isinstance(input_data, (list, tuple)): + input_data = input_data[0] + try: + self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data) + except ValueError: + self._collect_specified_data['collect_input_data'] = False + return + + def _collect_dataset_graph(self, cb_params): + """Only collect train dataset graph.""" + if not self._collect_specified_data.get('collect_dataset_graph'): + return + + # After analysis, we think that the validated dataset graph and the training dataset graph + # should be consistent under normal scenarios, so only the training dataset graph is collected. + if cb_params.mode == ModeEnum.TRAIN.value: + train_dataset = cb_params.train_dataset + dataset_graph = DatasetGraph() + graph_bytes = dataset_graph.package_dataset_graph(train_dataset) + self._record.add_value('dataset_graph', 'train_dataset', graph_bytes) + + def _collect_graphs(self, cb_params): + """Collect the graph of train network and eval network.""" + if not self._collect_specified_data.get('collect_graph'): + return + + network = cb_params.train_network if cb_params.mode == ModeEnum.TRAIN.value else cb_params.eval_network + graph_proto = network.get_func_graph_proto() + if graph_proto is None: + return + + self._has_saved_train_network = True + self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto) + + def _collect_metric(self, cb_params): + """Collect metric, currently only collection Loss is supported.""" + if not self._collect_specified_data.get('collect_metric'): + return + + loss = self._get_loss(cb_params) + if loss is None: + return + self._record.add_value(PluginEnum.SCALAR.value, 'loss/auto', loss) + + def _get_loss(self, cb_params): + """ + Get loss from the network output. + + Args: + cb_params (_InternalCallbackParam): Callback parameters. + + Returns: + Union[Tensor, None], if parse loss success, will return a Tensor value(shape is [1]), else return None. + """ + if not self._is_parse_loss_success: + # If parsing has failed before, avoid repeating it + return None + + output = cb_params.net_outputs + if output is None: + logger.warning("Can not find any output by this network.") + self._is_parse_loss_success = False + return None + + if isinstance(output, (int, float)): + loss = output + elif isinstance(output, (list, tuple)): + # If the output is a list, since the default network returns loss first, + # we assume that the first one is loss. + loss = output[0] + elif isinstance(output, Tensor) and (not output.shape or output.shape == [1]): + loss_numpy = output.asnumpy() + loss = float(np.atleast_1d(loss_numpy)[0]) + else: + logger.warning("The output type could not be identified, so no loss was recorded in SummaryCollector.") + self._is_parse_loss_success = False + return None + + if not isinstance(loss, Tensor): + loss = Tensor(loss) + + return loss + + def _get_optimizer(self, cb_params): + """ + Get optimizer from the cb_params or parse from the network. + + Args: + cb_params (_InternalCallbackParam): Callback parameters. + + Returns: + Union[Optimizer, None], if parse optimizer success, will return a optimizer, else return None. + """ + if self._optimizer == self._OPTIMIZER_FAILED: + return None + + if self._optimizer is not None: + return self._optimizer + + optimizer = cb_params.optimizer + if optimizer is None: + network = cb_params.train_network if cb_params.mode == 'train' else cb_params.eval_work + optimizer = self._parse_optimizer_by_network(network) + + if optimizer is None or not isinstance(optimizer, Optimizer): + logger.warning("Can not find optimizer in network, or the optimizer does not inherit Mindpore's optimizer, " + "so we will not collect data about optimizer in SummaryCollector.") + optimizer = self._OPTIMIZER_FAILED + + return optimizer + + @staticmethod + def _parse_optimizer_by_network(network): + """Parse optimizer from network, if parse success will return a optimizer, else return None.""" + optimizer = None + for _, cell in network.cells_and_names(): + try: + optimizer = getattr(cell, 'optimizer') + except AttributeError: + continue + + if not isinstance(optimizer, Optimizer): + continue + + # Optimizer found successfully + break + + return optimizer + + def _collect_histogram(self, cb_params): + """Collect histogram data, contain the parameter weight and bias.""" + # Note: if there is not a key named `histogram_regular` in `self._collect_specified_data`, + # it means we will not collect histogram data. + if 'histogram_regular' not in self._collect_specified_data: + return + + self._optimizer = self._get_optimizer(cb_params) + if self._optimizer is None: + return + + parameters = self._optimizer.parameters + regular = self._collect_specified_data.get('histogram_regular') + if regular is not None: + for parameter in parameters: + if re.match(regular, parameter.name): + self._record.add_value(PluginEnum.HISTOGRAM.value, parameter.name+'/auto', parameter.data) + return + + # Note: If `histogram_regular` in `self._collect_specified_data` and the value is None, + # we will collect the first five parameters. + default_parameter_count = 5 + for parameter in parameters[:default_parameter_count]: + self._record.add_value(PluginEnum.HISTOGRAM.value, parameter.name+'/auto', parameter.data) + + @staticmethod + def _get_learning_rate(optimizer): + """ + parse the learning rate from optimizer. + + Args: + optimizer (Optimizer): A optimizer which inherit the MindSpore Optimizer class. + + Returns: + Union[Tensor, None], if parse learning rate success, will return a Tensor, else return None. + """ + learning_rate = optimizer.learning_rate + if not isinstance(learning_rate, Parameter): + logger.info("The learning rate detected in the optimizer is not a Parameter type, so it is not recorded.") + return None + return learning_rate.data + + def _collect_train_lineage(self, cb_params): + """Collect train lineage data, the detail refer to lineage_pb2.TrainLineage.""" + if not self._collect_specified_data.get('collect_train_lineage'): + return + train_lineage = {} + loss = self._get_loss(cb_params) + if loss: + loss_numpy = loss.asnumpy() + loss = float(np.atleast_1d(loss_numpy)[0]) + train_lineage[LineageMetadata.loss] = loss + else: + train_lineage[LineageMetadata.loss] = None + + optimizer = self._get_optimizer(cb_params) + learning_rate = self._get_learning_rate(optimizer) + + if learning_rate is not None: + train_lineage[LineageMetadata.learning_rate] = list(np.atleast_1d(learning_rate.asnumpy()))[0] + else: + train_lineage[LineageMetadata.learning_rate] = None + train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None + train_lineage[LineageMetadata.train_network] = self._get_backbone(cb_params.train_network) + + loss_fn = self._get_loss_fn(cb_params) + train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None + + train_lineage[LineageMetadata.epoch] = cb_params.epoch_num + train_lineage[LineageMetadata.step_num] = cb_params.cur_step_num + train_lineage[LineageMetadata.parallel_mode] = cb_params.parallel_mode + train_lineage[LineageMetadata.device_num] = cb_params.device_number + train_lineage[LineageMetadata.batch_size] = cb_params.batch_num + + ckpt_file_path = self._get_ckpt_file_path(cb_params) + train_lineage[LineageMetadata.model_path] = json.dumps(dict(ckpt=ckpt_file_path)) + + model_size = os.path.getsize(ckpt_file_path) if ckpt_file_path else 0 + train_lineage[LineageMetadata.model_size] = model_size + + self._parse_dataset(cb_params, train_lineage) + + train_lineage_message = self._package_train_lineage_message(train_lineage) + + self._record.add_value(PluginEnum.TRAIN_LINEAGE.value, 'train_lineage', train_lineage_message) + + @staticmethod + def _package_train_lineage_message(train_lineage): + """ + Package train lineage data into binary data. + + Args: + train_lineage (dict): The train lineage dict, refer to the attribute of `_collect_train_lineage` method. + + Returns: + TrainLineage, a object of lineage_pb2.TrainLineage. + """ + lineage_message = lineage_pb2.TrainLineage() + + if train_lineage.get(LineageMetadata.train_network) is not None: + lineage_message.algorithm.network = train_lineage.get(LineageMetadata.train_network) + if train_lineage.get(LineageMetadata.loss) is not None: + lineage_message.algorithm.loss = train_lineage.get(LineageMetadata.loss) + + # Construct train_dataset message. + if train_lineage.get(LineageMetadata.train_dataset_path) is not None: + lineage_message.train_dataset.train_dataset_path = train_lineage.get(LineageMetadata.train_dataset_path) + if train_lineage.get(LineageMetadata.train_dataset_size) is not None: + lineage_message.train_dataset.train_dataset_size = train_lineage.get(LineageMetadata.train_dataset_size) + + # Construct model message + lineage_message.model.path = train_lineage.get(LineageMetadata.model_path) + lineage_message.model.size = train_lineage.get(LineageMetadata.model_size) + + # Construct hyper_parameters message. + if train_lineage.get(LineageMetadata.learning_rate) is not None: + lineage_message.hyper_parameters.learning_rate = train_lineage.get(LineageMetadata.learning_rate) + if train_lineage.get(LineageMetadata.optimizer) is not None: + lineage_message.hyper_parameters.optimizer = train_lineage.get(LineageMetadata.optimizer) + if train_lineage.get(LineageMetadata.loss_function) is not None: + lineage_message.hyper_parameters.loss_function = train_lineage.get(LineageMetadata.loss_function) + if train_lineage.get(LineageMetadata.parallel_mode) is not None: + lineage_message.hyper_parameters.parallel_mode = train_lineage.get(LineageMetadata.parallel_mode) + + lineage_message.hyper_parameters.epoch = train_lineage.get(LineageMetadata.epoch) + lineage_message.hyper_parameters.device_num = train_lineage.get(LineageMetadata.device_num) + lineage_message.hyper_parameters.batch_size = train_lineage.get(LineageMetadata.batch_size) + + return lineage_message + + def _parse_dataset(self, cb_params, lineage_dict): + """ + Analyze Dataset to get the dataset path and dataset size. + + Args: + cb_params (_InternalCallbackParam): Callback parameters. + lineage_dict (dict): The lineage dict, refer to the attribute + of `_collect_train_lineage` method or `_collect_eval_lineage`. + + Returns: + dict, the lineage metadata. + """ + dataset = cb_params.train_dataset if cb_params.mode == ModeEnum.TRAIN.value else cb_params.valid_dataset + + try: + dataset_path = self._get_dataset_path(dataset) + except IndexError: + dataset_path = None + + if dataset_path and os.path.isfile(dataset_path): + dataset_dir = os.path.dirname(dataset_path) + else: + dataset_dir = dataset_path + + batch_num = dataset.get_dataset_size() + batch_size = dataset.get_batch_size() + dataset_size = int(batch_num * batch_size) + + if cb_params.mode == ModeEnum.TRAIN.value: + lineage_dict[LineageMetadata.train_dataset_path] = dataset_dir + lineage_dict[LineageMetadata.train_dataset_size] = dataset_size + else: + lineage_dict[LineageMetadata.valid_dataset_path] = dataset_dir + lineage_dict[LineageMetadata.valid_dataset_size] = dataset_size + + return lineage_dict + + def _get_dataset_path(self, output_dataset): + """ + Get dataset path of MindDataset object. + + Args: + output_dataset (Union[Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, + VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]): + Refer to mindspore.dataset.Dataset. + + Returns: + str, dataset path. + + Raises: + IndexError: it means get dataset path failed. + """ + dataset_package = import_module('mindspore.dataset') + dataset_dir_set = (dataset_package.ImageFolderDatasetV2, dataset_package.MnistDataset, + dataset_package.Cifar10Dataset, dataset_package.Cifar100Dataset, + dataset_package.VOCDataset, dataset_package.CelebADataset) + dataset_file_set = (dataset_package.MindDataset, dataset_package.ManifestDataset) + dataset_files_set = (dataset_package.TFRecordDataset, dataset_package.TextFileDataset) + + if isinstance(output_dataset, dataset_file_set): + return output_dataset.dataset_file + if isinstance(output_dataset, dataset_dir_set): + return output_dataset.dataset_dir + if isinstance(output_dataset, dataset_files_set): + return output_dataset.dataset_files[0] + return self._get_dataset_path(output_dataset.input[0]) + + @staticmethod + def _get_ckpt_file_path(cb_params): + """ + Get checkpoint file path from MindSpore callback list. + + Args: + cb_params (_InternalCallbackParam): Callback parameters. + + Returns: + Union[str, None], if parse success will checkpoint file absolute path, else return None. + """ + callbacks = cb_params.list_callback + ckpt_file_path = None + for callback in callbacks: + if isinstance(callback, ModelCheckpoint): + ckpt_file_path = callback.latest_ckpt_file_name + + if ckpt_file_path: + ckpt_file_path = os.path.realpath(ckpt_file_path) + + return ckpt_file_path + + @staticmethod + def _get_backbone(network): + """ + Get the name of backbone network. + + Args: + network (Cell): The train network. + + Returns: + Union[str, None], If parse success, will return the name of the backbone network, else return None. + """ + backbone_name = None + backbone_key = '_backbone' + + for _, cell in network.cells_and_names(): + if hasattr(cell, backbone_key): + backbone_network = getattr(cell, backbone_key) + backbone_name = type(backbone_network).__name__ + + if backbone_name is None and network is not None: + backbone_name = type(network).__name__ + + return backbone_name + + @staticmethod + def _get_loss_fn(cb_params): + """ + Get loss function by cb_params and analyzing network. + + Args: + cb_params (_InternalCallbackParam): Callback parameters. + + Returns: + Union[Loss_fn, None], a Cell object, if parse failed, will return None. + """ + loss_fn = cb_params.loss_fn + if loss_fn is not None: + return loss_fn + + if cb_params.mode == ModeEnum.TRAIN.value: + network = cb_params.train_network + else: + network = cb_params.eval_network + + for _, cell in network.cells_and_names(): + if isinstance(cell, _Loss): + loss_fn = cell + break + return loss_fn + + def _collect_eval_lineage(self, cb_params): + """Collect eval lineage data, the detail refer to lineage_pb2.EvaluationLineage.""" + if not self._collect_specified_data.get('collect_eval_lineage'): + return + eval_lineage = dict() + + eval_lineage[LineageMetadata.metrics] = json.dumps(cb_params.metrics) + self._parse_dataset(cb_params, eval_lineage) + + eval_lineage_message = self._package_eval_lineage_message(eval_lineage) + self._record.add_value(PluginEnum.EVAL_LINEAGE.value, 'eval_lineage', eval_lineage_message) + + @staticmethod + def _package_eval_lineage_message(eval_lineage): + """ + Package eval lineage data into binary data. + + Args: + eval_lineage (dict): The eval lineage dict, refer to the attribute of `_collect_eval_lineage` method. + + Returns: + EvaluationLineage, a object of lineage_pb2.EvaluationLineage. + """ + lineage_message = lineage_pb2.EvaluationLineage() + + if eval_lineage.get(LineageMetadata.metrics) is not None: + lineage_message.metric = eval_lineage.get(LineageMetadata.metrics) + if eval_lineage.get(LineageMetadata.valid_dataset_path) is not None: + lineage_message.valid_dataset.valid_dataset_path = eval_lineage.get(LineageMetadata.valid_dataset_path) + if eval_lineage.get(LineageMetadata.valid_dataset_size) is not None: + lineage_message.valid_dataset.valid_dataset_size = eval_lineage.get(LineageMetadata.valid_dataset_size) + + return lineage_message diff --git a/mindspore/train/callback/_summary_step.py b/mindspore/train/callback/_summary_step.py deleted file mode 100644 index 0a4fbca80d..0000000000 --- a/mindspore/train/callback/_summary_step.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""SummaryStep Callback class.""" - -from ._callback import Callback - - -class SummaryStep(Callback): - """ - The summary callback class. - - Args: - summary (Object): Summary recode object. - flush_step (int): Number of interval steps to execute. Default: 10. - """ - - def __init__(self, summary, flush_step=10): - super(SummaryStep, self).__init__() - if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0: - raise ValueError("`flush_step` should be int and greater than 0") - self._summary = summary - self._flush_step = flush_step - - def __enter__(self): - self._summary.__enter__() - return self - - def __exit__(self, *err): - return self._summary.__exit__(*err) - - def step_end(self, run_context): - """ - Save summary. - - Args: - run_context (RunContext): Context of the train running. - """ - cb_params = run_context.original_args() - if cb_params.cur_step_num % self._flush_step == 0: - self._summary.record(cb_params.cur_step_num, cb_params.train_network) - - @property - def summary_file_name(self): - return self._summary.full_file_name diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 3ce71d76ce..2c08fa195b 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ """Model.""" +from collections.abc import Iterable + import numpy as np from mindspore import log as logger @@ -345,7 +347,8 @@ class Model: cb_params.parallel_mode = self._parallel_mode cb_params.device_number = self._device_number cb_params.train_dataset = train_dataset - cb_params.list_callback = callbacks + cb_params.list_callback = self._transform_callbacks(callbacks) + cb_params.train_dataset_element = None # build callback list with _CallbackManager(callbacks) as list_callback: @@ -358,6 +361,17 @@ class Model: else: self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) + @staticmethod + def _transform_callbacks(callbacks): + """Transform callback to a list.""" + if callbacks is None: + return [] + + if isinstance(callbacks, Iterable): + return list(callbacks) + + return [callbacks] + def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): """ Training process. The data would be passed to network through dataset channel. @@ -449,6 +463,7 @@ class Model: scaling_sens = self._get_scaling_sens() next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) + cb_params.train_dataset_element = next_element outputs = self._train_network(*next_element) cb_params.net_outputs = outputs if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): @@ -628,6 +643,7 @@ class Model: cb_params.batch_num = valid_dataset.get_dataset_size() cb_params.mode = "eval" cb_params.cur_step_num = 0 + cb_params.list_callback = self._transform_callbacks(callbacks) self._eval_network.set_train(mode=False) self._eval_network.phase = 'eval' diff --git a/mindspore/train/summary/enum.py b/mindspore/train/summary/enum.py new file mode 100644 index 0000000000..84044eab6c --- /dev/null +++ b/mindspore/train/summary/enum.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Summary's enumeration file.""" +from enum import Enum + + +class BaseEnum(Enum): + """The base enum class.""" + + @classmethod + def to_list(cls): + """Converts the enumeration into a list.""" + return [member.value for member in cls.__members__.values()] + + +class PluginEnum(BaseEnum): + """The list of plugins currently supported by the summary.""" + GRAPH = 'graph' + SCALAR = 'scalar' + IMAGE = 'image' + TENSOR = 'tensor' + HISTOGRAM = 'histogram' + TRAIN_LINEAGE = 'train_lineage' + EVAL_LINEAGE = 'eval_lineage' + DATASET_GRAPH = 'dataset_graph' + + +class ModeEnum(BaseEnum): + """The modes currently supported by the summary.""" + TRAIN = 'train' + EVAL = 'eval' diff --git a/tests/st/summary/test_gpu_summary.py b/tests/st/summary/test_gpu_summary.py index a1e8ca17d8..9b4095b8d9 100644 --- a/tests/st/summary/test_gpu_summary.py +++ b/tests/st/summary/test_gpu_summary.py @@ -75,7 +75,7 @@ class TestGpuSummary: if not os.path.exists(self.summary_dir): os.mkdir(self.summary_dir) - def teardown_emthod(self): + def teardown_method(self): """Run after method.""" if os.path.exists(self.summary_dir): shutil.rmtree(self.summary_dir) diff --git a/tests/ut/python/train/summary/test_graph_summary.py b/tests/ut/python/train/summary/test_graph_summary.py index 79a56ba748..643ddbdea2 100644 --- a/tests/ut/python/train/summary/test_graph_summary.py +++ b/tests/ut/python/train/summary/test_graph_summary.py @@ -20,8 +20,8 @@ import numpy as np import mindspore.nn as nn from mindspore import Model, context from mindspore.nn.optim import Momentum -from mindspore.train.callback import SummaryStep -from mindspore.train.summary.summary_record import SummaryRecord +from mindspore.train.summary import SummaryRecord +from mindspore.train.callback import SummaryCollector from .....dataset_mock import MindData CUR_DIR = os.getcwd() @@ -107,16 +107,9 @@ def test_graph_summary_sample(): model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer: model.train(2, dataset) - # step 2: create the Event for i in range(1, 5): test_writer.record(i) - # step 3: send the event to mq - - # step 4: accept the event and write the file - - log.debug("finished test_graph_summary_sample") - def test_graph_summary_callback(): dataset = get_dataset() @@ -125,18 +118,8 @@ def test_graph_summary_callback(): optim = Momentum(net.trainable_params(), 0.1, 0.9) context.set_context(mode=context.GRAPH_MODE) model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) - with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer: - summary_cb = SummaryStep(test_writer, 1) - model.train(2, dataset, callbacks=summary_cb) - - -def test_graph_summary_callback2(): - dataset = get_dataset() - net = Net() - loss = nn.SoftmaxCrossEntropyWithLogits() - optim = Momentum(net.trainable_params(), 0.1, 0.9) - context.set_context(mode=context.GRAPH_MODE) - model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) - with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) as test_writer: - summary_cb = SummaryStep(test_writer, 1) - model.train(2, dataset, callbacks=summary_cb) + summary_collector = SummaryCollector(SUMMARY_DIR, + collect_freq=1, + keep_default_action=False, + collect_specified_data={'collect_graph': True}) + model.train(1, dataset, callbacks=[summary_collector]) diff --git a/tests/ut/python/train/summary/test_image_summary.py b/tests/ut/python/train/summary/test_image_summary.py index e650442cd3..addeaec212 100644 --- a/tests/ut/python/train/summary/test_image_summary.py +++ b/tests/ut/python/train/summary/test_image_summary.py @@ -26,9 +26,8 @@ import mindspore.nn as nn from mindspore import Model, context from mindspore import Tensor from mindspore.nn.optim import Momentum -from mindspore.train.callback import SummaryStep -from mindspore.train.summary.summary_record import SummaryRecord, \ - _cache_summary_tensor_data +from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data +from mindspore.train.callback import Callback from .....dataset_mock import MindData CUR_DIR = os.getcwd() @@ -155,7 +154,8 @@ def get_dataset(): return dataset -class ImageSummaryCallback: +class ImageSummaryCallback(Callback): + """Image summary callback.""" def __init__(self, summary_record): self._summary_record = summary_record @@ -164,9 +164,10 @@ class ImageSummaryCallback: return self def __exit__(self, *err): - pass + self._summary_record.close() def record(self, step, train_network=None): + """record data.""" self._summary_record.record(step, train_network) self._summary_record.flush() @@ -183,9 +184,8 @@ def test_image_summary_train(): # step 2: create the Event model = get_model() - fn = ImageSummaryCallback(test_writer) - summary_recode = SummaryStep(fn, 1) - model.train(2, dataset, callbacks=summary_recode) + callback = ImageSummaryCallback(test_writer) + model.train(2, dataset, callbacks=[callback]) # step 3: send the event to mq diff --git a/tests/ut/python/train/summary/test_summary.py b/tests/ut/python/train/summary/test_summary.py index d26924a499..b069d89954 100644 --- a/tests/ut/python/train/summary/test_summary.py +++ b/tests/ut/python/train/summary/test_summary.py @@ -24,11 +24,9 @@ import random import numpy as np import pytest - import mindspore.nn as nn from mindspore.common.tensor import Tensor from mindspore.ops import operations as P -from mindspore.train.callback import SummaryStep from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data CUR_DIR = os.getcwd() @@ -192,16 +190,6 @@ def test_scalar_summary_with_ge_2(): def test_validate(): with SummaryRecord(SUMMARY_DIR) as sr: - with pytest.raises(ValueError): - SummaryStep(sr, 0) - with pytest.raises(ValueError): - SummaryStep(sr, -1) - with pytest.raises(ValueError): - SummaryStep(sr, 1.2) - with pytest.raises(ValueError): - SummaryStep(sr, True) - with pytest.raises(ValueError): - SummaryStep(sr, "str") sr.record(1) with pytest.raises(ValueError): sr.record(False) @@ -215,17 +203,3 @@ def test_validate(): sr.record("str") with pytest.raises(ValueError): sr.record(sr) - - SummaryStep(sr, 1) - with pytest.raises(ValueError): - SummaryStep(sr, 1.2) - with pytest.raises(ValueError): - SummaryStep(sr, False) - with pytest.raises(ValueError): - SummaryStep(sr, "str") - with pytest.raises(ValueError): - SummaryStep(sr, (1, 2)) - with pytest.raises(ValueError): - SummaryStep(sr, [3, 4]) - with pytest.raises(ValueError): - SummaryStep(sr, sr) diff --git a/tests/ut/python/train/summary/test_summary_abnormal_input.py b/tests/ut/python/train/summary/test_summary_abnormal_input.py index 95cbf7549b..388952feca 100644 --- a/tests/ut/python/train/summary/test_summary_abnormal_input.py +++ b/tests/ut/python/train/summary/test_summary_abnormal_input.py @@ -59,7 +59,8 @@ def test_summaryrecord_input_null_string(): log.debug("begin test_summaryrecord_input_null_string") # step 0: create the thread try: - SummaryRecord("") + with SummaryRecord(""): + pass except: assert True else: @@ -71,7 +72,8 @@ def test_summaryrecord_input_None(): log.debug("begin test_summaryrecord_input_None") # step 0: create the thread try: - SummaryRecord(None) + with SummaryRecord(None): + pass except: assert True else: @@ -83,7 +85,8 @@ def test_summaryrecord_input_relative_dir_1(): log.debug("begin test_summaryrecord_input_relative_dir_1") # step 0: create the thread try: - SummaryRecord("./test_temp_summary_event_file/") + with SummaryRecord("./test_temp_summary_event_file/"): + pass except: assert False else: @@ -95,7 +98,8 @@ def test_summaryrecord_input_relative_dir_2(): log.debug("begin test_summaryrecord_input_relative_dir_2") # step 0: create the thread try: - SummaryRecord("../summary/") + with SummaryRecord("../summary/"): + pass except: assert False else: @@ -107,7 +111,8 @@ def test_summaryrecord_input_invalid_type_dir(): log.debug("begin test_summaryrecord_input_invalid_type_dir") # step 0: create the thread try: - SummaryRecord(32) + with SummaryRecord(32): + pass except: assert True else: @@ -119,7 +124,8 @@ def test_mulit_layer_directory(): log.debug("begin test_mulit_layer_directory") # step 0: create the thread try: - SummaryRecord("./test_temp_summary_event_file/test/t1/") + with SummaryRecord("./test_temp_summary_event_file/test/t1/"): + pass except: assert False else: diff --git a/tests/ut/python/train/summary/test_summary_collector.py b/tests/ut/python/train/summary/test_summary_collector.py new file mode 100644 index 0000000000..5e7f8e662c --- /dev/null +++ b/tests/ut/python/train/summary/test_summary_collector.py @@ -0,0 +1,184 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test the exception parameter scenario for summary collector.""" +import os +import tempfile +import shutil +import pytest + +from mindspore.train.callback import SummaryCollector + + +class TestSummaryCollector: + """Test the exception parameter for summary collector.""" + base_summary_dir = '' + + def setup_class(self): + """Run before test this class.""" + self.base_summary_dir = tempfile.mkdtemp(suffix='summary') + + def teardown_class(self): + """Run after test this class.""" + if os.path.exists(self.base_summary_dir): + shutil.rmtree(self.base_summary_dir) + + @pytest.mark.parametrize("summary_dir", [1234, None, True, '']) + def test_params_with_summary_dir_value_error(self, summary_dir): + """Test the exception scenario for summary dir.""" + if isinstance(summary_dir, str): + with pytest.raises(ValueError) as exc: + SummaryCollector(summary_dir=summary_dir) + assert str(exc.value) == 'For `summary_dir` the value should be a valid string of path, ' \ + 'but got empty string.' + else: + with pytest.raises(TypeError) as exc: + SummaryCollector(summary_dir=summary_dir) + assert 'For `summary_dir` the type should be a valid type' in str(exc.value) + + def test_params_with_summary_dir_not_dir(self): + """Test the given summary dir parameter is not a directory.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + summary_file = os.path.join(summary_dir, 'temp_file.txt') + with open(summary_file, 'w') as file_handle: + file_handle.write('temp') + print(os.path.isfile(summary_file)) + with pytest.raises(NotADirectoryError): + SummaryCollector(summary_dir=summary_file) + + @pytest.mark.parametrize("collect_freq", [None, 0, 0.01]) + def test_params_with_collect_freq_exception(self, collect_freq): + """Test the exception scenario for collect freq.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + if isinstance(collect_freq, int): + with pytest.raises(ValueError) as exc: + SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq) + expected_msg = f'For `collect_freq` the value should be greater than 0, but got `{collect_freq}`.' + assert expected_msg == str(exc.value) + else: + with pytest.raises(TypeError) as exc: + SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq) + expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \ + f'bug got {type(collect_freq).__name__}.' + assert expected_msg == str(exc.value) + + @pytest.mark.parametrize("action", [None, 123, '', '123']) + def test_params_with_action_exception(self, action): + """Test the exception scenario for action.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + with pytest.raises(TypeError) as exc: + SummaryCollector(summary_dir=summary_dir, keep_default_action=action) + expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \ + f"bug got {type(action).__name__}." + assert expected_msg == str(exc.value) + + @pytest.mark.parametrize("collect_specified_data", [123]) + def test_params_with_collect_specified_data_type_error(self, collect_specified_data): + """Test type error scenario for collect specified data param.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + with pytest.raises(TypeError) as exc: + SummaryCollector(summary_dir, collect_specified_data=collect_specified_data) + + expected_msg = f"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], " \ + f"bug got {type(collect_specified_data).__name__}." + + assert expected_msg == str(exc.value) + + @pytest.mark.parametrize("collect_specified_data", [ + { + 123: 123 + }, + { + None: True + } + ]) + def test_params_with_collect_specified_data_key_type_error(self, collect_specified_data): + """Test the key of collect specified data param.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + with pytest.raises(TypeError) as exc: + SummaryCollector(summary_dir, collect_specified_data=collect_specified_data) + + param_name = list(collect_specified_data)[0] + expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \ + f"bug got {type(param_name).__name__}." + assert expected_msg == str(exc.value) + + @pytest.mark.parametrize("collect_specified_data", [ + { + 'collect_metric': None + }, + { + 'collect_graph': 123 + }, + { + 'histogram_regular': 123 + }, + ]) + def test_params_with_collect_specified_data_value_type_error(self, collect_specified_data): + """Test the value of collect specified data param.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + with pytest.raises(TypeError) as exc: + SummaryCollector(summary_dir, collect_specified_data=collect_specified_data) + + param_name = list(collect_specified_data)[0] + param_value = collect_specified_data[param_name] + expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']" + expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \ + f'bug got {type(param_value).__name__}.' + + assert expected_msg == str(exc.value) + + def test_params_with_collect_specified_data_unexpected_key(self): + """Test the collect_specified_data parameter with unexpected key.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + data = {'unexpected_key': True} + with pytest.raises(ValueError) as exc: + SummaryCollector(summary_dir, collect_specified_data=data) + expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported." + assert expected_msg == str(exc.value) + + @pytest.mark.parametrize("custom_lineage_data", [ + 123, + { + 'custom': {} + }, + { + 'custom': None + }, + { + 123: 'custom' + } + ]) + def test_params_with_custom_lineage_data_type_error(self, custom_lineage_data): + """Test the custom lineage data parameter type error.""" + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + with pytest.raises(TypeError) as exc: + SummaryCollector(summary_dir, custom_lineage_data=custom_lineage_data) + + if not isinstance(custom_lineage_data, dict): + expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \ + f"bug got {type(custom_lineage_data).__name__}." + else: + param_name = list(custom_lineage_data)[0] + param_value = custom_lineage_data[param_name] + if not isinstance(param_name, str): + arg_name = f'custom_lineage_data -> {param_name}' + expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \ + f'bug got {type(param_name).__name__}.' + else: + arg_name = f'the value of custom_lineage_data -> {param_name}' + expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \ + f'bug got {type(param_value).__name__}.' + + assert expected_msg == str(exc.value) diff --git a/tests/ut/python/train/test_training.py b/tests/ut/python/train/test_training.py index a007d18571..ad26951104 100644 --- a/tests/ut/python/train/test_training.py +++ b/tests/ut/python/train/test_training.py @@ -20,8 +20,8 @@ import pytest import mindspore.nn as nn from mindspore import Model, context from mindspore import Tensor +from mindspore.train.callback import Callback from mindspore.nn.optim import Momentum -from mindspore.train.callback import SummaryStep from ..ut_filter import non_graph_engine from ....dataset_mock import MindData @@ -174,7 +174,7 @@ class TestGraphMode: model.train(1, dataset) -class CallbackTest: +class CallbackTest(Callback): """ CallbackTest definition """ def __init__(self): @@ -186,19 +186,19 @@ class CallbackTest: def __exit__(self, *err): pass - def record(self, step, *args): - print(step, args) + def step_end(self, run_context): + cb_params = run_context.original_args() + print(cb_params.cur_epoch_num, cb_params.cur_step_num) def test_train_callback(test_with_simu): """ test_train_callback """ dataset = get_dataset() model = get_model() - fn = CallbackTest() - summary_recode = SummaryStep(fn, 2) + callback = CallbackTest() if test_with_simu: return - model.train(2, dataset, callbacks=summary_recode) + model.train(2, dataset, callbacks=callback) log = logging.getLogger("test")