modify load_dic_ckpt

This commit is contained in:
changzherui 2021-05-10 23:14:12 +08:00
parent 428217a9bd
commit 753d2fdd9b
1 changed files with 13 additions and 1 deletions

View File

@ -1185,15 +1185,22 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
param_total_dict = defaultdict(dict)
for file_index, file_name in enumerate(checkpoint_filenames):
ckpt_dict = load_checkpoint(file_name, dec_key, dec_mode)
ckpt_dict = load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
for param_name, param in ckpt_dict.items():
param_total_dict[param_name][file_index] = param
param_dict = {}
param_not_in_strategy = []
param_not_in_ckpt = []
for _, param in network.parameters_and_names():
sliced_params = []
if param.name not in rank_list.keys():
param_not_in_strategy.append(param.name)
continue
if param.name not in param_total_dict:
param_not_in_ckpt.append(param.name)
continue
param_rank = rank_list[param.name][0]
skip_merge_split = rank_list[param.name][1]
shard_stride = train_strategy[param.name][4]
@ -1230,6 +1237,11 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
param_dict[param.name] = split_param
if param_not_in_strategy:
logger.warning("{} parameters in network are not in the sclice strategy.".format(param_not_in_strategy))
if param_not_in_ckpt:
logger.warning("{} parameters in sclice strategy but not in the checkpoint file.".format(param_not_in_ckpt))
load_param_into_net(network, param_dict)