modify comment1

This commit is contained in:
changzherui 2021-12-28 20:07:42 +08:00
parent f316b1a88f
commit 832fbc1280
13 changed files with 153 additions and 79 deletions

View File

@ -1,11 +1,11 @@
.. py:class:: mindspore.train.callback.Callback
用于构建回调函数的基类。回调函数是一个上下文管理器,在运行模型时被调用。
可以使用此机制进行初始化和释放资源等操作。
用于构建Callback函数的基类。Callback函数是一个上下文管理器,在运行模型时被调用。
可以使用此机制进行一些自定义操作。
回调函数可以在step或epoch中的执行一些操作。
它保存模型相关信息。例如 `network``train_network``epoch_num``batch_num``loss_fn``optimizer``parallel_mode``device_number``list_callback``cur_epoch_num``cur_step_num``dataset_sink_mode``net_outputs`
Callback函数可以在step或epoch开始前或结束后执行一些操作。
要创建自定义Callback需要继承Callback基类并重载它相应的方法有关自定义Callback的详细信息请查看
`Callback <https://www.mindspore.cn/docs/programming_guide/zh-CN/master/custom_debugging_info.html>`_
**样例:**

View File

@ -63,53 +63,97 @@
.. py:method:: append_dict
:property:
获取checkpoint中添加字典里面的值。
获取需要额外保存到checkpoint中的字典的值。
**返回:**
Dict: 字典中的值。
.. py:method:: async_save
:property:
获取是否异步保存checkpoint。
**返回:**
Bool: 是否异步保存checkpoint。
.. py:method:: enc_key
:property:
获取加密的key值。
**返回:**
(None, bytes): 加密的key值。
.. py:method:: enc_mode
:property:
获取加密模式。
**返回:**
str: 加密模式。
.. py:method:: get_checkpoint_policy()
获取checkpoint的保存策略。
**返回:**
Dict: checkpoint的保存策略。
.. py:method:: integrated_save
:property:
获取是否合并保存拆分后的Tensor。
**返回:**
Bool: 获取是否合并保存拆分后的Tensor。
.. py:method:: keep_checkpoint_max
:property:
获取最多保存checkpoint文件的数量。
**返回:**
Int: 最多保存checkpoint文件的数量。
.. py:method:: keep_checkpoint_per_n_minutes
:property:
获取每隔多少分钟保存一个checkpoint文件。
**返回:**
Int: 每隔多少分钟保存一个checkpoint文件。
.. py:method:: saved_network
:property:
获取保存的网络。
获取需要保存的网络。
**返回:**
Cell: 需要保存的网络。
.. py:method:: save_checkpoint_seconds
:property:
获取每隔多少秒保存一次checkpoint文件。。
获取每隔多少秒保存一次checkpoint文件。
**返回:**
Int: 每隔多少秒保存一次checkpoint文件。
.. py:method:: save_checkpoint_steps
:property:
获取每隔多少个step保存一次checkpoint文件。
**返回:**
Int: 每隔多少个step保存一次checkpoint文件。

View File

@ -1,6 +1,6 @@
.. py:class:: mindspore.train.callback.LearningRateScheduler(learning_rate_function)
在训练期间更改学习率。
用于在训练期间更改学习率。
**参数:**

View File

@ -21,4 +21,4 @@
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的相关信息。

View File

@ -3,7 +3,7 @@
提供模型的相关信息。
在Model方法里提供模型的相关信息。
回调函数可以通过调用 `request_stop()` 方法来停止循环
回调函数可以调用 `request_stop()` 方法来停止迭代
**参数:**
@ -11,7 +11,7 @@
.. py:method:: get_stop_requested()
获取是否停止训练标志。
获取是否停止训练标志。
**返回:**
@ -19,11 +19,11 @@
.. py:method:: original_args()
获取模型相关信息。
获取模型相关信息的对象
**返回:**
dict模型的相关信息。
dict含有模型的相关信息的对象
.. py:method:: request_stop()

View File

