!19088 modify serialization comment

Merge pull request !19088 from changzherui/mod_serializer_comment
This commit is contained in:
i-robot 2021-06-30 06:16:03 +00:00 committed by Gitee
commit b4343edb5c
2 changed files with 35 additions and 14 deletions

View File

@ -80,14 +80,21 @@ class Callback:
Callback function will execute some operations in the current step or epoch.
Examples:
>>> from mindspore.train._callback import Callback
>>> from mindspore import Model, nn
>>> from mindspore.train.callback._callback import Callback
>>> class Print_info(Callback):
>>> def step_end(self, run_context):
>>> cb_params = run_context.original_args()
>>> print(cb_params.cur_epoch_num)
>>> print(cb_params.cur_step_num)
>>> print("step_num: ", cb_params.cur_step_num)
>>>
>>> print_cb = Print_info()
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> model.train(1, dataset, callbacks=print_cb)
step_num: 1
"""
def __enter__(self):

View File

@ -199,7 +199,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like
[{"name": param_name, "data": param_data},...], the type of
param_name would be string, and the type of param_data would
be parameter or `Tensor`).
be parameter or Tensor).
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
@ -337,6 +337,10 @@ def load(file_name, **kwargs):
>>> graph = load("net.mindir")
>>> net = nn.GraphCell(graph)
>>> output = net(input)
>>> print(output)
[[[[0.03204346 0.04455566 0.03509521]
[0.02406311 0.04125977 0.02404785]
[0.02018738 0.0292511 0.00889587]]]]
"""
if not isinstance(file_name, str):
raise ValueError("The file name must be string.")
@ -391,6 +395,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
>>>
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
>>> print(param_dict["conv2.weight]")
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True
"""
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
@ -496,8 +502,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
Args:
net (Cell): Cell network.
parameter_dict (dict): Parameter dictionary.
strict_load (bool): Whether to strict load the parameter into net. False: if some parameters in the net
not loaded, it will remove some parameter's prefix name continue to load. Default: False
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
in the param_dict into net with the same suffix and load
parameter with different accuracy. Default: False.
Returns:
List, parameters not loaded in the network.
Raises:
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
@ -695,15 +705,15 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
kwargs (dict): Configuration options dictionary.
- quant_mode: If the network is quantization aware training network, the quant_mode should
- quant_mode (str): If the network is quantization aware training network, the quant_mode should
be set to "QUANT", else the quant_mode should be set to "NONQUANT".
- mean: The mean of input data after preprocessing, used for quantizing the first layer of network.
- mean (float): The mean of input data after preprocessing, used for quantizing the first layer of network.
Default: 127.5.
- std_dev: The variance of input data after preprocessing, used for quantizing the first layer of network.
Default: 127.5.
- enc_key: Byte type key used for encryption. Tha valid length is 16, 24, or 32.
- enc_mode: Specifies the encryption mode, take effect when enc_key is set. Option: 'AES-GCM' | 'AES-CBC'.
Default: 'AES-GCM'.
- std_dev (float): The variance of input data after preprocessing,
used for quantizing the first layer of network. Default: 127.5.
- enc_key (str): Byte type key used for encryption. Tha valid length is 16, 24, or 32.
- enc_mode (str): Specifies the encryption mode, take effect when enc_key is set.
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
Examples:
>>> import numpy as np
@ -1130,7 +1140,7 @@ def build_searched_strategy(strategy_filename):
Raises:
ValueError: Strategy file is incorrect.
TypeError: Strategy_filename is not str.
TypeError: strategy_filename is not str.
"""
if not isinstance(strategy_filename, str):
raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.")
@ -1179,6 +1189,8 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
KeyError: The parameter name is not in keys of strategy.
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.common.parameter import Parameter
>>> from mindspore.train import merge_sliced_parameter
>>>
@ -1192,6 +1204,8 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
... "network.embedding_table")]
>>> merged_parameter = merge_sliced_parameter(sliced_parameters)
>>> print(merged_parameter)
Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
"""
if not isinstance(sliced_parameters, list):
raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.")