parallel checkpoint transform api fix

This commit is contained in:
yao_yf 2022-10-31 19:55:22 +08:00
parent d9e88bd13b
commit 5fa24ee733
6 changed files with 39 additions and 35 deletions

View File

@ -3,7 +3,7 @@ mindspore.merge_pipeline_strategys
.. py:function:: mindspore.merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file)
流水线并行模式下汇聚所有流水线并行子图的切分策略文件。关于更多分布式Checkpoint转换的细节请参考[分布式弹性训练与推理](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html)
流水线并行模式下汇聚所有流水线并行子图的切分策略文件。关于更多分布式Checkpoint转换的细节请参考`分布式弹性训练与推理 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html>`_
.. note::
src_strategy_dirs必须包含所有流水线并行的子图的切分策略文件。
@ -13,4 +13,4 @@ mindspore.merge_pipeline_strategys
- **dst_strategy_file** (str) - 保存汇聚后的切分策略的文件路径。
异常:
- **NotADirectoryError** - src_strategy_dirs不是一个目录。
- **NotADirectoryError** - `src_strategy_dirs` 不是一个目录。

View File

@ -3,7 +3,7 @@ mindspore.rank_list_for_transform
.. py:function:: mindspore.rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None)
在对分布式Checkpoint转换的过程中获取为了得到目标rank的Checkpoint文件所需的源Checkpoint文件rank列表。关于更多分布式Checkpoint转换的细节请参考[分布式弹性训练与推理](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html)
在对分布式Checkpoint转换的过程中获取为了得到目标rank的Checkpoint文件所需的源Checkpoint文件rank列表。关于更多分布式Checkpoint转换的细节请参考`分布式弹性训练与推理 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html>`_
参数:
- **rank_id** (int) - 待转换得到的Checkpoint的rank号。
@ -14,6 +14,6 @@ mindspore.rank_list_for_transform
转换得到rank_id的分布式Checkpoint所需要的卡号列表。
异常:
- **ValueError** - src_strategy_file或者dst_strategy_file不是正确的切分策略proto文件。
- **TypeError** - src_strategy_file或者dst_strategy_file不是字符串。
- **TypeError** - rank_id不是一个整数。
- **ValueError** - `src_strategy_file` 或者 `dst_strategy_file` 不是正确的切分策略proto文件。
- **TypeError** - `src_strategy_file` 或者 `dst_strategy_file` 不是字符串。
- **TypeError** - `rank_id` 不是一个整数。

View File

@ -3,7 +3,7 @@ mindspore.transform_checkpoint_by_rank
.. py:function:: mindspore.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, src_strategy_file=None, dst_strategy_file=None)
将一个分布式网络的Checkpoint由源切分策略转换到目标切分策略对特定一个rank进行转换。关于更多分布式Checkpoint转换的细节请参考[分布式弹性训练与推理](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html)
将一个分布式网络的Checkpoint由源切分策略转换到目标切分策略对特定一个rank进行转换。关于更多分布式Checkpoint转换的细节请参考`分布式弹性训练与推理 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html>`_
参数:
- **rank_id** (int) - 待转换得到的Checkpoint的rank号。
@ -13,10 +13,10 @@ mindspore.transform_checkpoint_by_rank
- **dst_strategy_file** (str) - 目标切分策略proto文件名由mindspore.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时表示切分策略为不切分。默认值None。
异常:
- **ValueError** - src_strategy_file或者dst_strategy_file不是正确的切分策略proto文件。
- **ValueError** - checkpoint_files_map内的元素不是正确的Checkpoint文件。
- **ValueError** - save_checkpoint_file_name不以“.ckpt”结尾。
- **TypeError** - checkpoint_files_map不是一个字典。
- **TypeError** - src_strategy_file或者dst_strategy_file不是字符串。
- **TypeError** - rank_id不是一个整数。
- **TypeError** - save_checkpoint_file_name不是字符串。
- **ValueError** - `src_strategy_file` 或者 `dst_strategy_file` 不是正确的切分策略proto文件。
- **ValueError** - `checkpoint_files_map` 内的元素不是正确的Checkpoint文件。
- **ValueError** - `save_checkpoint_file_name` 不以“.ckpt”结尾。
- **TypeError** - `checkpoint_files_map` 不是一个字典。
- **TypeError** - `src_strategy_file` 或者 `dst_strategy_file` 不是字符串。
- **TypeError** - `rank_id` 不是一个整数。
- **TypeError** - `save_checkpoint_file_name` 不是字符串。

View File

