!30494 Supplement the docs and log to recompute api of Cell

Merge pull request !30494 from YuJianfeng/recompute
This commit is contained in:
i-robot 2022-02-24 13:04:57 +00:00 committed by Gitee
commit 3219508ecd
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 30 additions and 4 deletions

View File

@ -285,12 +285,13 @@
.. py:method:: recompute(**kwargs)
设置Cell重计算。Cell中的所有算子将被设置为重计算。如果一个算子的计算结果被输出到一些反向节点来进行梯度计算且被设置成重计算那么我们会在反向传播中重新计算它而不去存储在前向传播中的中间激活层的计算结果。
设置Cell重计算。Cell中输出算子以外的所有算子将被设置为重计算。如果一个算子的计算结果被输出到一些反向节点来进行梯度计算,且被设置成重计算,那么我们会在反向传播中重新计算它,而不去存储在前向传播中的中间激活层的计算结果。
.. note::
- 如果计算涉及到诸如随机化或全局变量之类的操作,那么目前还不能保证等价。
- 如果该Cell中算子的重计算API也被调用则该算子的重计算模式以算子的重计算API的设置为准。
- 该接口仅配置一次即当父Cell配置了子Cell不需再配置。
- Cell的输出算子默认不做重计算这一点是基于我们减少内存占用的配置经验。如果一个Cell里面只有一个算子而且想要把这个算子设置为重计算的那么请使用算子的重计算API。
- 当应用了重计算且内存充足时,可以配置'mp_comm_recompute=False'来提升性能。
- 当应用了重计算但内存不足时,可以配置'parallel_optimizer_comm_recompute=True'来节省内存。有相同融合group的Cell应该配置相同的parallel_optimizer_comm_recompute。

View File

@ -45,6 +45,12 @@ class SetCellOutputNoRecompute : public AnfVisitor {
if (output->isa<CNode>()) {
mindspore::HashSet<CNodePtr> real_outputs;
GetRealOutputNodes(output, &real_outputs);
if (OutputAllNodes(real_outputs)) {
MS_LOG(WARNING)
<< "All nodes in the graph " << fg->ToString()
<< " are the output nodes, which are set to not be recomputed. If you want to set these nodes to "
"be recomputed, use the api recompute() of Primitive.";
}
for (const auto &real_output : real_outputs) {
// Set the attr of cnode in case of shared primitives.
real_output->AddAttr(kAttrRecompute, MakeValue(false));
@ -120,6 +126,22 @@ class SetCellOutputNoRecompute : public AnfVisitor {
}
return nullptr;
}
bool OutputAllNodes(const mindspore::HashSet<CNodePtr> &real_outputs) {
for (const auto &cnode : real_outputs) {
const auto &inputs = cnode->inputs();
for (const auto &input : inputs) {
auto input_cnode = input->cast<CNodePtr>();
if (input_cnode == nullptr || IsPrimitiveCNode(input_cnode, prim::kPrimLoad)) {
continue;
}
if (real_outputs.find(input_cnode) == real_outputs.end()) {
return false;
}
}
}
return true;
}
};
} // namespace irpass
} // namespace opt

View File

@ -1851,9 +1851,9 @@ class Cell(Cell_):
@args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool)
def recompute(self, **kwargs):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
set recomputed feeds into some backward nodes for computing gradient, rather than storing the
intermediate activation computed in forward pass, we will recompute it in backward pass.
Set the cell recomputed. All the primitive in the cell except the outputs will be set recomputed.
If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than
storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
Note:
@ -1863,6 +1863,9 @@ class Cell(Cell_):
primitive is subject to the recompute api of the primitive.
- The interface can be configured only once.
Therefore, when the parent cell is configured, the child cell should not be configured.
- The outputs of cell are excluded from recomputation by default, which is based on our configuration
experience to reduce memory footprint. If a cell has only one primitive and the primitive is wanted
to be set recomputed, use the recompute api of the primtive.
- When the memory remains after applying the recomputation, configuring 'mp_comm_recompute=False'
to improve performance if necessary.
- When the memory still not enough after applying the recompute, configuring