!45054 auto_parallel_checkpoint_transform_fix_221103
Merge pull request !45054 from yao_yf/auto_parallel_checkpoint_transform_fix_221103
This commit is contained in:
commit
5cb3116f16
|
@ -185,7 +185,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
|
|||
raise TypeError("The save_checkpoint_file_name should be a str.")
|
||||
if save_checkpoint_file_name[-5:] != ".ckpt":
|
||||
raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name))
|
||||
if not os.path.exists(os.path.dirname(dst_strategy_file)):
|
||||
if os.path.dirname(dst_strategy_file) and not os.path.exists(os.path.dirname(dst_strategy_file)):
|
||||
raise ValueError("The director of dst_strategy_file: {} is not exists.".
|
||||
format(os.path.dirname(dst_strategy_file)))
|
||||
for rank, local_file in checkpoint_files_map.items():
|
||||
|
@ -269,7 +269,9 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
|
|||
continue
|
||||
rank_id = int(rank_id_str)
|
||||
checkpoint_file_name = os.path.join(checkpoint_dir, "*.ckpt")
|
||||
for checkpoint_file in glob.glob(checkpoint_file_name):
|
||||
rank_ckpts = glob.glob(checkpoint_file_name)
|
||||
rank_ckpts.sort()
|
||||
for checkpoint_file in rank_ckpts:
|
||||
if not os.path.isfile(checkpoint_file):
|
||||
ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file))
|
||||
continue
|
||||
|
|
Loading…
Reference in New Issue