forked from mindspore-Ecosystem/mindspore
checpoint support customize network
This commit is contained in:
parent
f5aa3a2eab
commit
93e909594d
|
@ -224,42 +224,60 @@ def load_param_into_net(net, parameter_dict):
|
|||
msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict)))
|
||||
raise TypeError(msg)
|
||||
|
||||
logger.info("Execute parameter into net process.")
|
||||
param_name_net_not_have = []
|
||||
logger.info("Execute load parameter into net process.")
|
||||
for name in parameter_dict:
|
||||
b_par_dict_have_par_of_net = False
|
||||
for _, param in net.parameters_and_names():
|
||||
if name == param.name:
|
||||
b_par_dict_have_par_of_net = True
|
||||
if name == param.name and param.layerwise_parallel:
|
||||
# layerwise parallel parameter data loaded from checkpoint file,
|
||||
# was a complete(merged) data, need to be splited
|
||||
if param.layerwise_parallel:
|
||||
new_param = parameter_dict[param.name]
|
||||
_load_tensor_for_layerwise(new_param, param)
|
||||
break
|
||||
if not b_par_dict_have_par_of_net:
|
||||
param_name_net_not_have.append(name)
|
||||
|
||||
param_name_param_dict_not_have = []
|
||||
param_not_load = []
|
||||
for _, param in net.parameters_and_names():
|
||||
if param.name in parameter_dict:
|
||||
new_param = parameter_dict[param.name]
|
||||
|
||||
if not isinstance(new_param, Parameter):
|
||||
logger.error("Failed to combine the net and the parameters.")
|
||||
msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
|
||||
raise TypeError(msg)
|
||||
_update_param(param, new_param)
|
||||
else:
|
||||
param_name_param_dict_not_have.append(param.name)
|
||||
param_not_load.append(param.name)
|
||||
|
||||
if param_not_load:
|
||||
_load_dismatch_prefix_params(net, parameter_dict, param_not_load)
|
||||
|
||||
logger.debug("Params not matched(in net but not in parameter_dict):")
|
||||
for paramname in param_name_param_dict_not_have:
|
||||
logger.debug("%s", paramname)
|
||||
logger.debug("Params not matched(in parameter_dict but not in net):")
|
||||
for paramname in param_name_net_not_have:
|
||||
logger.debug("%s", paramname)
|
||||
logger.info("Load parameter into net process finish.")
|
||||
for param_name in param_not_load:
|
||||
logger.debug("%s", param_name)
|
||||
|
||||
logger.info("Load parameter into net finish, {} parameters has not been loaded.".format(len(param_not_load)))
|
||||
|
||||
|
||||
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load):
|
||||
"""When some net parameter did not load, try to continue load."""
|
||||
prefix_name = ""
|
||||
longest_name = param_not_load[0]
|
||||
while prefix_name != longest_name and param_not_load:
|
||||
logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load)))
|
||||
longest_name = sorted(param_not_load, key=len, reverse=True)[0]
|
||||
prefix_name = longest_name
|
||||
for net_param_name in param_not_load:
|
||||
for dict_name in parameter_dict:
|
||||
if dict_name.endswith(net_param_name):
|
||||
tmp_name = dict_name[:-len(net_param_name)]
|
||||
prefix_name = prefix_name if len(prefix_name) < len(tmp_name) else tmp_name
|
||||
|
||||
if prefix_name != longest_name:
|
||||
logger.info("Remove parameter prefix name: {}, continue to load.".format(prefix_name))
|
||||
for _, param in net.parameters_and_names():
|
||||
new_param_name = prefix_name + param.name
|
||||
if param.name in param_not_load and new_param_name in parameter_dict:
|
||||
new_param = parameter_dict[new_param_name]
|
||||
_update_param(param, new_param)
|
||||
param_not_load.remove(param.name)
|
||||
|
||||
|
||||
def _save_graph(network, file_name):
|
||||
|
|
Loading…
Reference in New Issue