add_TransformerRecomputeConfig_API

This commit is contained in:
wangshengnan12@huawei.com 2022-02-25 10:59:59 +08:00
parent c554d4a8b1
commit bb761a47be
2 changed files with 12 additions and 0 deletions

View File

@ -19,6 +19,7 @@ mindspore.nn.transformer
.. include:: transformer/mindspore.nn.OpParallelConfig.rst
.. include:: transformer/mindspore.nn.FixedSparseAttention.rst
.. include:: transformer/mindspore.nn.MoEConfig.rst
.. include:: transformer/mindspore.nn.TransformerRecomputeConfig.rst
.. automodule:: mindspore.nn.transformer
:members:

View File

@ -0,0 +1,11 @@
.. py:class:: mindspore.nn.transformer.TransformerRecomputeConfig(recompute=False, parallel_optimizer_comm_recompute=False,
mp_comm_recompute=True, recompute_slice_activation=False)
Transformer的重计算配置接口。
**参数:**
- **recompute** (bool) - 是否使能重计算。默认值为False。
- **parallel_optimizer_comm_recompute** (bool) - 指定由优化器切分产生的AllGather算子是否进行重计算。默认值为False。
- **mp_comm_recompute** (bool) - 指定由模型并行成分产生的通信算子是否进行重计算。默认值为False。
- **recompute_slice_activation** (bool) - 指定激活层是否切片保存。默认值为False。