@ -5,9 +5,8 @@ mindspore.async_ckpt_thread_status
获取异步保存checkpoint文件线程的状态。
在执行异步保存checkpoint时可以通过该函数获取线程状态以确保写入checkpoint文件已完成
在执行异步保存checkpoint时判断异步线程是否执行完毕
**返回:**
True异步保存checkpoint线程正在运行。
False异步保存checkpoint线程未运行。
Bool: True异步保存checkpoint线程正在运行。False异步保存checkpoint线程未运行。

View File

@ -5,22 +5,22 @@ mindspore.load
加载MindIR文件。
返回的对象可以由 `GraphCell` 执行,更多细节参见类 :class:`mindspore.nn.GraphCell`
返回一个可以由 `GraphCell` 执行的对象,更多细节参见类 :class:`mindspore.nn.GraphCell`
**参数:**
- **file_name** (str) MindIR文件名。
- **file_name** (str) MindIR文件的全路径名。
- **kwargs** (dict) 配置项字典。
- **dec_key** (bytes) - 用于解密的字节类型密钥。 有效长度为 16、24 或 32。
- **dec_mode** - 指定解密模式设置dec_key时生效。可选项'AES-GCM' | 'AES-CBC'。 默认值:“AES-GCM”
- **dec_mode** (str) - 指定解密模式设置dec_key时生效。可选项'AES-GCM' | 'AES-CBC'。 默认值:"AES-GCM"
**返回:**
Object,一个可以由 `GraphCell` 构成的可执行的编译图。
GraphCell,一个可以由 `GraphCell` 构成的可执行的编译图。
**异常:**
- **ValueError** MindIR 文件名不正确
- **ValueError** MindIR 文件名不存在或`file_name`不是string类型
- **RuntimeError** - 解析MindIR文件失败。
**样例:**

View File

@ -20,7 +20,7 @@ mindspore.load_checkpoint
**异常:**
- **ValueError** checkpoint文件格式正确。
- **ValueError** checkpoint文件格式正确。
**样例:**

View File

@ -3,7 +3,7 @@ mindspore.load_param_into_net
.. py:class:: mindspore.load_param_into_net(net, parameter_dict, strict_load=False)
将参数加载到网络中。
将参数加载到网络中,返回网络中没有被加载的参数列表
**参数:**

View File

