forked from mindspore-Ecosystem/mindspore
!11009 Fix continuous calls for recompute api
From: @ginfung Reviewed-by: @zhunaipan,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
21addb331d
|
@ -31,7 +31,6 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr auto kGradientsFlag = "Gradients";
|
||||
constexpr auto kAttrRecompute = "recompute";
|
||||
constexpr auto kAttrNoRecompute = "no_recompute";
|
||||
bool IsBpropNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
|
@ -46,8 +45,7 @@ bool WithRecomputedScope(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
auto full_name_with_scope = node->fullname_with_scope();
|
||||
return full_name_with_scope.find(kAttrRecompute) == 0 &&
|
||||
full_name_with_scope.find(kAttrNoRecompute) == full_name_with_scope.npos;
|
||||
return full_name_with_scope.find(kAttrRecompute) == 0;
|
||||
}
|
||||
|
||||
bool HasRecomputeCNodeAttr(const AnfNodePtr &node) {
|
||||
|
|
|
@ -913,8 +913,11 @@ class Cell(Cell_):
|
|||
"""Sets the name on the first time."""
|
||||
if self._scope is None:
|
||||
self._scope = name
|
||||
elif self._scope == 'recomputed':
|
||||
self._scope = self._scope + "_" + name
|
||||
elif self._scope == 'recompute':
|
||||
if name is None:
|
||||
self._scope = None
|
||||
elif name != 'recompute':
|
||||
self._scope = self._scope + '_' + name
|
||||
|
||||
def _children_scope_recursive(self, parent_prefix='Default'):
|
||||
"""Generates the scope of each layer of the network recursively."""
|
||||
|
@ -1102,10 +1105,11 @@ class Cell(Cell_):
|
|||
Args:
|
||||
mode (bool): Specifies whether the cell is recomputed. Default: True.
|
||||
"""
|
||||
Validator.check_bool(mode)
|
||||
if mode is True:
|
||||
self._set_scope("recompute")
|
||||
else:
|
||||
self._set_scope("no_recompute")
|
||||
self._set_scope(None)
|
||||
for cell in self.cells():
|
||||
cell.recompute(mode)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import copy
|
|||
from mindspore.common.api import _wrap_func
|
||||
from mindspore import context
|
||||
from .._c_expression import Primitive_, real_run_op, prim_type
|
||||
from .._checkparam import Validator
|
||||
from . import signature as sig
|
||||
|
||||
|
||||
|
@ -213,6 +214,7 @@ class Primitive(Primitive_):
|
|||
Args:
|
||||
mode (bool): Specifies whether the primitive is recomputed. Default: True.
|
||||
"""
|
||||
Validator.check_bool(mode)
|
||||
self.add_prim_attr("recompute", mode)
|
||||
return self
|
||||
|
||||
|
|
Loading…
Reference in New Issue