checpoint support customize network

This commit is contained in:
chang zherui 2020-04-17 15:50:30 +08:00
parent f5aa3a2eab
commit 93e909594d
1 changed files with 37 additions and 19 deletions

View File

@ -224,42 +224,60 @@ def load_param_into_net(net, parameter_dict):
msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict)))
raise TypeError(msg)
logger.info("Execute parameter into net process.")
param_name_net_not_have = []
logger.info("Execute load parameter into net process.")
for name in parameter_dict:
b_par_dict_have_par_of_net = False
for _, param in net.parameters_and_names():
if name == param.name:
b_par_dict_have_par_of_net = True
if name == param.name and param.layerwise_parallel:
# layerwise parallel parameter data loaded from checkpoint file,
# was a complete(merged) data, need to be splited
if param.layerwise_parallel:
new_param = parameter_dict[param.name]
_load_tensor_for_layerwise(new_param, param)
break
if not b_par_dict_have_par_of_net:
param_name_net_not_have.append(name)
param_name_param_dict_not_have = []
param_not_load = []
for _, param in net.parameters_and_names():
if param.name in parameter_dict:
new_param = parameter_dict[param.name]
if not isinstance(new_param, Parameter):
logger.error("Failed to combine the net and the parameters.")
msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
raise TypeError(msg)
_update_param(param, new_param)
else:
param_name_param_dict_not_have.append(param.name)
param_not_load.append(param.name)
if param_not_load:
_load_dismatch_prefix_params(net, parameter_dict, param_not_load)
logger.debug("Params not matched(in net but not in parameter_dict):")
for paramname in param_name_param_dict_not_have:
logger.debug("%s", paramname)
logger.debug("Params not matched(in parameter_dict but not in net):")
for paramname in param_name_net_not_have:
logger.debug("%s", paramname)
logger.info("Load parameter into net process finish.")
for param_name in param_not_load:
logger.debug("%s", param_name)
logger.info("Load parameter into net finish, {} parameters has not been loaded.".format(len(param_not_load)))
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load):
"""When some net parameter did not load, try to continue load."""
prefix_name = ""
longest_name = param_not_load[0]
while prefix_name != longest_name and param_not_load:
logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load)))
longest_name = sorted(param_not_load, key=len, reverse=True)[0]
prefix_name = longest_name
for net_param_name in param_not_load:
for dict_name in parameter_dict:
if dict_name.endswith(net_param_name):
tmp_name = dict_name[:-len(net_param_name)]
prefix_name = prefix_name if len(prefix_name) < len(tmp_name) else tmp_name
if prefix_name != longest_name:
logger.info("Remove parameter prefix name: {}, continue to load.".format(prefix_name))
for _, param in net.parameters_and_names():
new_param_name = prefix_name + param.name
if param.name in param_not_load and new_param_name in parameter_dict:
new_param = parameter_dict[new_param_name]
_update_param(param, new_param)
param_not_load.remove(param.name)
def _save_graph(network, file_name):