diff --git a/docs/api/api_python/mindspore/mindspore.load_param_into_net.rst b/docs/api/api_python/mindspore/mindspore.load_param_into_net.rst index 2e7f19bf816..077e47aec35 100644 --- a/docs/api/api_python/mindspore/mindspore.load_param_into_net.rst +++ b/docs/api/api_python/mindspore/mindspore.load_param_into_net.rst @@ -12,6 +12,7 @@ mindspore.load_param_into_net 返回: List,网络中没有被加载的参数。 + List,checkpoint文件中没有被加载的参数。 异常: - **TypeError** - 如果参数不是Cell,或者 `parameter_dict` 不是Parameter类型的字典。 diff --git a/mindspore/python/mindspore/train/serialization.py b/mindspore/python/mindspore/train/serialization.py index 3f693b8d149..0ead308e568 100644 --- a/mindspore/python/mindspore/train/serialization.py +++ b/mindspore/python/mindspore/train/serialization.py @@ -958,7 +958,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False): on the parameters of the same type, such as float32 to float16. Default: False. Returns: - List, the parameter name which are not loaded into the network. + List, the parameter name in model which are not loaded into the network. + List, the parameter name in checkpoint file which are not loaded into the network. Raises: TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. @@ -994,10 +995,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False): logger.info("Execute the process of loading parameters into net.") net.init_parameters_data() param_not_load = [] + ckpt_not_load = list(parameter_dict.keys()) for _, param in net.parameters_and_names(): if param.name in parameter_dict: new_param = copy.deepcopy(parameter_dict[param.name]) _update_param(param, new_param, strict_load) + ckpt_not_load.remove(param.name) else: param_not_load.append(param.name) @@ -1016,7 +1019,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False): "when training and loading checkpoint.".format(len(param_not_load))) for param_name in param_not_load: logger.warning("{} is not loaded.".format(param_name)) - return param_not_load + return param_not_load, ckpt_not_load def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):