forked from mindspore-Ecosystem/mindspore
!9457 modify some examples
From: @caozhou_huawei Reviewed-by: Signed-off-by:
This commit is contained in:
commit
b9081806b8
|
@ -240,7 +240,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
||||||
ValueError: Checkpoint file is incorrect.
|
ValueError: Checkpoint file is incorrect.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> ckpt_file_name = "./checkpoint/LeNet5-2_1875.ckpt"
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
||||||
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
||||||
"""
|
"""
|
||||||
if not isinstance(ckpt_file_name, str):
|
if not isinstance(ckpt_file_name, str):
|
||||||
|
@ -341,8 +341,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||||
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = LeNet5()
|
>>> net = Net()
|
||||||
>>> param_dict = load_checkpoint("LeNet5-2_1875.ckpt")
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
||||||
|
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
||||||
>>> load_param_into_net(net, param_dict)
|
>>> load_param_into_net(net, param_dict)
|
||||||
"""
|
"""
|
||||||
if not isinstance(net, nn.Cell):
|
if not isinstance(net, nn.Cell):
|
||||||
|
@ -783,9 +784,6 @@ def build_searched_strategy(strategy_filename):
|
||||||
ValueError: Strategy file is incorrect.
|
ValueError: Strategy file is incorrect.
|
||||||
TypeError: Strategy_filename is not str.
|
TypeError: Strategy_filename is not str.
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> strategy_filename = "./strategy_train.ckpt"
|
|
||||||
>>> strategy = build_searched_strategy(strategy_filename)
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(strategy_filename, str):
|
if not isinstance(strategy_filename, str):
|
||||||
raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.")
|
raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.")
|
||||||
|
@ -836,17 +834,16 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
||||||
KeyError: The parameter name is not in keys of strategy.
|
KeyError: The parameter name is not in keys of strategy.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
|
|
||||||
>>> sliced_parameters = [
|
>>> sliced_parameters = [
|
||||||
>>> Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
|
... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
|
||||||
>>> "network.embedding_table"),
|
... "network.embedding_table"),
|
||||||
>>> Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
|
... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
|
||||||
>>> "network.embedding_table"),
|
... "network.embedding_table"),
|
||||||
>>> Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
|
... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
|
||||||
>>> "network.embedding_table"),
|
... "network.embedding_table"),
|
||||||
>>> Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
|
... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
|
||||||
>>> "network.embedding_table")]
|
... "network.embedding_table")]
|
||||||
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
|
>>> merged_parameter = merge_sliced_parameter(sliced_parameters)
|
||||||
"""
|
"""
|
||||||
if not isinstance(sliced_parameters, list):
|
if not isinstance(sliced_parameters, list):
|
||||||
raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.")
|
raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.")
|
||||||
|
|
Loading…
Reference in New Issue