@ -3,7 +3,7 @@ mindspore.transform_checkpoints
.. py:function:: mindspore.transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None, dst_strategy_file=None)
将一个分布式网络的Checkpoint由源切分策略转换到目标切分策略。关于更多分布式Checkpoint转换的细节请参考[分布式弹性训练与推理](https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html)
将一个分布式网络的Checkpoint由源切分策略转换到目标切分策略。关于更多分布式Checkpoint转换的细节请参考`分布式弹性训练与推理 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/resilience_train_and_predict.html>`_
.. note::
src_checkpoints_dir目录必须组织为“src_checkpoints_dir/rank_0/a.ckpt”这样的目录结构rank号必须作为子目录并且该rank的Checkpoint必须放置于该子目录内。如果多个文件存在于一个rank目录下将会选名字的字典序最高的文件。
@ -16,7 +16,7 @@ mindspore.transform_checkpoints
- **dst_strategy_file** (str) - 目标切分策略proto文件名由mindspore.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为None时表示切分策略为不切分。默认值None。
异常:
- **ValueError** - src_strategy_file或者dst_strategy_file不是正确的切分策略proto文件。
- **NotADirectoryError** - src_checkpoints_dir或者dst_checkpoints_dir不是一个目录。
- **ValueError** - src_checkpoints_dir中缺失了Checkpoint文件。
- **TypeError** - src_strategy_file或者dst_strategy_file 不是字符串。
- **ValueError** - `src_strategy_file` 或者 `dst_strategy_file` 不是正确的切分策略proto文件。
- **NotADirectoryError** - `src_checkpoints_dir` 或者 `dst_checkpoints_dir` 不是一个目录。
- **ValueError** - `src_checkpoints_dir` 中缺失了Checkpoint文件。
- **TypeError** - `src_strategy_file` 或者 `dst_strategy_file` 不是字符串。

View File

@ -1725,7 +1725,8 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const
StrategyPtr strategyPtr;
std::shared_ptr<Strategies> strategy_v_ptr = operator_->GenerateBatchStrategiesWithCheck();
MS_EXCEPTION_IF_NULL(strategy_v_ptr);
strategyPtr = NewStrategy(0, *strategy_v_ptr);
auto stage_id = g_device_manager->stage_id();
strategyPtr = NewStrategy(stage_id, *strategy_v_ptr);
std::vector<ValuePtr> elements;
for (size_t i = 0; i < strategy_v_ptr->size(); i++) {
elements.push_back(MakeValue((*strategy_v_ptr)[i]));

View File

@ -43,7 +43,7 @@ def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
dst_strategy_file (str): The file merged strategy to save.
Raises:
NotADirectoryError: src_strategy_dirs is not a directory.
NotADirectoryError: `src_strategy_dirs` is not a directory.
Examples:
>>> # src_strategy_dir/stra0.ckpt, src_strategy_dir/stra1.ckpt ... src_strategy_dir/stra127.ckpt
@ -97,9 +97,9 @@ def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=N
List, the rank list required for converting the distributed checkpoint of rank_id.
Raises:
ValueError: src_strategy_file or dst_strategy_file is incorrect.
TypeError: src_strategy_file or dst_strategy_file is not a string.
TypeError: rank_id is not a int.
ValueError: `src_strategy_file` or dst_strategy_file is incorrect.
TypeError: `src_strategy_file` or dst_strategy_file is not a string.
TypeError: `rank_id` is not a int.
Examples:
>>> rank_id = 0
@ -157,13 +157,13 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
is without any sharing for each parameter. Default:None.
Raises:
ValueError: src_strategy_file or dst_strategy_file is incorrect.
ValueError: item in checkpoint_files_map is incorrect.
ValueError: save_checkpoint_file_name is not end with ".ckpt".
TypeError: checkpoint_files_map is not a dict.
TypeError: src_strategy_file or dst_strategy_file is not a string.
TypeError: rank_id is not a int.
TypeError: save_checkpoint_file_name is not a string.
ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect.
ValueError: item in `checkpoint_files_map` is incorrect.
ValueError: `save_checkpoint_file_name` is not end with ".ckpt".
TypeError: `checkpoint_files_map` is not a dict.
TypeError: `src_strategy_file` or `dst_strategy_file` is not a string.
TypeError: `rank_id` is not a int.
TypeError: `save_checkpoint_file_name` is not a string.
Examples:
>>> dst_device_num = 8
@ -185,6 +185,9 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
raise TypeError("The save_checkpoint_file_name should be a str.")
if save_checkpoint_file_name[-5:] != ".ckpt":
raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name))
if not os.path.exists(os.path.dirname(dst_strategy_file)):
raise ValueError("The director of dst_strategy_file: {} is not exists.".
format(os.path.dirname(dst_strategy_file)))
for rank, local_file in checkpoint_files_map.items():
if not os.path.exists(local_file):
raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank))
@ -238,10 +241,10 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
is without any sharing for each parameter. Default:None.
Raises:
ValueError: src_strategy_file or dst_strategy_file is incorrect.
NotADirectoryError: src_checkpoints_dir or dst_checkpoints_dir is not a directory.
ValueError: The checkpoint file is missing in src_checkpoints_dir.
TypeError: src_strategy_file or dst_strategy_file is not a string.
ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect.
NotADirectoryError: `src_checkpoints_dir` or `dst_checkpoints_dir` is not a directory.
ValueError: The checkpoint file is missing in `src_checkpoints_dir`.
TypeError: `src_strategy_file` or `dst_strategy_file` is not a string.
Examples:
>>> transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, "dst_checkpoint",