forked from mindspore-Ecosystem/mindspore
checkpoint transform api
This commit is contained in:
parent
75c526019f
commit
534f6ec36a
|
@ -88,8 +88,11 @@ mindspore
|
|||
mindspore.load_param_into_net
|
||||
mindspore.merge_sliced_parameter
|
||||
mindspore.parse_print
|
||||
mindspore.rank_list_for_transform
|
||||
mindspore.restore_group_info_list
|
||||
mindspore.save_checkpoint
|
||||
mindspore.transform_checkpoint_by_rank
|
||||
mindspore.transform_checkpoints
|
||||
|
||||
调试调优
|
||||
----------
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
mindspore.rank_list_for_transform
|
||||
======================================
|
||||
|
||||
.. py:function:: mindspore.rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None)
|
||||
|
||||
在对分布式Checkpoint转换的过程中,获取为了得到目标rank的Checkpoint文件所需的源Checkpoint文件rank列表。关于更多分布式Checkpoint转换的细节,请参考:[分布式弹性训练与推理](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html)。
|
||||
|
||||
.. note::
|
||||
当前暂时不支持流水线并行维度的转换。
|
||||
|
||||
参数:
|
||||
- **rank_id** (int) - 待转换得到的Checkpoint的rank号。
|
||||
- **src_strategy_file** (str) - 源切分策略proto文件名,由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时,表示切分策略为不切分。默认值:None。
|
||||
- **dst_strategy_file** (str) - 目标切分策略proto文件名,由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时,表示切分策略为不切分。默认值:None。
|
||||
|
||||
异常:
|
||||
- **ValueError** - src_strategy_file或者dst_strategy_file不是正确的切分策略proto文件。
|
||||
- **TypeError** - src_strategy_file或者dst_strategy_file不是字符串。
|
||||
- **TypeError** - rank_id不是一个整数。
|
|
@ -0,0 +1,25 @@
|
|||
mindspore.transform_checkpoint_by_rank
|
||||
======================================
|
||||
|
||||
.. py:function:: mindspore.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, src_strategy_file=None, dst_strategy_file=None)
|
||||
|
||||
将一个分布式网络的Checkpoint由源切分策略转换到目标切分策略,对特定一个rank进行转换。关于更多分布式Checkpoint转换的细节,请参考:[分布式弹性训练与推理](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html)。
|
||||
|
||||
.. note::
|
||||
当前暂时不支持流水线并行维度的转换。
|
||||
|
||||
参数:
|
||||
- **rank_id** (int) - 待转换得到的Checkpoint的rank号。
|
||||
- **checkpoint_files_map** (dict) - 源Checkpoint字典,其key为rank号,值为该rank号对应的Checkpoint文件路径。
|
||||
- **save_checkpoint_file_name** (str) - 目标Checkpoint路径以及名字。
|
||||
- **src_strategy_file** (str) - 源切分策略proto文件名,由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时,表示切分策略为不切分。默认值:None。
|
||||
- **dst_strategy_file** (str) - 目标切分策略proto文件名,由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时,表示切分策略为不切分。默认值:None。
|
||||
|
||||
异常:
|
||||
- **ValueError** - src_strategy_file或者dst_strategy_file不是正确的切分策略proto文件。
|
||||
- **ValueError** - checkpoint_files_map内的元素不是正确的Checkpoint文件。
|
||||
- **ValueError** - save_checkpoint_file_name不以“.ckpt”结尾。
|
||||
- **TypeError** - checkpoint_files_map不是一个字典。
|
||||
- **TypeError** - src_strategy_file或者dst_strategy_file不是字符串。
|
||||
- **TypeError** - rank_id不是一个整数。
|
||||
- **TypeError** - save_checkpoint_file_name不是字符串。
|
|
@ -0,0 +1,22 @@
|
|||
mindspore.transform_checkpoints
|
||||
======================================
|
||||
|
||||
.. py:function:: mindspore.transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None, dst_strategy_file=None)
|
||||
|
||||
将一个分布式网络的Checkpoint由源切分策略转换到目标切分策略。关于更多分布式Checkpoint转换的细节,请参考:[分布式弹性训练与推理](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html)。
|
||||
|
||||
.. note::
|
||||
src_checkpoints_dir目录必须组织为“src_checkpoints_dir/rank_0/a.ckpt”这样的目录结构,rank号必须作为子目录并且该rank的Checkpoint必须放置于该子目录内。如果多个文件存在于一个rank目录下,将会选名字的字典序最高的文件。当前暂时不支持流水线并行维度的转换。
|
||||
|
||||
参数:
|
||||
- **src_checkpoints_dir** (str) - 源Checkpoint文件所在的目录。
|
||||
- **dst_checkpoints_dir** (str) - 目标Checkpoint文件存储的目录。
|
||||
- **ckpt_prefix** (str) - 目标Checkpoint前缀名。
|
||||
- **src_strategy_file** (str) - 源切分策略proto文件名,由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时,表示切分策略为不切分。默认值:None。
|
||||
- **dst_strategy_file** (str) - 目标切分策略proto文件名,由context.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时,表示切分策略为不切分。默认值:None。
|
||||
|
||||
异常:
|
||||
- **ValueError** - src_strategy_file或者dst_strategy_file不是正确的切分策略proto文件。
|
||||
- **NotADirectoryError** - src_checkpoints_dir或者dst_checkpoints_dir不是一个目录。
|
||||
- **ValueError** - src_checkpoints_dir中缺失了Checkpoint文件。
|
||||
- **TypeError** - src_strategy_file或者dst_strategy_file 不是字符串。
|
|
@ -203,8 +203,11 @@ Serialization
|
|||
mindspore.load_param_into_net
|
||||
mindspore.merge_sliced_parameter
|
||||
mindspore.parse_print
|
||||
mindspore.rank_list_for_transform
|
||||
mindspore.restore_group_info_list
|
||||
mindspore.save_checkpoint
|
||||
mindspore.transform_checkpoint_by_rank
|
||||
mindspore.transform_checkpoints
|
||||
|
||||
JIT
|
||||
---
|
||||
|
|
|
@ -37,12 +37,14 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
|
|||
|
||||
Args:
|
||||
rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
|
||||
src_strategy_file (str): Name of source sharding strategy file, when the 'src_strategy_file' is None,
|
||||
it means that the source sharding strategy is without any sharing for each parameter.
|
||||
Default:None.
|
||||
dst_strategy_file (str): Name of destination sharding strategy file. when the 'dst_strategy_file' is None,
|
||||
it means that the source sharding strategy is without any sharing for each parameter.
|
||||
Default:None.
|
||||
src_strategy_file (str): Name of source sharding strategy file which saved by
|
||||
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
|
||||
when the 'src_strategy_file' is None, it means that the source sharding strategy is
|
||||
without any sharing for each parameter. Default:None.
|
||||
dst_strategy_file (str): Name of destination sharding strategy file which saved by
|
||||
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
|
||||
when the 'dst_strategy_file' is None, it means that the destination sharding strategy
|
||||
is without any sharing for each parameter. Default:None.
|
||||
|
||||
Returns:
|
||||
List, the rank list required for converting the distributed checkpoint of rank_id.
|
||||
|
@ -54,8 +56,7 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
|
|||
|
||||
Examples:
|
||||
>>> rank_id = 0
|
||||
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt",
|
||||
>>> "./dst_strategy.ckpt")
|
||||
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
>>> checkpoint_files_map = {}
|
||||
>>> for rank in rank_list:
|
||||
>>> checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
|
||||
|
@ -69,7 +70,8 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
|
|||
def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
|
||||
src_strategy_file=None, dst_strategy_file=None):
|
||||
"""
|
||||
Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank.
|
||||
Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank
|
||||
for a network.
|
||||
|
||||
Note:
|
||||
Cannot transform pipeline parallel dimensions currently.
|
||||
|
@ -79,37 +81,44 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
|
|||
checkpoint_files_map (dict): The checkpoint files map whose key is the rank id and the value is
|
||||
the checkpoint file name.
|
||||
save_checkpoint_file_name (str): The file name to save the converted checkpoint.
|
||||
src_strategy_file (str): Name of source sharding strategy file, when the 'src_strategy_file' is None,
|
||||
it means that the source sharding strategy is without any sharding for each parameter.
|
||||
Default:None.
|
||||
dst_strategy_file (str): Name of destination sharding strategy file. when the 'dst_strategy_file' is None,
|
||||
it means that the source sharding strategy is without any sharding for each parameter.
|
||||
Default:None.
|
||||
src_strategy_file (str): Name of source sharding strategy file which saved by
|
||||
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
|
||||
when the 'src_strategy_file' is None, it means that the source sharding strategy is
|
||||
without any sharing for each parameter. Default:None.
|
||||
dst_strategy_file (str): Name of destination sharding strategy file which saved by
|
||||
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
|
||||
when the 'dst_strategy_file' is None, it means that the destination sharding strategy
|
||||
is without any sharing for each parameter. Default:None.
|
||||
|
||||
Raises:
|
||||
ValueError: src_strategy_file or dst_strategy_file is incorrect.
|
||||
ValueError: item in checkpoint_files_map is incorrect.
|
||||
ValueError: save_checkpoint_file_name is not end with ".ckpt".
|
||||
TypeError: checkpoint_files_map is not a dict.
|
||||
TypeError: src_strategy_file or dst_strategy_file is not a string.
|
||||
TypeError: rank_id is not a int.
|
||||
TypeError: save_checkpoint_file_name is not a string.
|
||||
|
||||
Examples:
|
||||
>>> dst_device_num = 8
|
||||
>>> for rank_id in range(dst_device_num)
|
||||
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt",
|
||||
>>> "./dst_strategy.ckpt")
|
||||
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
>>> checkpoint_files_map = {}
|
||||
>>> for rank in rank_list:
|
||||
>>> checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank)
|
||||
>>> save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id)
|
||||
>>> transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
|
||||
>>> "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
>>> "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
|
||||
"""
|
||||
if not isinstance(checkpoint_files_map, dict):
|
||||
raise TypeError("The checkpoint_files_map should be a dict.")
|
||||
if not isinstance(rank_id, int):
|
||||
raise TypeError("The rank_id should be a int.")
|
||||
if not isinstance(save_checkpoint_file_name, str):
|
||||
raise TypeError("The save_checkpoint_file_name should be a str.")
|
||||
if save_checkpoint_file_name[-5:] != ".ckpt":
|
||||
raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name))
|
||||
for rank, local_file in checkpoint_files_map.items():
|
||||
if not os.path.exists(local_file):
|
||||
raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank))
|
||||
|
@ -128,7 +137,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
|
|||
def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None,
|
||||
dst_strategy_file=None):
|
||||
"""
|
||||
Transform distributed checkpoint from source sharding strategy to destination sharding strategy.
|
||||
Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank.
|
||||
|
||||
Note:
|
||||
The src_checkpoints_dir directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the
|
||||
|
@ -139,12 +148,15 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
|
|||
Args:
|
||||
src_checkpoints_dir (str): The source checkpoints directory.
|
||||
dst_checkpoints_dir (str): The destination checkpoints directory to save the converted checkpoints.
|
||||
src_strategy_file (str): Name of source sharding strategy file, when the 'src_strategy_file' is None,
|
||||
it means that the source sharding strategy is without any sharding for each parameter.
|
||||
Default:None.
|
||||
dst_strategy_file (str): Name of destination sharding strategy file. when the 'dst_strategy_file' is None,
|
||||
it means that the source sharding strategy is without any sharding for each parameter.
|
||||
Default:None.
|
||||
ckpt_prefix (str): The destination checkpoint name prefix.
|
||||
src_strategy_file (str): Name of source sharding strategy file which saved by
|
||||
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
|
||||
when the 'src_strategy_file' is None, it means that the source sharding strategy is
|
||||
without any sharing for each parameter. Default:None.
|
||||
dst_strategy_file (str): Name of destination sharding strategy file which saved by
|
||||
'context.set_autp_parallel_context(strategy_ckpt_save_file)'.
|
||||
when the 'dst_strategy_file' is None, it means that the destination sharding strategy
|
||||
is without any sharing for each parameter. Default:None.
|
||||
|
||||
Raises:
|
||||
ValueError: src_strategy_file or dst_strategy_file is incorrect.
|
||||
|
@ -153,8 +165,8 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
|
|||
TypeError: src_strategy_file or dst_strategy_file is not a string.
|
||||
|
||||
Examples:
|
||||
>>> transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir,
|
||||
>>> "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
>>> transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, "dst_checkpoint",
|
||||
... "./src_strategy.ckpt", "./dst_strategy.ckpt")
|
||||
|
||||
"""
|
||||
if not os.path.isdir(src_checkpoints_dir):
|
||||
|
|
Loading…
Reference in New Issue