checkpoint transform api

This commit is contained in:
yao_yf 2022-09-20 16:15:35 +08:00
parent 75c526019f
commit 534f6ec36a
6 changed files with 111 additions and 27 deletions

View File

@ -88,8 +88,11 @@ mindspore
mindspore.load_param_into_net
mindspore.merge_sliced_parameter
mindspore.parse_print
mindspore.rank_list_for_transform
mindspore.restore_group_info_list
mindspore.save_checkpoint
mindspore.transform_checkpoint_by_rank
mindspore.transform_checkpoints
调试调优
----------

View File

@ -0,0 +1,19 @@
mindspore.rank_list_for_transform
======================================
.. py:function:: mindspore.rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None)
在对分布式Checkpoint转换的过程中获取为了得到目标rank的Checkpoint文件所需的源Checkpoint文件rank列表。关于更多分布式Checkpoint转换的细节请参考[分布式弹性训练与推理](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html)。
.. note::
当前暂时不支持流水线并行维度的转换。
参数:
- **rank_id** (int) - 待转换得到的Checkpoint的rank号。
- **src_strategy_file** (str) - 源切分策略proto文件名由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时表示切分策略为不切分。默认值None。
- **dst_strategy_file** (str) - 目标切分策略proto文件名由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时表示切分策略为不切分。默认值None。
异常:
- **ValueError** - src_strategy_file或者dst_strategy_file不是正确的切分策略proto文件。
- **TypeError** - src_strategy_file或者dst_strategy_file不是字符串。
- **TypeError** - rank_id不是一个整数。

View File

@ -0,0 +1,25 @@
mindspore.transform_checkpoint_by_rank
======================================
.. py:function:: mindspore.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, src_strategy_file=None, dst_strategy_file=None)
将一个分布式网络的Checkpoint由源切分策略转换到目标切分策略对特定一个rank进行转换。关于更多分布式Checkpoint转换的细节请参考[分布式弹性训练与推理](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html)。
.. note::
当前暂时不支持流水线并行维度的转换。
参数:
- **rank_id** (int) - 待转换得到的Checkpoint的rank号。
- **checkpoint_files_map** (dict) - 源Checkpoint字典其key为rank号值为该rank号对应的Checkpoint文件路径。
- **save_checkpoint_file_name** (str) - 目标Checkpoint路径以及名字。
- **src_strategy_file** (str) - 源切分策略proto文件名由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时表示切分策略为不切分。默认值None。
- **dst_strategy_file** (str) - 目标切分策略proto文件名由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时表示切分策略为不切分。默认值None。
异常:
- **ValueError** - src_strategy_file或者dst_strategy_file不是正确的切分策略proto文件。
- **ValueError** - checkpoint_files_map内的元素不是正确的Checkpoint文件。
- **ValueError** - save_checkpoint_file_name不以“.ckpt”结尾。
- **TypeError** - checkpoint_files_map不是一个字典。
- **TypeError** - src_strategy_file或者dst_strategy_file不是字符串。
- **TypeError** - rank_id不是一个整数。
- **TypeError** - save_checkpoint_file_name不是字符串。

View File

@ -0,0 +1,22 @@
mindspore.transform_checkpoints
======================================
.. py:function:: mindspore.transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None, dst_strategy_file=None)
将一个分布式网络的Checkpoint由源切分策略转换到目标切分策略。关于更多分布式Checkpoint转换的细节请参考[分布式弹性训练与推理](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html)。
.. note::
src_checkpoints_dir目录必须组织为“src_checkpoints_dir/rank_0/a.ckpt”这样的目录结构rank号必须作为子目录并且该rank的Checkpoint必须放置于该子目录内。如果多个文件存在于一个rank目录下将会选名字的字典序最高的文件。当前暂时不支持流水线并行维度的转换。
参数:
- **src_checkpoints_dir** (str) - 源Checkpoint文件所在的目录。
- **dst_checkpoints_dir** (str) - 目标Checkpoint文件存储的目录。
- **ckpt_prefix** (str) - 目标Checkpoint前缀名。
- **src_strategy_file** (str) - 源切分策略proto文件名由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时表示切分策略为不切分。默认值None。
- **dst_strategy_file** (str) - 目标切分策略proto文件名由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时表示切分策略为不切分。默认值None。
异常:
- **ValueError** - src_strategy_file或者dst_strategy_file不是正确的切分策略proto文件。
- **NotADirectoryError** - src_checkpoints_dir或者dst_checkpoints_dir不是一个目录。
- **ValueError** - src_checkpoints_dir中缺失了Checkpoint文件。
- **TypeError** - src_strategy_file或者dst_strategy_file 不是字符串。

View File

@ -203,8 +203,11 @@ Serialization
mindspore.load_param_into_net
mindspore.merge_sliced_parameter
mindspore.parse_print
mindspore.rank_list_for_transform
mindspore.restore_group_info_list
mindspore.save_checkpoint
mindspore.transform_checkpoint_by_rank
mindspore.transform_checkpoints
JIT
---

View File

