forked from mindspore-Ecosystem/mindspore
modify comment1
This commit is contained in:
parent
f316b1a88f
commit
832fbc1280
|
@ -1,11 +1,11 @@
|
|||
.. py:class:: mindspore.train.callback.Callback
|
||||
|
||||
用于构建回调函数的基类。回调函数是一个上下文管理器,在运行模型时被调用。
|
||||
可以使用此机制进行初始化和释放资源等操作。
|
||||
用于构建Callback函数的基类。Callback函数是一个上下文管理器,在运行模型时被调用。
|
||||
可以使用此机制进行一些自定义操作。
|
||||
|
||||
回调函数可以在step或epoch中的执行一些操作。
|
||||
|
||||
它保存模型相关信息。例如 `network` 、 `train_network` 、 `epoch_num` 、 `batch_num` 、 `loss_fn` 、 `optimizer` 、 `parallel_mode` 、 `device_number` 、 `list_callback` 、 `cur_epoch_num` 、 `cur_step_num` 、 `dataset_sink_mode` 、 `net_outputs` 等。
|
||||
Callback函数可以在step或epoch开始前或结束后执行一些操作。
|
||||
要创建自定义Callback,需要继承Callback基类并重载它相应的方法,有关自定义Callback的详细信息,请查看
|
||||
`Callback <https://www.mindspore.cn/docs/programming_guide/zh-CN/master/custom_debugging_info.html>`_。
|
||||
|
||||
**样例:**
|
||||
|
||||
|
|
|
@ -63,53 +63,97 @@
|
|||
.. py:method:: append_dict
|
||||
:property:
|
||||
|
||||
获取checkpoint中添加字典里面的值。
|
||||
获取需要额外保存到checkpoint中的字典的值。
|
||||
|
||||
**返回:**
|
||||
|
||||
Dict: 字典中的值。
|
||||
|
||||
.. py:method:: async_save
|
||||
:property:
|
||||
|
||||
获取是否异步保存checkpoint。
|
||||
|
||||
**返回:**
|
||||
|
||||
Bool: 是否异步保存checkpoint。
|
||||
|
||||
.. py:method:: enc_key
|
||||
:property:
|
||||
|
||||
获取加密的key值。
|
||||
|
||||
**返回:**
|
||||
|
||||
(None, bytes): 加密的key值。
|
||||
|
||||
.. py:method:: enc_mode
|
||||
:property:
|
||||
|
||||
获取加密模式。
|
||||
|
||||
**返回:**
|
||||
|
||||
str: 加密模式。
|
||||
|
||||
.. py:method:: get_checkpoint_policy()
|
||||
|
||||
获取checkpoint的保存策略。
|
||||
|
||||
**返回:**
|
||||
|
||||
Dict: checkpoint的保存策略。
|
||||
|
||||
.. py:method:: integrated_save
|
||||
:property:
|
||||
|
||||
获取是否合并保存拆分后的Tensor。
|
||||
|
||||
**返回:**
|
||||
|
||||
Bool: 获取是否合并保存拆分后的Tensor。
|
||||
|
||||
.. py:method:: keep_checkpoint_max
|
||||
:property:
|
||||
|
||||
获取最多保存checkpoint文件的数量。
|
||||
|
||||
**返回:**
|
||||
|
||||
Int: 最多保存checkpoint文件的数量。
|
||||
|
||||
.. py:method:: keep_checkpoint_per_n_minutes
|
||||
:property:
|
||||
|
||||
获取每隔多少分钟保存一个checkpoint文件。
|
||||
|
||||
**返回:**
|
||||
|
||||
Int: 每隔多少分钟保存一个checkpoint文件。
|
||||
|
||||
.. py:method:: saved_network
|
||||
:property:
|
||||
|
||||
获取保存的网络。
|
||||
获取需要保存的网络。
|
||||
|
||||
**返回:**
|
||||
|
||||
Cell: 需要保存的网络。
|
||||
|
||||
.. py:method:: save_checkpoint_seconds
|
||||
:property:
|
||||
|
||||
获取每隔多少秒保存一次checkpoint文件。。
|
||||
获取每隔多少秒保存一次checkpoint文件。
|
||||
|
||||
**返回:**
|
||||
|
||||
Int: 每隔多少秒保存一次checkpoint文件。
|
||||
|
||||
.. py:method:: save_checkpoint_steps
|
||||
:property:
|
||||
|
||||
获取每隔多少个step保存一次checkpoint文件。
|
||||
|
||||
**返回:**
|
||||
|
||||
Int: 每隔多少个step保存一次checkpoint文件。
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
.. py:class:: mindspore.train.callback.LearningRateScheduler(learning_rate_function)
|
||||
|
||||
在训练期间更改学习率。
|
||||
用于在训练期间更改学习率。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
|
|
@ -21,4 +21,4 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
提供模型的相关信息。
|
||||
|
||||
在Model方法里提供模型的相关信息。
|
||||
回调函数可以通过调用 `request_stop()` 方法来停止循环。
|
||||
回调函数可以调用 `request_stop()` 方法来停止迭代。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
@ -11,7 +11,7 @@
|
|||
|
||||
.. py:method:: get_stop_requested()
|
||||
|
||||
获取是否停止训练标志。
|
||||
获取是否停止训练的标志。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -19,11 +19,11 @@
|
|||
|
||||
.. py:method:: original_args()
|
||||
|
||||
获取模型的相关信息。
|
||||
获取模型相关信息的对象。
|
||||
|
||||
**返回:**
|
||||
|
||||
dict,模型的相关信息。
|
||||
dict,含有模型的相关信息的对象。
|
||||
|
||||
.. py:method:: request_stop()
|
||||
|
||||
|
|
|
@ -5,9 +5,8 @@ mindspore.async_ckpt_thread_status
|
|||
|
||||
获取异步保存checkpoint文件线程的状态。
|
||||
|
||||
在执行异步保存checkpoint时,可以通过该函数获取线程状态以确保写入checkpoint文件已完成。
|
||||
在执行异步保存checkpoint时,判断异步线程是否执行完毕。
|
||||
|
||||
**返回:**
|
||||
|
||||
True,异步保存checkpoint线程正在运行。
|
||||
False,异步保存checkpoint线程未运行。
|
||||
Bool: True,异步保存checkpoint线程正在运行。False,异步保存checkpoint线程未运行。
|
||||
|
|
|
@ -5,22 +5,22 @@ mindspore.load
|
|||
|
||||
加载MindIR文件。
|
||||
|
||||
返回的对象可以由 `GraphCell` 执行,更多细节参见类 :class:`mindspore.nn.GraphCell` 。
|
||||
返回一个可以由 `GraphCell` 执行的对象,更多细节参见类 :class:`mindspore.nn.GraphCell`。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **file_name** (str) – MindIR文件名。
|
||||
- **file_name** (str) – MindIR文件的全路径名。
|
||||
- **kwargs** (dict) – 配置项字典。
|
||||
- **dec_key** (bytes) - 用于解密的字节类型密钥。 有效长度为 16、24 或 32。
|
||||
- **dec_mode** - 指定解密模式,设置dec_key时生效。可选项:'AES-GCM' | 'AES-CBC'。 默认值:“AES-GCM”。
|
||||
- **dec_mode** (str) - 指定解密模式,设置dec_key时生效。可选项:'AES-GCM' | 'AES-CBC'。 默认值:"AES-GCM"。
|
||||
|
||||
**返回:**
|
||||
|
||||
Object,一个可以由 `GraphCell` 构成的可执行的编译图。
|
||||
GraphCell,一个可以由 `GraphCell` 构成的可执行的编译图。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** – MindIR 文件名不正确。
|
||||
- **ValueError** – MindIR 文件名不存在或`file_name`不是string类型。
|
||||
- **RuntimeError** - 解析MindIR文件失败。
|
||||
|
||||
**样例:**
|
||||
|
|
|
@ -20,7 +20,7 @@ mindspore.load_checkpoint
|
|||
|
||||
**异常:**
|
||||
|
||||
- **ValueError** – checkpoint文件格式正确。
|
||||
- **ValueError** – checkpoint文件格式不正确。
|
||||
|
||||
**样例:**
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ mindspore.load_param_into_net
|
|||
|
||||
.. py:class:: mindspore.load_param_into_net(net, parameter_dict, strict_load=False)
|
||||
|
||||
将参数加载到网络中。
|
||||
将参数加载到网络中,返回网络中没有被加载的参数列表。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
|
|
@ -3,13 +3,11 @@ mindspore.parse_print
|
|||
|
||||
.. py:class:: mindspore.parse_print(print_file_name)
|
||||
|
||||
解析由 mindspore.ops.Print 生成的保存数据。
|
||||
|
||||
将数据打印到屏幕上。也可以通过设置 `context` 中的参数 `print_file_path` 来关闭,数据会保存在 `print_file_path` 指定的文件中。 parse_print 用于解析保存的文件。 更多信息请参考 :func:`mindspore.context.set_context` 和 :class:`mindspore.ops.Print` 。
|
||||
解析由 mindspore.ops.Print 生成的数据文件。
|
||||
|
||||
**参数:**
|
||||
|
||||
**print_file_name** (str) – 保存打印数据的文件名。
|
||||
**print_file_name** (str) – 需要解析的文件名。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -17,12 +15,12 @@ mindspore.parse_print
|
|||
|
||||
**异常:**
|
||||
|
||||
**ValueError** – 指定的文件名可能为空,请确保输入正确的文件名。
|
||||
**ValueError** – 指定的文件不存在或为空。
|
||||
**RuntimeError** - 解析文件失败。
|
||||
|
||||
**样例:**
|
||||
|
||||
>>> import numpy as np
|
||||
>>> import mindspore
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, context
|
||||
|
@ -40,8 +38,10 @@ mindspore.parse_print
|
|||
>>> input_pra = Tensor(x)
|
||||
>>> net = PrintInputTensor()
|
||||
>>> net(input_pra)
|
||||
|
||||
>>> import mindspore
|
||||
>>> data = mindspore.parse_print('./log.data')
|
||||
>>> print(data)
|
||||
['print:', Tensor(shape=[2, 4], dtype=Float32, value=
|
||||
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
|
||||
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
|
||||
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
|
||||
|
|
|
@ -77,17 +77,13 @@ class Callback:
|
|||
"""
|
||||
Abstract base class used to build a callback class. Callbacks are context managers
|
||||
which will be entered and exited when passing into the Model.
|
||||
You can use this mechanism to initialize and release resources automatically.
|
||||
You can use this mechanism to do some custom operations.
|
||||
|
||||
Callback function will execute some operations in the current step or epoch.
|
||||
Callback function can perform some operations before and after step or epoch.
|
||||
To create a custom callback, subclass Callback and override the method associated
|
||||
with the stage of interest. For details of Callback fusion, please check
|
||||
`Callback <https://www.mindspore.cn/docs/programming_guide/zh-CN/master/custom_debugging_info.html>`_.
|
||||
|
||||
It holds the information of the model. Such as `network`, `train_network`, `epoch_num`, `batch_num`,
|
||||
`loss_fn`, `optimizer`, `parallel_mode`, `device_number`, `list_callback`, `cur_epoch_num`,
|
||||
`cur_step_num`, `dataset_sink_mode`, `net_outputs` and so on.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Model, nn
|
||||
>>> from mindspore.train.callback import Callback
|
||||
|
|
|
@ -198,7 +198,12 @@ class CheckpointConfig:
|
|||
|
||||
@property
|
||||
def save_checkpoint_steps(self):
|
||||
"""Get the value of _save_checkpoint_steps."""
|
||||
"""
|
||||
Get the value of steps to save checkpoint.
|
||||
|
||||
Returns:
|
||||
Int, steps to save checkpoint.
|
||||
"""
|
||||
return self._save_checkpoint_steps
|
||||
|
||||
@property
|
||||
|
@ -208,46 +213,91 @@ class CheckpointConfig:
|
|||
|
||||
@property
|
||||
def keep_checkpoint_max(self):
|
||||
"""Get the value of _keep_checkpoint_max."""
|
||||
"""
|
||||
Get the value of maximum number of checkpoint files can be saved.
|
||||
|
||||
Returns:
|
||||
Int, Maximum number of checkpoint files can be saved.
|
||||
"""
|
||||
return self._keep_checkpoint_max
|
||||
|
||||
@property
|
||||
def keep_checkpoint_per_n_minutes(self):
|
||||
"""Get the value of _keep_checkpoint_per_n_minutes."""
|
||||
"""
|
||||
Get the value of save the checkpoint file every n minutes.
|
||||
|
||||
Returns:
|
||||
Int, save the checkpoint file every n minutes.
|
||||
"""
|
||||
return self._keep_checkpoint_per_n_minutes
|
||||
|
||||
@property
|
||||
def integrated_save(self):
|
||||
"""Get the value of _integrated_save."""
|
||||
"""
|
||||
Get the value of whether to merge and save the split Tensor in the automatic parallel scenario.
|
||||
|
||||
Returns:
|
||||
Bool, whether to merge and save the split Tensor in the automatic parallel scenario.
|
||||
"""
|
||||
return self._integrated_save
|
||||
|
||||
@property
|
||||
def async_save(self):
|
||||
"""Get the value of _async_save."""
|
||||
"""
|
||||
Get the value of whether asynchronous execution saves the checkpoint to a file.
|
||||
|
||||
Returns:
|
||||
Bool, whether asynchronous execution saves the checkpoint to a file.
|
||||
"""
|
||||
return self._async_save
|
||||
|
||||
@property
|
||||
def saved_network(self):
|
||||
"""Get the value of _saved_network"""
|
||||
"""
|
||||
Get the value of network to be saved in checkpoint file.
|
||||
|
||||
Returns:
|
||||
Cell, network to be saved in checkpoint file.
|
||||
"""
|
||||
return self._saved_network
|
||||
|
||||
@property
|
||||
def enc_key(self):
|
||||
"""Get the value of _enc_key"""
|
||||
"""
|
||||
Get the value of byte type key used for encryption.
|
||||
|
||||
Returns:
|
||||
(None, bytes), byte type key used for encryption.
|
||||
"""
|
||||
return self._enc_key
|
||||
|
||||
@property
|
||||
def enc_mode(self):
|
||||
"""Get the value of _enc_mode"""
|
||||
"""
|
||||
Get the value of the encryption mode.
|
||||
|
||||
Returns:
|
||||
str, encryption mode.
|
||||
"""
|
||||
return self._enc_mode
|
||||
|
||||
@property
|
||||
def append_dict(self):
|
||||
"""Get the value of append_dict."""
|
||||
"""
|
||||
Get the value of information dict saved to checkpoint file.
|
||||
|
||||
Returns:
|
||||
Dict, the information saved to checkpoint file.
|
||||
"""
|
||||
return self._append_dict
|
||||
|
||||
def get_checkpoint_policy(self):
|
||||
"""Get the policy of checkpoint."""
|
||||
"""
|
||||
Get the policy of checkpoint.
|
||||
|
||||
Returns:
|
||||
Dict, the information of checkpoint policy.
|
||||
"""
|
||||
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
|
||||
'save_checkpoint_seconds': self.save_checkpoint_seconds,
|
||||
'keep_checkpoint_max': self.keep_checkpoint_max,
|
||||
|
|
|
@ -343,10 +343,10 @@ def load(file_name, **kwargs):
|
|||
- dec_mode (str): Specifies the decryption mode, to take effect when dec_key is set.
|
||||
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
|
||||
Returns:
|
||||
Object, a compiled graph that can executed by `GraphCell`.
|
||||
GraphCell, a compiled graph that can executed by `GraphCell`.
|
||||
|
||||
Raises:
|
||||
ValueError: MindIR file name is incorrect.
|
||||
ValueError: MindIR file does not exist or `file_name` is not a string.
|
||||
RuntimeError: Failed to parse MindIR file.
|
||||
|
||||
Examples:
|
||||
|
@ -417,7 +417,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
Dict, key is parameter name, value is a Parameter.
|
||||
|
||||
Raises:
|
||||
ValueError: Checkpoint file is incorrect.
|
||||
ValueError: Checkpoint file's format is incorrect.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import load_checkpoint
|
||||
|
@ -535,7 +535,7 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
|
|||
|
||||
def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||
"""
|
||||
Load parameters into network.
|
||||
Load parameters into network, return parameter list that are not loaded in the network.
|
||||
|
||||
Args:
|
||||
net (Cell): The network where the parameters will be loaded.
|
||||
|
@ -546,7 +546,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|||
on the parameters of the same type, such as float32 to float16. Default: False.
|
||||
|
||||
Returns:
|
||||
List, parameter name not loaded into the network
|
||||
List, the parameter name which are not loaded into the network.
|
||||
|
||||
Raises:
|
||||
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
||||
|
@ -878,9 +878,7 @@ def _generate_front_info_for_param_data_file(is_encrypt, kwargs):
|
|||
|
||||
|
||||
def _change_file(f, dirname, external_local, is_encrypt, kwargs):
|
||||
'''
|
||||
Change to another file to write parameter data
|
||||
'''
|
||||
"""Change to another file to write parameter data."""
|
||||
# The parameter has been not written in the file
|
||||
front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs)
|
||||
f.seek(0, 0)
|
||||
|
@ -895,9 +893,7 @@ def _change_file(f, dirname, external_local, is_encrypt, kwargs):
|
|||
|
||||
|
||||
def _get_data_file(is_encrypt, kwargs, data_file_name):
|
||||
'''
|
||||
Get Data File to write parameter data
|
||||
'''
|
||||
"""Get Data File to write parameter data."""
|
||||
# Reserves 64 bytes as spare information such as check data
|
||||
offset = 64
|
||||
if os.path.exists(data_file_name):
|
||||
|
@ -913,9 +909,7 @@ def _get_data_file(is_encrypt, kwargs, data_file_name):
|
|||
|
||||
|
||||
def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
||||
'''
|
||||
The function to save parameter data
|
||||
'''
|
||||
"""The function to save parameter data."""
|
||||
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
|
||||
# save parameter
|
||||
file_prefix = file_name.split("/")[-1]
|
||||
|
@ -1070,10 +1064,7 @@ def _save_dataset_to_mindir(model, dataset):
|
|||
|
||||
|
||||
def quant_mode_manage(func):
|
||||
"""
|
||||
Inherit the quant_mode in old version.
|
||||
"""
|
||||
|
||||
"""Inherit the quant_mode in old version."""
|
||||
def warpper(network, *inputs, file_format, **kwargs):
|
||||
if 'quant_mode' not in kwargs:
|
||||
return network
|
||||
|
@ -1090,9 +1081,7 @@ def quant_mode_manage(func):
|
|||
|
||||
@quant_mode_manage
|
||||
def _quant_export(network, *inputs, file_format, **kwargs):
|
||||
"""
|
||||
Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
|
||||
"""
|
||||
"""Exports MindSpore quantization predict model to deploy with AIR and MINDIR."""
|
||||
supported_device = ["Ascend", "GPU"]
|
||||
supported_formats = ['AIR', 'MINDIR']
|
||||
quant_mode_formats = ['QUANT', 'NONQUANT']
|
||||
|
@ -1130,23 +1119,20 @@ def _quant_export(network, *inputs, file_format, **kwargs):
|
|||
|
||||
def parse_print(print_file_name):
|
||||
"""
|
||||
Parse saved data generated by mindspore.ops.Print. Print is used to print data to screen in graph mode.
|
||||
It can also been turned off by setting the parameter `print_file_path` in `context`, and the data will be saved
|
||||
in a file specified by print_file_path. parse_print is used to parse the saved file. For more information
|
||||
please refer to :func:`mindspore.context.set_context` and :class:`mindspore.ops.Print`.
|
||||
Parse data file generated by mindspore.ops.Print.
|
||||
|
||||
Args:
|
||||
print_file_name (str): The file name of saved print data.
|
||||
print_file_name (str): The file name needs to be parsed.
|
||||
|
||||
Returns:
|
||||
List, element of list is Tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: The print file may be empty, please make sure enter the correct file name.
|
||||
ValueError: The print file does not exist or is empty.
|
||||
RuntimeError: Failed to parse the file.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore import nn
|
||||
>>> from mindspore import Tensor, context
|
||||
|
@ -1632,12 +1618,11 @@ def async_ckpt_thread_status():
|
|||
"""
|
||||
Get the status of asynchronous save checkpoint thread.
|
||||
|
||||
When performing asynchronous save checkpoint, you can get the thread state through this function
|
||||
to ensure that write checkpoint file is completed.
|
||||
When performing asynchronous save checkpoint, you can determine whether the asynchronous thread is completed.
|
||||
|
||||
Returns:
|
||||
True, Asynchronous save checkpoint thread is running.
|
||||
False, Asynchronous save checkpoint thread is not executing.
|
||||
bool, True, Asynchronous save checkpoint thread is running.
|
||||
False, Asynchronous save checkpoint thread is not executing.
|
||||
"""
|
||||
thr_list = threading.enumerate()
|
||||
return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
|
||||
|
|
Loading…
Reference in New Issue