Update recompute python api

This commit is contained in:
yujianfeng 2021-02-24 15:06:00 +08:00
parent a38c996c9c
commit 047e006aab
2 changed files with 11 additions and 4 deletions

View File

@ -1144,8 +1144,14 @@ class Cell(Cell_):
def recompute(self, mode=True):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad
node and is set recomputed, we will compute it again for the grad node after the forward computation.
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
set recomputed feeds into a gradient node, we will compute it again for the gradient node
after the forward computation.
Note:
If the recompute api of a primtive in this cell is also called, the recompute mode of this
primitive is subject to the recompute api of the primitive.
Args:
mode (bool): Specifies whether the cell is recomputed. Default: True.
"""

View File

@ -222,8 +222,9 @@ class Primitive(Primitive_):
def recompute(self, mode):
"""
Set the primitive recomputed. If a primitive feeds into a grad node and is set recomputed,
we will compute it again for the grad node after the forward computation.
Set the primitive recomputed. If a primitive set recomputed feeds into a gradient node,
we will compute it again for the gradient node after the forward computation.
Args:
mode (bool): Specifies whether the primitive is recomputed. Default: True.
"""