@ -3,13 +3,11 @@ mindspore.parse_print
.. py:class:: mindspore.parse_print(print_file_name)
解析由 mindspore.ops.Print 生成的保存数据。
将数据打印到屏幕上。也可以通过设置 `context` 中的参数 `print_file_path` 来关闭,数据会保存在 `print_file_path` 指定的文件中。 parse_print 用于解析保存的文件。 更多信息请参考 :func:`mindspore.context.set_context`:class:`mindspore.ops.Print`
解析由 mindspore.ops.Print 生成的数据文件。
**参数:**
**print_file_name** (str) 保存打印数据的文件名。
**print_file_name** (str) 需要解析的文件名。
**返回:**
@ -17,12 +15,12 @@ mindspore.parse_print
**异常:**
**ValueError** 指定的文件名可能为空,请确保输入正确的文件名。
**ValueError** 指定的文件不存在或为空。
**RuntimeError** - 解析文件失败。
**样例:**
>>> import numpy as np
>>> import mindspore
>>> import mindspore.ops as ops
>>> from mindspore.nn as nn
>>> from mindspore import Tensor, context
@ -40,8 +38,10 @@ mindspore.parse_print
>>> input_pra = Tensor(x)
>>> net = PrintInputTensor()
>>> net(input_pra)
>>> import mindspore
>>> data = mindspore.parse_print('./log.data')
>>> print(data)
['print:', Tensor(shape=[2, 4], dtype=Float32, value=
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]

View File

@ -77,17 +77,13 @@ class Callback:
"""
Abstract base class used to build a callback class. Callbacks are context managers
which will be entered and exited when passing into the Model.
You can use this mechanism to initialize and release resources automatically.
You can use this mechanism to do some custom operations.
Callback function will execute some operations in the current step or epoch.
Callback function can perform some operations before and after step or epoch.
To create a custom callback, subclass Callback and override the method associated
with the stage of interest. For details of Callback fusion, please check
`Callback <https://www.mindspore.cn/docs/programming_guide/zh-CN/master/custom_debugging_info.html>`_.
It holds the information of the model. Such as `network`, `train_network`, `epoch_num`, `batch_num`,
`loss_fn`, `optimizer`, `parallel_mode`, `device_number`, `list_callback`, `cur_epoch_num`,
`cur_step_num`, `dataset_sink_mode`, `net_outputs` and so on.
Examples:
>>> from mindspore import Model, nn
>>> from mindspore.train.callback import Callback

View File

@ -198,7 +198,12 @@ class CheckpointConfig:
@property
def save_checkpoint_steps(self):
"""Get the value of _save_checkpoint_steps."""
"""
Get the value of steps to save checkpoint.
Returns:
Int, steps to save checkpoint.
"""
return self._save_checkpoint_steps
@property
@ -208,46 +213,91 @@ class CheckpointConfig:
@property
def keep_checkpoint_max(self):
"""Get the value of _keep_checkpoint_max."""
"""
Get the value of maximum number of checkpoint files can be saved.
Returns:
Int, Maximum number of checkpoint files can be saved.
"""
return self._keep_checkpoint_max
@property
def keep_checkpoint_per_n_minutes(self):
"""Get the value of _keep_checkpoint_per_n_minutes."""
"""
Get the value of save the checkpoint file every n minutes.
Returns:
Int, save the checkpoint file every n minutes.
"""
return self._keep_checkpoint_per_n_minutes
@property
def integrated_save(self):
"""Get the value of _integrated_save."""
"""
Get the value of whether to merge and save the split Tensor in the automatic parallel scenario.
Returns:
Bool, whether to merge and save the split Tensor in the automatic parallel scenario.
"""
return self._integrated_save
@property
def async_save(self):
"""Get the value of _async_save."""
"""
Get the value of whether asynchronous execution saves the checkpoint to a file.
Returns:
Bool, whether asynchronous execution saves the checkpoint to a file.
"""
return self._async_save
@property
def saved_network(self):
"""Get the value of _saved_network"""
"""
Get the value of network to be saved in checkpoint file.
Returns:
Cell, network to be saved in checkpoint file.
"""
return self._saved_network
@property
def enc_key(self):
"""Get the value of _enc_key"""
"""
Get the value of byte type key used for encryption.
Returns:
(None, bytes), byte type key used for encryption.
"""
return self._enc_key
@property
def enc_mode(self):
"""Get the value of _enc_mode"""
"""
Get the value of the encryption mode.
Returns:
str, encryption mode.
"""
return self._enc_mode
@property
def append_dict(self):
"""Get the value of append_dict."""
"""
Get the value of information dict saved to checkpoint file.
Returns:
Dict, the information saved to checkpoint file.
"""
return self._append_dict
def get_checkpoint_policy(self):
"""Get the policy of checkpoint."""
"""
Get the policy of checkpoint.
Returns:
Dict, the information of checkpoint policy.
"""
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
'save_checkpoint_seconds': self.save_checkpoint_seconds,
'keep_checkpoint_max': self.keep_checkpoint_max,

View File

@ -343,10 +343,10 @@ def load(file_name, **kwargs):
- dec_mode (str): Specifies the decryption mode, to take effect when dec_key is set.
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
Returns:
Object, a compiled graph that can executed by `GraphCell`.
GraphCell, a compiled graph that can executed by `GraphCell`.
Raises:
ValueError: MindIR file name is incorrect.
ValueError: MindIR file does not exist or `file_name` is not a string.
RuntimeError: Failed to parse MindIR file.
Examples:
@ -417,7 +417,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
Dict, key is parameter name, value is a Parameter.
Raises:
ValueError: Checkpoint file is incorrect.
ValueError: Checkpoint file's format is incorrect.
Examples:
>>> from mindspore import load_checkpoint
@ -535,7 +535,7 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
def load_param_into_net(net, parameter_dict, strict_load=False):
"""
Load parameters into network.
Load parameters into network, return parameter list that are not loaded in the network.
Args:
net (Cell): The network where the parameters will be loaded.
@ -546,7 +546,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
on the parameters of the same type, such as float32 to float16. Default: False.
Returns:
List, parameter name not loaded into the network
List, the parameter name which are not loaded into the network.
Raises:
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
@ -878,9 +878,7 @@ def _generate_front_info_for_param_data_file(is_encrypt, kwargs):
def _change_file(f, dirname, external_local, is_encrypt, kwargs):
'''
Change to another file to write parameter data
'''
"""Change to another file to write parameter data."""
# The parameter has been not written in the file
front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs)
f.seek(0, 0)
@ -895,9 +893,7 @@ def _change_file(f, dirname, external_local, is_encrypt, kwargs):
def _get_data_file(is_encrypt, kwargs, data_file_name):
'''
Get Data File to write parameter data
'''
"""Get Data File to write parameter data."""
# Reserves 64 bytes as spare information such as check data
offset = 64
if os.path.exists(data_file_name):
@ -913,9 +909,7 @@ def _get_data_file(is_encrypt, kwargs, data_file_name):
def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
'''
The function to save parameter data
'''
"""The function to save parameter data."""
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
# save parameter
file_prefix = file_name.split("/")[-1]
@ -1070,10 +1064,7 @@ def _save_dataset_to_mindir(model, dataset):
def quant_mode_manage(func):
"""
Inherit the quant_mode in old version.
"""
"""Inherit the quant_mode in old version."""
def warpper(network, *inputs, file_format, **kwargs):
if 'quant_mode' not in kwargs:
return network
@ -1090,9 +1081,7 @@ def quant_mode_manage(func):
@quant_mode_manage
def _quant_export(network, *inputs, file_format, **kwargs):
"""
Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
"""
"""Exports MindSpore quantization predict model to deploy with AIR and MINDIR."""
supported_device = ["Ascend", "GPU"]
supported_formats = ['AIR', 'MINDIR']
quant_mode_formats = ['QUANT', 'NONQUANT']
@ -1130,23 +1119,20 @@ def _quant_export(network, *inputs, file_format, **kwargs):
def parse_print(print_file_name):
"""
Parse saved data generated by mindspore.ops.Print. Print is used to print data to screen in graph mode.
It can also been turned off by setting the parameter `print_file_path` in `context`, and the data will be saved
in a file specified by print_file_path. parse_print is used to parse the saved file. For more information
please refer to :func:`mindspore.context.set_context` and :class:`mindspore.ops.Print`.
Parse data file generated by mindspore.ops.Print.
Args:
print_file_name (str): The file name of saved print data.
print_file_name (str): The file name needs to be parsed.
Returns:
List, element of list is Tensor.
Raises:
ValueError: The print file may be empty, please make sure enter the correct file name.
ValueError: The print file does not exist or is empty.
RuntimeError: Failed to parse the file.
Examples:
>>> import numpy as np
>>> import mindspore
>>> import mindspore.ops as ops
>>> from mindspore import nn
>>> from mindspore import Tensor, context
@ -1632,12 +1618,11 @@ def async_ckpt_thread_status():
"""
Get the status of asynchronous save checkpoint thread.
When performing asynchronous save checkpoint, you can get the thread state through this function
to ensure that write checkpoint file is completed.
When performing asynchronous save checkpoint, you can determine whether the asynchronous thread is completed.
Returns:
True, Asynchronous save checkpoint thread is running.
False, Asynchronous save checkpoint thread is not executing.
bool, True, Asynchronous save checkpoint thread is running.
False, Asynchronous save checkpoint thread is not executing.
"""
thr_list = threading.enumerate()
return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]