forked from mindspore-Ecosystem/mindspore
!30494 Supplement the docs and log to recompute api of Cell
Merge pull request !30494 from YuJianfeng/recompute
This commit is contained in:
commit
3219508ecd
|
@ -285,12 +285,13 @@
|
||||||
|
|
||||||
.. py:method:: recompute(**kwargs)
|
.. py:method:: recompute(**kwargs)
|
||||||
|
|
||||||
设置Cell重计算。Cell中的所有算子将被设置为重计算。如果一个算子的计算结果被输出到一些反向节点来进行梯度计算,且被设置成重计算,那么我们会在反向传播中重新计算它,而不去存储在前向传播中的中间激活层的计算结果。
|
设置Cell重计算。Cell中输出算子以外的所有算子将被设置为重计算。如果一个算子的计算结果被输出到一些反向节点来进行梯度计算,且被设置成重计算,那么我们会在反向传播中重新计算它,而不去存储在前向传播中的中间激活层的计算结果。
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
- 如果计算涉及到诸如随机化或全局变量之类的操作,那么目前还不能保证等价。
|
- 如果计算涉及到诸如随机化或全局变量之类的操作,那么目前还不能保证等价。
|
||||||
- 如果该Cell中算子的重计算API也被调用,则该算子的重计算模式以算子的重计算API的设置为准。
|
- 如果该Cell中算子的重计算API也被调用,则该算子的重计算模式以算子的重计算API的设置为准。
|
||||||
- 该接口仅配置一次,即当父Cell配置了,子Cell不需再配置。
|
- 该接口仅配置一次,即当父Cell配置了,子Cell不需再配置。
|
||||||
|
- Cell的输出算子默认不做重计算,这一点是基于我们减少内存占用的配置经验。如果一个Cell里面只有一个算子而且想要把这个算子设置为重计算的,那么请使用算子的重计算API。
|
||||||
- 当应用了重计算且内存充足时,可以配置'mp_comm_recompute=False'来提升性能。
|
- 当应用了重计算且内存充足时,可以配置'mp_comm_recompute=False'来提升性能。
|
||||||
- 当应用了重计算但内存不足时,可以配置'parallel_optimizer_comm_recompute=True'来节省内存。有相同融合group的Cell应该配置相同的parallel_optimizer_comm_recompute。
|
- 当应用了重计算但内存不足时,可以配置'parallel_optimizer_comm_recompute=True'来节省内存。有相同融合group的Cell应该配置相同的parallel_optimizer_comm_recompute。
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,12 @@ class SetCellOutputNoRecompute : public AnfVisitor {
|
||||||
if (output->isa<CNode>()) {
|
if (output->isa<CNode>()) {
|
||||||
mindspore::HashSet<CNodePtr> real_outputs;
|
mindspore::HashSet<CNodePtr> real_outputs;
|
||||||
GetRealOutputNodes(output, &real_outputs);
|
GetRealOutputNodes(output, &real_outputs);
|
||||||
|
if (OutputAllNodes(real_outputs)) {
|
||||||
|
MS_LOG(WARNING)
|
||||||
|
<< "All nodes in the graph " << fg->ToString()
|
||||||
|
<< " are the output nodes, which are set to not be recomputed. If you want to set these nodes to "
|
||||||
|
"be recomputed, use the api recompute() of Primitive.";
|
||||||
|
}
|
||||||
for (const auto &real_output : real_outputs) {
|
for (const auto &real_output : real_outputs) {
|
||||||
// Set the attr of cnode in case of shared primitives.
|
// Set the attr of cnode in case of shared primitives.
|
||||||
real_output->AddAttr(kAttrRecompute, MakeValue(false));
|
real_output->AddAttr(kAttrRecompute, MakeValue(false));
|
||||||
|
@ -120,6 +126,22 @@ class SetCellOutputNoRecompute : public AnfVisitor {
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool OutputAllNodes(const mindspore::HashSet<CNodePtr> &real_outputs) {
|
||||||
|
for (const auto &cnode : real_outputs) {
|
||||||
|
const auto &inputs = cnode->inputs();
|
||||||
|
for (const auto &input : inputs) {
|
||||||
|
auto input_cnode = input->cast<CNodePtr>();
|
||||||
|
if (input_cnode == nullptr || IsPrimitiveCNode(input_cnode, prim::kPrimLoad)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (real_outputs.find(input_cnode) == real_outputs.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace irpass
|
} // namespace irpass
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -1851,9 +1851,9 @@ class Cell(Cell_):
|
||||||
@args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool)
|
@args_type_check(mp_comm_recompute=bool, parallel_optimizer_comm_recompute=bool)
|
||||||
def recompute(self, **kwargs):
|
def recompute(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
|
Set the cell recomputed. All the primitive in the cell except the outputs will be set recomputed.
|
||||||
set recomputed feeds into some backward nodes for computing gradient, rather than storing the
|
If a primitive set recomputed feeds into some backward nodes for computing gradient, rather than
|
||||||
intermediate activation computed in forward pass, we will recompute it in backward pass.
|
storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
|
||||||
|
@ -1863,6 +1863,9 @@ class Cell(Cell_):
|
||||||
primitive is subject to the recompute api of the primitive.
|
primitive is subject to the recompute api of the primitive.
|
||||||
- The interface can be configured only once.
|
- The interface can be configured only once.
|
||||||
Therefore, when the parent cell is configured, the child cell should not be configured.
|
Therefore, when the parent cell is configured, the child cell should not be configured.
|
||||||
|
- The outputs of cell are excluded from recomputation by default, which is based on our configuration
|
||||||
|
experience to reduce memory footprint. If a cell has only one primitive and the primitive is wanted
|
||||||
|
to be set recomputed, use the recompute api of the primtive.
|
||||||
- When the memory remains after applying the recomputation, configuring 'mp_comm_recompute=False'
|
- When the memory remains after applying the recomputation, configuring 'mp_comm_recompute=False'
|
||||||
to improve performance if necessary.
|
to improve performance if necessary.
|
||||||
- When the memory still not enough after applying the recompute, configuring
|
- When the memory still not enough after applying the recompute, configuring
|
||||||
|
|
Loading…
Reference in New Issue