forked from mindspore-Ecosystem/mindspore
[模型训练] 支持r1.3分支resnet50修改部分参数
[修改人] haoripei [审核人] chenshushu
This commit is contained in:
parent
4df45e7e96
commit
1765b6e78f
|
@ -182,7 +182,7 @@ def run_train(args_opt):
|
|||
model = Model(pangu_alpha_with_grads)
|
||||
|
||||
if args_opt.pre_trained:
|
||||
load_checkpoint(args_opt, callback_size, ds, model)
|
||||
load_checkpoint(args_opt, callback_size, ds, model, device_num)
|
||||
|
||||
add_checkpoint_callback_policy(args_opt, callback, rank)
|
||||
|
||||
|
@ -221,7 +221,7 @@ def add_checkpoint_callback_policy(args_param, callback, rank_id):
|
|||
callback.append(ckpoint_cb)
|
||||
|
||||
|
||||
def load_checkpoint(args_param, sink_size, dataset, model):
|
||||
def load_checkpoint(args_param, sink_size, dataset, model, device_num):
|
||||
r"""
|
||||
Load checkpoint process.
|
||||
"""
|
||||
|
@ -250,7 +250,7 @@ def load_checkpoint(args_param, sink_size, dataset, model):
|
|||
sink_size = ckpt_file.split("-")[-1].split("_")[-1].split(".")[0]
|
||||
ckpt_file_list = [os.path.join(args_param.save_checkpoint_path,
|
||||
f"{ckpt_name}{ckpt_rank}_{depulicate_num}-{step_size}_{sink_size}.ckpt")
|
||||
for ckpt_rank in range(D.get_group_size())]
|
||||
for ckpt_rank in range(device_num)]
|
||||
# Load checkpoint files
|
||||
load_distributed_checkpoint(model.train_network, ckpt_file_list, strategy)
|
||||
elif len(ckpt_file_length) == 2:
|
||||
|
@ -258,7 +258,7 @@ def load_checkpoint(args_param, sink_size, dataset, model):
|
|||
sink_size = ckpt_file.split("-")[-1].split("_")[-1].split(".")[0]
|
||||
ckpt_file_list = [os.path.join(args_param.save_checkpoint_path,
|
||||
f"{ckpt_name}{ckpt_rank}-{step_size}_{sink_size}.ckpt")
|
||||
for ckpt_rank in range(D.get_group_size())]
|
||||
for ckpt_rank in range(device_num)]
|
||||
# Load checkpoint files
|
||||
load_distributed_checkpoint(model.train_network, ckpt_file_list, strategy)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue