diff --git a/mindspore/python/mindspore/parallel/checkpoint_transform.py b/mindspore/python/mindspore/parallel/checkpoint_transform.py index e003e7d7dd9..324595c4e20 100644 --- a/mindspore/python/mindspore/parallel/checkpoint_transform.py +++ b/mindspore/python/mindspore/parallel/checkpoint_transform.py @@ -185,7 +185,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_ raise TypeError("The save_checkpoint_file_name should be a str.") if save_checkpoint_file_name[-5:] != ".ckpt": raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name)) - if not os.path.exists(os.path.dirname(dst_strategy_file)): + if os.path.dirname(dst_strategy_file) and not os.path.exists(os.path.dirname(dst_strategy_file)): raise ValueError("The director of dst_strategy_file: {} is not exists.". format(os.path.dirname(dst_strategy_file))) for rank, local_file in checkpoint_files_map.items(): @@ -269,7 +269,9 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, continue rank_id = int(rank_id_str) checkpoint_file_name = os.path.join(checkpoint_dir, "*.ckpt") - for checkpoint_file in glob.glob(checkpoint_file_name): + rank_ckpts = glob.glob(checkpoint_file_name) + rank_ckpts.sort() + for checkpoint_file in rank_ckpts: if not os.path.isfile(checkpoint_file): ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file)) continue