fix load checkpoint bug

This commit is contained in:
chang zherui 2020-04-27 15:10:03 +08:00
parent 19bc9e785b
commit 8cef3aff7c
1 changed files with 5 additions and 4 deletions

View File

@ -258,16 +258,17 @@ def _load_dismatch_prefix_params(net, parameter_dict, param_not_load):
longest_name = param_not_load[0] longest_name = param_not_load[0]
while prefix_name != longest_name and param_not_load: while prefix_name != longest_name and param_not_load:
logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load))) logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load)))
longest_name = sorted(param_not_load, key=len, reverse=True)[0]
prefix_name = longest_name prefix_name = longest_name
for net_param_name in param_not_load: for net_param_name in param_not_load:
for dict_name in parameter_dict: for dict_name in parameter_dict:
if dict_name.endswith(net_param_name): if dict_name.endswith(net_param_name):
tmp_name = dict_name[:-len(net_param_name)] prefix_name = dict_name[:-len(net_param_name)]
prefix_name = prefix_name if len(prefix_name) < len(tmp_name) else tmp_name break
if prefix_name != longest_name:
break
if prefix_name != longest_name: if prefix_name != longest_name:
logger.info("Remove parameter prefix name: {}, continue to load.".format(prefix_name)) logger.warning("Remove parameter prefix name: {}, continue to load.".format(prefix_name))
for _, param in net.parameters_and_names(): for _, param in net.parameters_and_names():
new_param_name = prefix_name + param.name new_param_name = prefix_name + param.name
if param.name in param_not_load and new_param_name in parameter_dict: if param.name in param_not_load and new_param_name in parameter_dict: