diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 33b0c6cbd37..309da04ba68 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -345,7 +345,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False): >>> net = Net() >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") - >>> load_param_into_net(net, param_dict) + >>> param_not_load = load_param_into_net(net, param_dict) + >>> print(param_not_load) + ['conv1.weight'] """ if not isinstance(net, nn.Cell): logger.error("Failed to combine the net and the parameters.")