!45054 auto_parallel_checkpoint_transform_fix_221103

Merge pull request !45054 from yao_yf/auto_parallel_checkpoint_transform_fix_221103
This commit is contained in:
i-robot 2022-11-04 01:19:59 +00:00 committed by Gitee
commit 5cb3116f16
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 4 additions and 2 deletions

View File

@ -185,7 +185,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
raise TypeError("The save_checkpoint_file_name should be a str.")
if save_checkpoint_file_name[-5:] != ".ckpt":
raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name))
if not os.path.exists(os.path.dirname(dst_strategy_file)):
if os.path.dirname(dst_strategy_file) and not os.path.exists(os.path.dirname(dst_strategy_file)):
raise ValueError("The director of dst_strategy_file: {} is not exists.".
format(os.path.dirname(dst_strategy_file)))
for rank, local_file in checkpoint_files_map.items():
@ -269,7 +269,9 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
continue
rank_id = int(rank_id_str)
checkpoint_file_name = os.path.join(checkpoint_dir, "*.ckpt")
for checkpoint_file in glob.glob(checkpoint_file_name):
rank_ckpts = glob.glob(checkpoint_file_name)
rank_ckpts.sort()
for checkpoint_file in rank_ckpts:
if not os.path.isfile(checkpoint_file):
ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file))
continue