@ -37,12 +37,14 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
Args:
rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
src_strategy_file (str): Name of source sharding strategy file, when the 'src_strategy_file' is None,
it means that the source sharding strategy is without any sharing for each parameter.
Default:None.
dst_strategy_file (str): Name of destination sharding strategy file. when the 'dst_strategy_file' is None,
it means that the source sharding strategy is without any sharing for each parameter.
Default:None.
src_strategy_file (str): Name of source sharding strategy file which saved by
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
when the 'src_strategy_file' is None, it means that the source sharding strategy is
without any sharing for each parameter. Default:None.
dst_strategy_file (str): Name of destination sharding strategy file which saved by
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
when the 'dst_strategy_file' is None, it means that the destination sharding strategy
is without any sharing for each parameter. Default:None.
Returns:
List, the rank list required for converting the distributed checkpoint of rank_id.
@ -54,8 +56,7 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
Examples:
>>> rank_id = 0
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt",
>>> "./dst_strategy.ckpt")
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
>>> checkpoint_files_map = {}
>>> for rank in rank_list:
>>> checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
@ -69,7 +70,8 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
src_strategy_file=None, dst_strategy_file=None):
"""
Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank.
Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank
for a network.
Note:
Cannot transform pipeline parallel dimensions currently.
@ -79,37 +81,44 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
checkpoint_files_map (dict): The checkpoint files map whose key is the rank id and the value is
the checkpoint file name.
save_checkpoint_file_name (str): The file name to save the converted checkpoint.
src_strategy_file (str): Name of source sharding strategy file, when the 'src_strategy_file' is None,
it means that the source sharding strategy is without any sharding for each parameter.
Default:None.
dst_strategy_file (str): Name of destination sharding strategy file. when the 'dst_strategy_file' is None,
it means that the source sharding strategy is without any sharding for each parameter.
Default:None.
src_strategy_file (str): Name of source sharding strategy file which saved by
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
when the 'src_strategy_file' is None, it means that the source sharding strategy is
without any sharing for each parameter. Default:None.
dst_strategy_file (str): Name of destination sharding strategy file which saved by
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
when the 'dst_strategy_file' is None, it means that the destination sharding strategy
is without any sharing for each parameter. Default:None.
Raises:
ValueError: src_strategy_file or dst_strategy_file is incorrect.
ValueError: item in checkpoint_files_map is incorrect.
ValueError: save_checkpoint_file_name is not end with ".ckpt".
TypeError: checkpoint_files_map is not a dict.
TypeError: src_strategy_file or dst_strategy_file is not a string.
TypeError: rank_id is not a int.
TypeError: save_checkpoint_file_name is not a string.
Examples:
>>> dst_device_num = 8
>>> for rank_id in range(dst_device_num)
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt",
>>> "./dst_strategy.ckpt")
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
>>> checkpoint_files_map = {}
>>> for rank in rank_list:
>>> checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank)
>>> save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id)
>>> transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
>>> "./src_strategy.ckpt", "./dst_strategy.ckpt")
>>> "./src_strategy.ckpt", "./dst_strategy.ckpt")
"""
if not isinstance(checkpoint_files_map, dict):
raise TypeError("The checkpoint_files_map should be a dict.")
if not isinstance(rank_id, int):
raise TypeError("The rank_id should be a int.")
if not isinstance(save_checkpoint_file_name, str):
raise TypeError("The save_checkpoint_file_name should be a str.")
if save_checkpoint_file_name[-5:] != ".ckpt":
raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name))
for rank, local_file in checkpoint_files_map.items():
if not os.path.exists(local_file):
raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank))
@ -128,7 +137,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None,
dst_strategy_file=None):
"""
Transform distributed checkpoint from source sharding strategy to destination sharding strategy.
Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank.
Note:
The src_checkpoints_dir directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the
@ -139,12 +148,15 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
Args:
src_checkpoints_dir (str): The source checkpoints directory.
dst_checkpoints_dir (str): The destination checkpoints directory to save the converted checkpoints.
src_strategy_file (str): Name of source sharding strategy file, when the 'src_strategy_file' is None,
it means that the source sharding strategy is without any sharding for each parameter.
Default:None.
dst_strategy_file (str): Name of destination sharding strategy file. when the 'dst_strategy_file' is None,
it means that the source sharding strategy is without any sharding for each parameter.
Default:None.
ckpt_prefix (str): The destination checkpoint name prefix.
src_strategy_file (str): Name of source sharding strategy file which saved by
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
when the 'src_strategy_file' is None, it means that the source sharding strategy is
without any sharing for each parameter. Default:None.
dst_strategy_file (str): Name of destination sharding strategy file which saved by
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
when the 'dst_strategy_file' is None, it means that the destination sharding strategy
is without any sharing for each parameter. Default:None.
Raises:
ValueError: src_strategy_file or dst_strategy_file is incorrect.
@ -153,8 +165,8 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
TypeError: src_strategy_file or dst_strategy_file is not a string.
Examples:
>>> transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir,
>>> "./src_strategy.ckpt", "./dst_strategy.ckpt")
>>> transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, "dst_checkpoint",
... "./src_strategy.ckpt", "./dst_strategy.ckpt")
"""
if not os.path.isdir(src_checkpoints_dir):