!30494 Supplement the docs and log to recompute api of Cell
Merge pull request !30494 from YuJianfeng/recompute
This commit is contained in:
commit
3219508ecd
|
@ -285,12 +285,13 @@
|
|||
|
||||
.. py:method:: recompute(**kwargs)
|
||||
|
||||
设置Cell重计算。Cell中的所有算子将被设置为重计算。如果一个算子的计算结果被输出到一些反向节点来进行梯度计算,且被设置成重计算,那么我们会在反向传播中重新计算它,而不去存储在前向传播中的中间激活层的计算结果。
|
||||
设置Cell重计算。Cell中输出算子以外的所有算子将被设置为重计算。如果一个算子的计算结果被输出到一些反向节点来进行梯度计算,且被设置成重计算,那么我们会在反向传播中重新计算它,而不去存储在前向传播中的中间激活层的计算结果。
|
||||
|
||||
.. note::
|
||||
- 如果计算涉及到诸如随机化或全局变量之类的操作,那么目前还不能保证等价。
|
||||
- 如果该Cell中算子的重计算API也被调用,则该算子的重计算模式以算子的重计算API的设置为准。
|
||||
- 该接口仅配置一次,即当父Cell配置了,子Cell不需再配置。
|
||||
- Cell的输出算子默认不做重计算,这一点是基于我们减少内存占用的配置经验。如果一个Cell里面只有一个算子而且想要把这个算子设置为重计算的,那么请使用算子的重计算API。
|
||||
- 当应用了重计算且内存充足时,可以配置'mp_comm_recompute=False'来提升性能。
|
||||
- 当应用了重计算但内存不足时,可以配置'parallel_optimizer_comm_recompute=True'来节省内存。有相同融合group的Cell应该配置相同的parallel_optimizer_comm_recompute。
|
||||
|
||||
|
|
|
@ -45,6 +45,12 @@ class SetCellOutputNoRecompute : public AnfVisitor {
|
|||
if (output->isa<CNode>()) {
|
||||
mindspore::HashSet<CNodePtr> real_outputs;
|
||||
GetRealOutputNodes(output, &real_outputs);
|
||||
if (OutputAllNodes(real_outputs)) {
|
||||
MS_LOG(WARNING)
|
||||
<< "All nodes in the graph " << fg->ToString()
|
||||
<< " are the output nodes, which are set to not be recomputed. If you want to set these nodes to "
|
||||
"be recomputed, use the api recompute() of Primitive.";
|
||||
}
|
||||
for (const auto &real_output : real_outputs) {
|
||||
// Set the attr of cnode in case of shared primitives.
|
||||
real_output->AddAttr(kAttrRecompute, MakeValue(false));
|
||||
|
@ -120,6 +126,22 @@ class SetCellOutputNoRecompute : public AnfVisitor {
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool OutputAllNodes(const mindspore::HashSet<CNodePtr> &real_outputs) {
|
||||
for (const auto &cnode : real_outputs) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
for (const auto &input : inputs) {
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
if (input_cnode == nullptr || IsPrimitiveCNode(input_cnode, prim::kPrimLoad)) {
|
||||
continue;
|
||||
}
|
||||
if (real_outputs.find(input_cnode) == real_outputs.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -1851,9 +1851,9 @@ class Cell(Cell_):
|
|||
@args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool)
|
||||
def recompute(self, **kwargs):
|
||||
"""
|
||||
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
|
||||
set recomputed feeds into some backward nodes for computing gradient, rather than storing the
|
||||
intermediate activation computed in forward pass, we will recompute it in backward pass.
|
||||
Set the cell recomputed. All the primitive in the cell except the outputs will be set recomputed.
|
||||
If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than
|
||||
storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
|
||||
|
||||
Note:
|
||||
|
||||
|
@ -1863,6 +1863,9 @@ class Cell(Cell_):
|
|||
primitive is subject to the recompute api of the primitive.
|
||||
- The interface can be configured only once.
|
||||
Therefore, when the parent cell is configured, the child cell should not be configured.
|
||||
- The outputs of cell are excluded from recomputation by default, which is based on our configuration
|
||||
experience to reduce memory footprint. If a cell has only one primitive and the primitive is wanted
|
||||
to be set recomputed, use the recompute api of the primtive.
|
||||
- When the memory remains after applying the recomputation, configuring 'mp_comm_recompute=False'
|
||||
to improve performance if necessary.
|
||||
- When the memory still not enough after applying the recompute, configuring
|
||||
|
|
Loading…
Reference in New Issue