!49370 ckpt参数更新接口load_param_into_net新增返回值:ckpt_not_load

Merge pull request !49370 from GuoZhibin/add_return_value_ckpt_not_load
This commit is contained in:
i-robot 2023-03-01 06:22:08 +00:00 committed by Gitee
commit 53714e4a21
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 6 additions and 2 deletions

View File

@ -12,6 +12,7 @@ mindspore.load_param_into_net
返回:
List网络中没有被加载的参数。
Listcheckpoint文件中没有被加载的参数。
异常:
- **TypeError** - 如果参数不是Cell或者 `parameter_dict` 不是Parameter类型的字典。

View File

@ -963,7 +963,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
on the parameters of the same type, such as float32 to float16. Default: False.
Returns:
List, the parameter name which are not loaded into the network.
List, the parameter name in model which are not loaded into the network.
List, the parameter name in checkpoint file which are not loaded into the network.
Raises:
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
@ -999,10 +1000,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
logger.info("Execute the process of loading parameters into net.")
net.init_parameters_data()
param_not_load = []
ckpt_not_load = list(parameter_dict.keys())
for _, param in net.parameters_and_names():
if param.name in parameter_dict:
new_param = copy.deepcopy(parameter_dict[param.name])
_update_param(param, new_param, strict_load)
ckpt_not_load.remove(param.name)
else:
param_not_load.append(param.name)
@ -1021,7 +1024,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
"when training and loading checkpoint.".format(len(param_not_load)))
for param_name in param_not_load:
logger.warning("{} is not loaded.".format(param_name))
return param_not_load
return param_not_load, ckpt_not_load
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):