[模型训练] 支持r1.3分支resnet50修改部分参数

[修改人] haoripei
[审核人] chenshushu
This commit is contained in:
Atlas_hrp 2021-08-27 15:12:29 +08:00
parent 4df45e7e96
commit 1765b6e78f
1 changed files with 4 additions and 4 deletions

View File

@ -182,7 +182,7 @@ def run_train(args_opt):
model = Model(pangu_alpha_with_grads)
if args_opt.pre_trained:
load_checkpoint(args_opt, callback_size, ds, model)
load_checkpoint(args_opt, callback_size, ds, model, device_num)
add_checkpoint_callback_policy(args_opt, callback, rank)
@ -221,7 +221,7 @@ def add_checkpoint_callback_policy(args_param, callback, rank_id):
callback.append(ckpoint_cb)
def load_checkpoint(args_param, sink_size, dataset, model):
def load_checkpoint(args_param, sink_size, dataset, model, device_num):
r"""
Load checkpoint process.
"""
@ -250,7 +250,7 @@ def load_checkpoint(args_param, sink_size, dataset, model):
sink_size = ckpt_file.split("-")[-1].split("_")[-1].split(".")[0]
ckpt_file_list = [os.path.join(args_param.save_checkpoint_path,
f"{ckpt_name}{ckpt_rank}_{depulicate_num}-{step_size}_{sink_size}.ckpt")
for ckpt_rank in range(D.get_group_size())]
for ckpt_rank in range(device_num)]
# Load checkpoint files
load_distributed_checkpoint(model.train_network, ckpt_file_list, strategy)
elif len(ckpt_file_length) == 2:
@ -258,7 +258,7 @@ def load_checkpoint(args_param, sink_size, dataset, model):
sink_size = ckpt_file.split("-")[-1].split("_")[-1].split(".")[0]
ckpt_file_list = [os.path.join(args_param.save_checkpoint_path,
f"{ckpt_name}{ckpt_rank}-{step_size}_{sink_size}.ckpt")
for ckpt_rank in range(D.get_group_size())]
for ckpt_rank in range(device_num)]
# Load checkpoint files
load_distributed_checkpoint(model.train_network, ckpt_file_list, strategy)
else: