mindspore/docs/api/api_python/nn/mindspore.nn.AdaSumByDeltaW...

38 lines
1.9 KiB
ReStructuredText
Raw Normal View History

2022-02-25 11:38:02 +08:00
mindspore.nn.AdaSumByDeltaWeightWrapCell
========================================
.. py:class:: mindspore.nn.AdaSumByDeltaWeightWrapCell(optimizer)
2022-03-23 20:30:59 +08:00
Adaptive Summation (AdaSum)算法的实现根据更新前后的参数差计算。应用于semi_auto_parallel/auto_parallel模式。
2022-02-25 11:38:02 +08:00
请参阅论文 `AdaSum: Scaling Distributed Training with Adaptive Summation <https://arxiv.org/abs/2006.02924>`_
公式如下:
.. math::
\begin{array}{ll}
w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2}) \\
w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + (1 - \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2] \\
\end{array}
在本实现中, :math:`g` 代表优化器更新前后的权重的变化量,下标代表数据并行维度下不同的设备。
.. note::
本接口推荐应用于半自动并行或者全自动并行模式。针对数据并行模式推荐使用mindspore.boost功能以使用AdaSum。
使用本接口时训练的卡的数量必须是2的幂并且至少需要16张卡。目前使用本接口时不支持优化器并行和流水线并行。
**参数:**
2022-03-29 22:00:57 +08:00
- **optimizer** (nn.optimizer) - 必须是单输入的优化器。
2022-02-25 11:38:02 +08:00
**输入:**
2022-03-08 15:28:53 +08:00
- **grads** (tuple[Tensor]) - `params` 的梯度形状shape`params` 相同,与所传优化器的输入一致。
2022-02-25 11:38:02 +08:00
**异常:**
2022-03-01 17:46:19 +08:00
- **RuntimeError** - `parallel_mode` 使用了 `stand_alone` 模式, AdaSum仅支持在分布式场景下使用。
2022-02-25 11:38:02 +08:00
- **RuntimeError** - 同时使用了优化器并行, 暂时不支持在优化器并行场景下使用AdaSum。
- **RuntimeError** - 同时使用了流水线并行, 暂时不支持在流水线并行场景下使用AdaSum。
- **RuntimeError** - `device_num` 不是2的幂或者小于16。