!19088 modify serialization comment
Merge pull request !19088 from changzherui/mod_serializer_comment
This commit is contained in:
commit
b4343edb5c
|
@ -80,14 +80,21 @@ class Callback:
|
|||
Callback function will execute some operations in the current step or epoch.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.train._callback import Callback
|
||||
>>> from mindspore import Model, nn
|
||||
>>> from mindspore.train.callback._callback import Callback
|
||||
>>> class Print_info(Callback):
|
||||
>>> def step_end(self, run_context):
|
||||
>>> cb_params = run_context.original_args()
|
||||
>>> print(cb_params.cur_epoch_num)
|
||||
>>> print(cb_params.cur_step_num)
|
||||
>>> print("step_num: ", cb_params.cur_step_num)
|
||||
>>>
|
||||
>>> print_cb = Print_info()
|
||||
>>> dataset = create_custom_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
>>> model.train(1, dataset, callbacks=print_cb)
|
||||
step_num: 1
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
|
|
|
@ -199,7 +199,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|||
save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like
|
||||
[{"name": param_name, "data": param_data},...], the type of
|
||||
param_name would be string, and the type of param_data would
|
||||
be parameter or `Tensor`).
|
||||
be parameter or Tensor).
|
||||
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
|
||||
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
|
||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
||||
|
@ -337,6 +337,10 @@ def load(file_name, **kwargs):
|
|||
>>> graph = load("net.mindir")
|
||||
>>> net = nn.GraphCell(graph)
|
||||
>>> output = net(input)
|
||||
>>> print(output)
|
||||
[[[[0.03204346 0.04455566 0.03509521]
|
||||
[0.02406311 0.04125977 0.02404785]
|
||||
[0.02018738 0.0292511 0.00889587]]]]
|
||||
"""
|
||||
if not isinstance(file_name, str):
|
||||
raise ValueError("The file name must be string.")
|
||||
|
@ -391,6 +395,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
>>>
|
||||
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
||||
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
||||
>>> print(param_dict["conv2.weight]")
|
||||
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True
|
||||
"""
|
||||
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
|
||||
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
||||
|
@ -496,8 +502,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|||
Args:
|
||||
net (Cell): Cell network.
|
||||
parameter_dict (dict): Parameter dictionary.
|
||||
strict_load (bool): Whether to strict load the parameter into net. False: if some parameters in the net
|
||||
not loaded, it will remove some parameter's prefix name continue to load. Default: False
|
||||
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
|
||||
in the param_dict into net with the same suffix and load
|
||||
parameter with different accuracy. Default: False.
|
||||
|
||||
Returns:
|
||||
List, parameters not loaded in the network.
|
||||
|
||||
Raises:
|
||||
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
||||
|
@ -695,15 +705,15 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|||
|
||||
kwargs (dict): Configuration options dictionary.
|
||||
|
||||
- quant_mode: If the network is quantization aware training network, the quant_mode should
|
||||
- quant_mode (str): If the network is quantization aware training network, the quant_mode should
|
||||
be set to "QUANT", else the quant_mode should be set to "NONQUANT".
|
||||
- mean: The mean of input data after preprocessing, used for quantizing the first layer of network.
|
||||
- mean (float): The mean of input data after preprocessing, used for quantizing the first layer of network.
|
||||
Default: 127.5.
|
||||
- std_dev: The variance of input data after preprocessing, used for quantizing the first layer of network.
|
||||
Default: 127.5.
|
||||
- enc_key: Byte type key used for encryption. Tha valid length is 16, 24, or 32.
|
||||
- enc_mode: Specifies the encryption mode, take effect when enc_key is set. Option: 'AES-GCM' | 'AES-CBC'.
|
||||
Default: 'AES-GCM'.
|
||||
- std_dev (float): The variance of input data after preprocessing,
|
||||
used for quantizing the first layer of network. Default: 127.5.
|
||||
- enc_key (str): Byte type key used for encryption. Tha valid length is 16, 24, or 32.
|
||||
- enc_mode (str): Specifies the encryption mode, take effect when enc_key is set.
|
||||
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -1130,7 +1140,7 @@ def build_searched_strategy(strategy_filename):
|
|||
|
||||
Raises:
|
||||
ValueError: Strategy file is incorrect.
|
||||
TypeError: Strategy_filename is not str.
|
||||
TypeError: strategy_filename is not str.
|
||||
"""
|
||||
if not isinstance(strategy_filename, str):
|
||||
raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.")
|
||||
|
@ -1179,6 +1189,8 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|||
KeyError: The parameter name is not in keys of strategy.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.common.parameter import Parameter
|
||||
>>> from mindspore.train import merge_sliced_parameter
|
||||
>>>
|
||||
|
@ -1192,6 +1204,8 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|||
... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
|
||||
... "network.embedding_table")]
|
||||
>>> merged_parameter = merge_sliced_parameter(sliced_parameters)
|
||||
>>> print(merged_parameter)
|
||||
Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
|
||||
"""
|
||||
if not isinstance(sliced_parameters, list):
|
||||
raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.")
|
||||
|
|
Loading…
Reference in New Issue