!4582 Mod_callback_serial
Merge pull request !4582 from wanyiming/mod_callback_serial
This commit is contained in:
commit
465390e580
|
@ -75,9 +75,9 @@ class Callback:
|
|||
"""
|
||||
Abstract base class used to build a callback class. Callbacks are context managers
|
||||
which will be entered and exited when passing into the Model.
|
||||
You can leverage this mechanism to init and release resources automatically.
|
||||
You can use this mechanism to initialize and release resources automatically.
|
||||
|
||||
Callback function will execution some operating to the current step or epoch.
|
||||
Callback function will execute some operations in the current step or epoch.
|
||||
|
||||
Examples:
|
||||
>>> class Print_info(Callback):
|
||||
|
@ -229,11 +229,11 @@ class RunContext:
|
|||
"""
|
||||
Provides information about the model.
|
||||
|
||||
Run call being made. Provides information about original request to model function.
|
||||
callback objects can stop the loop by calling request_stop() of run_context.
|
||||
Provides information about original request to model function.
|
||||
Callback objects can stop the loop by calling request_stop() of run_context.
|
||||
|
||||
Args:
|
||||
original_args (dict): Holding the related information of model etc.
|
||||
original_args (dict): Holding the related information of model.
|
||||
"""
|
||||
def __init__(self, original_args):
|
||||
if not isinstance(original_args, dict):
|
||||
|
@ -246,13 +246,13 @@ class RunContext:
|
|||
Get the _original_args object.
|
||||
|
||||
Returns:
|
||||
Dict, a object holding the original arguments of model.
|
||||
Dict, an object that holds the original arguments of model.
|
||||
"""
|
||||
return self._original_args
|
||||
|
||||
def request_stop(self):
|
||||
"""
|
||||
Sets stop requested during training.
|
||||
Sets stop requirement during training.
|
||||
|
||||
Callbacks can use this function to request stop of iterations.
|
||||
model.train() checks whether this is called or not.
|
||||
|
|
|
@ -70,23 +70,24 @@ def _chg_ckpt_file_name_if_same_exist(directory, prefix):
|
|||
|
||||
class CheckpointConfig:
|
||||
"""
|
||||
The config for model checkpoint.
|
||||
The configuration of model checkpoint.
|
||||
|
||||
Note:
|
||||
During the training process, if dataset is transmitted through the data channel,
|
||||
suggest set save_checkpoint_steps be an integer multiple of loop_size.
|
||||
Otherwise there may be deviation in the timing of saving checkpoint.
|
||||
It is suggested to set 'save_checkpoint_steps' to an integer multiple of loop_size.
|
||||
Otherwise, the time to save the checkpoint may be biased.
|
||||
|
||||
Args:
|
||||
save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.
|
||||
save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0.
|
||||
Can't be used with save_checkpoint_steps at the same time.
|
||||
keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
|
||||
keep_checkpoint_max (int): Maximum number of checkpoint files can be saved. Default: 5.
|
||||
keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
|
||||
Can't be used with keep_checkpoint_max at the same time.
|
||||
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
|
||||
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
|
||||
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
|
||||
integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene.
|
||||
Default: True. Integrated save function is only supported in automatic parallel scene, not supported
|
||||
in manual parallel.
|
||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
||||
|
||||
Raises:
|
||||
ValueError: If the input_param is None or 0.
|
||||
|
@ -180,9 +181,9 @@ class ModelCheckpoint(Callback):
|
|||
It is called to combine with train process and save the model and network parameters after traning.
|
||||
|
||||
Args:
|
||||
prefix (str): Checkpoint files names prefix. Default: "CKP".
|
||||
directory (str): Folder path into which checkpoint files will be saved. Default: None.
|
||||
config (CheckpointConfig): Checkpoint strategy config. Default: None.
|
||||
prefix (str): The prefix name of checkpoint files. Default: "CKP".
|
||||
directory (str): The path of the folder which will be saved in the checkpoint file. Default: None.
|
||||
config (CheckpointConfig): Checkpoint strategy configuration. Default: None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prefix is invalid.
|
||||
|
|
|
@ -27,13 +27,13 @@ class LossMonitor(Callback):
|
|||
If the loss is NAN or INF, it will terminate training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
If per_print_times is 0, do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
per_print_times (int): Print the loss each every time. Default: 1.
|
||||
|
||||
Raises:
|
||||
ValueError: If print_step is not int or less than zero.
|
||||
ValueError: If print_step is not an integer or less than zero.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1):
|
||||
|
|
|
@ -62,7 +62,7 @@ class SummaryCollector(Callback):
|
|||
SummaryCollector can help you to collect some common information.
|
||||
|
||||
It can help you to collect loss, learning late, computational graph and so on.
|
||||
SummaryCollector also persists data collected by the summary operator into a summary file.
|
||||
SummaryCollector also enables the summary operator to collect data from a summary file.
|
||||
|
||||
Note:
|
||||
1. Multiple SummaryCollector instances in callback list are not allowed.
|
||||
|
@ -74,51 +74,51 @@ class SummaryCollector(Callback):
|
|||
If the directory does not exist, it will be created automatically.
|
||||
collect_freq (int): Set the frequency of data collection, it should be greater then zero,
|
||||
and the unit is `step`. Default: 10. If a frequency is set, we will collect data
|
||||
at (current steps % freq) == 0, and the first step will be collected at any time.
|
||||
when (current steps % freq) equals to 0, and the first step will be collected at any time.
|
||||
It is important to note that if the data sink mode is used, the unit will become the `epoch`.
|
||||
It is not recommended to collect data too frequently, which can affect performance.
|
||||
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
|
||||
By default, if set to None, all data is collected as the default behavior.
|
||||
If you want to customize the data collected, you can do so with a dictionary.
|
||||
Examples,you can set {'collect_metric': False} to control not collecting metrics.
|
||||
You can customize the collected data with a dictionary.
|
||||
For example, you can set {'collect_metric': False} to control not collecting metrics.
|
||||
The data that supports control is shown below.
|
||||
|
||||
- collect_metric: Whether to collect training metrics, currently only loss is collected.
|
||||
The first output will be treated as loss, and it will be averaged.
|
||||
- collect_metric: Whether to collect training metrics, currently only the loss is collected.
|
||||
The first output will be treated as the loss and it will be averaged.
|
||||
Optional: True/False. Default: True.
|
||||
- collect_graph: Whether to collect computational graph, currently only
|
||||
- collect_graph: Whether to collect the computational graph. Currently, only
|
||||
training computational graph is collected. Optional: True/False. Default: True.
|
||||
- collect_train_lineage: Whether to collect lineage data for the training phase,
|
||||
this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.
|
||||
- collect_eval_lineage: Whether to collect lineage data for the eval phase,
|
||||
- collect_eval_lineage: Whether to collect lineage data for the evaluation phase,
|
||||
this field will be displayed on the lineage page of Mindinsight. Optional: True/False. Default: True.
|
||||
- collect_input_data: Whether to collect dataset for each training. Currently only image data is supported.
|
||||
Optional: True/False. Default: True.
|
||||
- collect_dataset_graph: Whether to collect dataset graph for the training phase.
|
||||
Optional: True/False. Default: True.
|
||||
- histogram_regular: Collect weight and bias for parameter distribution page display in MindInsight.
|
||||
- histogram_regular: Collect weight and bias for parameter distribution page and displayed in MindInsight.
|
||||
This field allows regular strings to control which parameters to collect.
|
||||
Default: None, it means only the first five parameters are collected.
|
||||
It is not recommended to collect too many parameters at once, as it can affect performance.
|
||||
Note that if you collect too many parameters and run out of memory, the training will fail.
|
||||
keep_default_action (bool): This field affects the collection behavior of the 'collect_specified_data' field.
|
||||
Optional: True/False, Default: True.
|
||||
True: means that after specified data is set, non-specified data is collected as the default behavior.
|
||||
False: means that after specified data is set, only the specified data is collected,
|
||||
True: it means that after specified data is set, non-specified data is collected as the default behavior.
|
||||
False: it means that after specified data is set, only the specified data is collected,
|
||||
and the others are not collected.
|
||||
custom_lineage_data (Union[dict, None]): Allows you to customize the data and present it on the MingInsight
|
||||
lineage page. In the custom data, the key type support str, and the value type support str/int/float.
|
||||
Default: None, it means there is no custom data.
|
||||
collect_tensor_freq (Optional[int]): Same semantic as the `collect_freq`, but controls TensorSummary only.
|
||||
Because TensorSummary data is too large compared to other summary data, this parameter is used to reduce
|
||||
its collection. By default, TensorSummary data will be collected at most 20 steps, but not more than how
|
||||
many steps other summary data will be collected.
|
||||
lineage page. In the custom data, the type of the key supports str, and the type of value supports str, int
|
||||
and float. Default: None, it means there is no custom data.
|
||||
collect_tensor_freq (Optional[int]): The same semantics as the `collect_freq`, but controls TensorSummary only.
|
||||
Because TensorSummary data is too large to be compared with other summary data, this parameter is used to
|
||||
reduce its collection. By default, The maximum number of steps for collecting TensorSummary data is 21,
|
||||
but it will not exceed the number of steps for collecting other summary data.
|
||||
Default: None, which means to follow the behavior as described above. For example, given `collect_freq=10`,
|
||||
when the total steps is 600, TensorSummary will be collected 20 steps, while other summary data 61 steps,
|
||||
but when the total steps is 20, both TensorSummary and other summary will be collected 3 steps.
|
||||
Also note that when in parallel mode, the total steps will be splitted evenly, which will
|
||||
affect how many steps TensorSummary will be collected.
|
||||
max_file_size (Optional[int]): The maximum size in bytes each file can be written to the disk.
|
||||
affect the number of steps TensorSummary will be collected.
|
||||
max_file_size (Optional[int]): The maximum size in bytes of each file that can be written to the disk.
|
||||
Default: None, which means no limit. For example, to write not larger than 4GB,
|
||||
specify `max_file_size=4 * 1024**3`.
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ class FixedLossScaleManager(LossScaleManager):
|
|||
|
||||
Args:
|
||||
loss_scale (float): Loss scale. Default: 128.0.
|
||||
drop_overflow_update (bool): whether to do optimizer if there is overflow. Default: True.
|
||||
drop_overflow_update (bool): whether to execute optimizer if there is an overflow. Default: True.
|
||||
|
||||
Examples:
|
||||
>>> loss_scale_manager = FixedLossScaleManager()
|
||||
|
@ -59,7 +59,7 @@ class FixedLossScaleManager(LossScaleManager):
|
|||
return self._loss_scale
|
||||
|
||||
def get_drop_overflow_update(self):
|
||||
"""Get the flag whether to drop optimizer update when there is overflow happened"""
|
||||
"""Get the flag whether to drop optimizer update when there is an overflow."""
|
||||
return self._drop_overflow_update
|
||||
|
||||
def update_loss_scale(self, overflow):
|
||||
|
@ -82,7 +82,7 @@ class DynamicLossScaleManager(LossScaleManager):
|
|||
Dynamic loss-scale manager.
|
||||
|
||||
Args:
|
||||
init_loss_scale (float): Init loss scale. Default: 2**24.
|
||||
init_loss_scale (float): Initialize loss scale. Default: 2**24.
|
||||
scale_factor (int): Coefficient of increase and decrease. Default: 2.
|
||||
scale_window (int): Maximum continuous normal steps when there is no overflow. Default: 2000.
|
||||
|
||||
|
@ -135,7 +135,7 @@ class DynamicLossScaleManager(LossScaleManager):
|
|||
self.cur_iter += 1
|
||||
|
||||
def get_drop_overflow_update(self):
|
||||
"""Get the flag whether to drop optimizer update when there is overflow happened"""
|
||||
"""Get the flag whether to drop optimizer update when there is an overflow."""
|
||||
return True
|
||||
|
||||
def get_update_cell(self):
|
||||
|
|
|
@ -13,11 +13,11 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
quantization.
|
||||
Quantization.
|
||||
|
||||
User can use quantization aware to train a model. MindSpore supports quantization aware training,
|
||||
which models quantization errors in both the forward and backward passes using fake-quantization
|
||||
ops. Note that the entire computation is carried out in floating point. At the end of quantization
|
||||
operations. Note that the entire computation is carried out in floating point. At the end of quantization
|
||||
aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
|
||||
"""
|
||||
|
||||
|
|
|
@ -145,10 +145,10 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
|||
Saves checkpoint info to a specified file.
|
||||
|
||||
Args:
|
||||
parameter_list (list): Parameters list, each element is a dict
|
||||
parameter_list (list): Parameters list, each element is a dictionary
|
||||
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
||||
ckpt_file_name (str): Checkpoint file name.
|
||||
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
|
||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
||||
|
||||
Raises:
|
||||
RuntimeError: Failed to save the Checkpoint file.
|
||||
|
@ -271,10 +271,10 @@ def load_param_into_net(net, parameter_dict):
|
|||
|
||||
Args:
|
||||
net (Cell): Cell network.
|
||||
parameter_dict (dict): Parameter dict.
|
||||
parameter_dict (dict): Parameter dictionary.
|
||||
|
||||
Raises:
|
||||
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict.
|
||||
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
||||
"""
|
||||
if not isinstance(net, nn.Cell):
|
||||
logger.error("Failed to combine the net and the parameters.")
|
||||
|
@ -450,12 +450,12 @@ def _fill_param_into_net(net, parameter_list):
|
|||
|
||||
def export(net, *inputs, file_name, file_format='AIR'):
|
||||
"""
|
||||
Exports MindSpore predict model to file in specified format.
|
||||
Export the MindSpore prediction model to a file in the specified format.
|
||||
|
||||
Args:
|
||||
net (Cell): MindSpore network.
|
||||
inputs (Tensor): Inputs of the `net`.
|
||||
file_name (str): File name of model to export.
|
||||
file_name (str): File name of the model to be exported.
|
||||
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
|
||||
|
||||
- AIR: Ascend Intermidiate Representation. An intermidiate representation format of Ascend model.
|
||||
|
@ -510,7 +510,7 @@ def parse_print(print_file_name):
|
|||
Loads Print data from a specified file.
|
||||
|
||||
Args:
|
||||
print_file_name (str): The file name of save print data.
|
||||
print_file_name (str): The file name of saved print data.
|
||||
|
||||
Returns:
|
||||
List, element of list is Tensor.
|
||||
|
|
|
@ -64,29 +64,29 @@ class SummaryRecord:
|
|||
SummaryRecord is used to record the summary data and lineage data.
|
||||
|
||||
The API will create a summary file and lineage files lazily in a given directory and writes data to them.
|
||||
It writes the data to files by executing the 'record' method. In addition to record the data bubbled up from
|
||||
It writes the data to files by executing the 'record' method. In addition to recording the data bubbled up from
|
||||
the network by defining the summary operators, SummaryRecord also supports to record extra data which
|
||||
can be added by calling add_value.
|
||||
|
||||
Note:
|
||||
1. Make sure to close the SummaryRecord at the end, or the process will not exit.
|
||||
Please see the Example section below on how to properly close with two ways.
|
||||
2. The SummaryRecord instance can only allow one at a time, otherwise it will cause problems with data writes.
|
||||
1. Make sure to close the SummaryRecord at the end, otherwise the process will not exit.
|
||||
Please see the Example section below to learn how to close properly in two ways.
|
||||
2. Only one SummaryRecord instance is allowed at a time, otherwise it will cause data writing problems.
|
||||
|
||||
Args:
|
||||
log_dir (str): The log_dir is a directory location to save the summary.
|
||||
queue_max_size (int): Deprecated. The capacity of event queue.(reserved). Default: 0.
|
||||
flush_time (int): Deprecated. Frequency to flush the summaries to disk, the unit is second. Default: 120.
|
||||
flush_time (int): Deprecated. Frequency of flush the summary file to disk. The unit is second. Default: 120.
|
||||
file_prefix (str): The prefix of file. Default: "events".
|
||||
file_suffix (str): The suffix of file. Default: "_MS".
|
||||
network (Cell): Obtain a pipeline through network for saving graph summary. Default: None.
|
||||
max_file_size (Optional[int]): The maximum size in bytes each file can be written to the disk. \
|
||||
max_file_size (Optional[int]): The maximum size of each file that can be written to disk (in bytes). \
|
||||
Unlimited by default. For example, to write not larger than 4GB, specify `max_file_size=4 * 1024**3`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `max_file_size`, `queue_max_size` or `flush_time` is not int, \
|
||||
or `file_prefix` and `file_suffix` is not str.
|
||||
RuntimeError: If the log_dir can not be resolved to a canonicalized absolute pathname.
|
||||
TypeError: If the data type of `max_file_size`, `queue_max_size` or `flush_time` is not int, \
|
||||
or the data type of `file_prefix` and `file_suffix` is not str.
|
||||
RuntimeError: If the log_dir is not a normalized absolute path name.
|
||||
|
||||
Examples:
|
||||
>>> # use in with statement to auto close
|
||||
|
@ -171,10 +171,10 @@ class SummaryRecord:
|
|||
|
||||
def set_mode(self, mode):
|
||||
"""
|
||||
Set the mode for the recorder to be aware. The mode is set 'train' by default.
|
||||
Set the mode for the recorder to be aware. The mode is set to 'train' by default.
|
||||
|
||||
Args:
|
||||
mode (str): The mode to set, which should be 'train' or 'eval'.
|
||||
mode (str): The mode to be set, which should be 'train' or 'eval'.
|
||||
|
||||
Raises:
|
||||
ValueError: When the mode is not recognized.
|
||||
|
@ -190,29 +190,30 @@ class SummaryRecord:
|
|||
|
||||
def add_value(self, plugin, name, value):
|
||||
"""
|
||||
Add value to be record later on.
|
||||
Add value to be recorded later.
|
||||
|
||||
When the plugin is 'tensor', 'scalar', 'image' or 'histogram',
|
||||
the name should be the tag name, and the value should be a Tensor.
|
||||
|
||||
When the plugin plugin is 'graph', the value should be a GraphProto.
|
||||
When the plugin is 'graph', the value should be a GraphProto.
|
||||
|
||||
When the plugin 'dataset_graph', 'train_lineage', 'eval_lineage',
|
||||
When the plugin is 'dataset_graph', 'train_lineage', 'eval_lineage',
|
||||
or 'custom_lineage_data', the value should be a proto message.
|
||||
|
||||
|
||||
Args:
|
||||
plugin (str): The plugin for the value.
|
||||
name (str): The name for the value.
|
||||
plugin (str): The value of the plugin.
|
||||
name (str): The value of the name.
|
||||
value (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): \
|
||||
The value to store.
|
||||
|
||||
- GraphProto: The 'value' should be a serialized string this type when the plugin is 'graph'.
|
||||
- Tensor: The 'value' should be this type when the plugin is 'scalar', 'image', 'tensor' or 'histogram'.
|
||||
- TrainLineage: The 'value' should be this type when the plugin is 'train_lineage'.
|
||||
- EvaluationLineage: The 'value' should be this type when the plugin is 'eval_lineage'.
|
||||
- DatasetGraph: The 'value' should be this type when the plugin is 'dataset_graph'.
|
||||
- UserDefinedInfo: The 'value' should be this type when the plugin is 'custom_lineage_data'.
|
||||
- The data type of value should be 'GraphProto' when the plugin is 'graph'.
|
||||
- The data type of value should be 'Tensor' when the plugin is 'scalar', 'image', 'tensor'
|
||||
or 'histogram'.
|
||||
- The data type of value should be 'TrainLineage' when the plugin is 'train_lineage'.
|
||||
- The data type of value should be 'EvaluationLineage' when the plugin is 'eval_lineage'.
|
||||
- The data type of value should be 'DatasetGraph' when the plugin is 'dataset_graph'.
|
||||
- The data type of value should be 'UserDefinedInfo' when the plugin is 'custom_lineage_data'.
|
||||
|
||||
Raises:
|
||||
ValueError: When the name is not valid.
|
||||
|
@ -248,9 +249,9 @@ class SummaryRecord:
|
|||
|
||||
Args:
|
||||
step (int): Represents training step number.
|
||||
train_network (Cell): The network that called the callback.
|
||||
train_network (Cell): The network to call the callback.
|
||||
plugin_filter (Optional[Callable[[str], bool]]): The filter function, \
|
||||
which is used to filter out plugins from being written by return False.
|
||||
which is used to filter out plugins from being written by returning False.
|
||||
|
||||
Returns:
|
||||
bool, whether the record process is successful or not.
|
||||
|
@ -342,7 +343,7 @@ class SummaryRecord:
|
|||
|
||||
def close(self):
|
||||
"""
|
||||
Flush all events and close summary records. Please use with statement to autoclose.
|
||||
Flush all events and close summary records. Please use the statement to autoclose.
|
||||
|
||||
Examples:
|
||||
>>> try:
|
||||
|
|
Loading…
Reference in New Issue