diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 24c8ad38184..df5f1c11820 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -240,7 +240,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N ValueError: Checkpoint file is incorrect. Examples: - >>> ckpt_file_name = "./checkpoint/LeNet5-2_1875.ckpt" + >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") """ if not isinstance(ckpt_file_name, str): @@ -341,8 +341,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False): TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. Examples: - >>> net = LeNet5() - >>> param_dict = load_checkpoint("LeNet5-2_1875.ckpt") + >>> net = Net() + >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" + >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") >>> load_param_into_net(net, param_dict) """ if not isinstance(net, nn.Cell): @@ -783,9 +784,6 @@ def build_searched_strategy(strategy_filename): ValueError: Strategy file is incorrect. TypeError: Strategy_filename is not str. - Examples: - >>> strategy_filename = "./strategy_train.ckpt" - >>> strategy = build_searched_strategy(strategy_filename) """ if not isinstance(strategy_filename, str): raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.") @@ -836,17 +834,16 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): KeyError: The parameter name is not in keys of strategy. Examples: - >>> strategy = build_searched_strategy("./strategy_train.ckpt") >>> sliced_parameters = [ - >>> Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), - >>> "network.embedding_table"), - >>> Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), - >>> "network.embedding_table"), - >>> Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), - >>> "network.embedding_table"), - >>> Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), - >>> "network.embedding_table")] - >>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) + ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), + ... "network.embedding_table"), + ... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), + ... "network.embedding_table"), + ... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), + ... "network.embedding_table"), + ... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), + ... "network.embedding_table")] + >>> merged_parameter = merge_sliced_parameter(sliced_parameters) """ if not isinstance(sliced_parameters, list): raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.")