!11009 Fix continuous calls for recompute api

From: @ginfung
Reviewed-by: @zhunaipan,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-01-11 15:13:27 +08:00 committed by Gitee
commit 21addb331d
3 changed files with 10 additions and 6 deletions

View File

@ -31,7 +31,6 @@ namespace opt {
namespace {
constexpr auto kGradientsFlag = "Gradients";
constexpr auto kAttrRecompute = "recompute";
constexpr auto kAttrNoRecompute = "no_recompute";
bool IsBpropNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
@ -46,8 +45,7 @@ bool WithRecomputedScope(const AnfNodePtr &node) {
return false;
}
auto full_name_with_scope = node->fullname_with_scope();
return full_name_with_scope.find(kAttrRecompute) == 0 &&
full_name_with_scope.find(kAttrNoRecompute) == full_name_with_scope.npos;
return full_name_with_scope.find(kAttrRecompute) == 0;
}
bool HasRecomputeCNodeAttr(const AnfNodePtr &node) {

View File

@ -913,8 +913,11 @@ class Cell(Cell_):
"""Sets the name on the first time."""
if self._scope is None:
self._scope = name
elif self._scope == 'recomputed':
self._scope = self._scope + "_" + name
elif self._scope == 'recompute':
if name is None:
self._scope = None
elif name != 'recompute':
self._scope = self._scope + '_' + name
def _children_scope_recursive(self, parent_prefix='Default'):
"""Generates the scope of each layer of the network recursively."""
@ -1102,10 +1105,11 @@ class Cell(Cell_):
Args:
mode (bool): Specifies whether the cell is recomputed. Default: True.
"""
Validator.check_bool(mode)
if mode is True:
self._set_scope("recompute")
else:
self._set_scope("no_recompute")
self._set_scope(None)
for cell in self.cells():
cell.recompute(mode)

View File

@ -19,6 +19,7 @@ import copy
from mindspore.common.api import _wrap_func
from mindspore import context
from .._c_expression import Primitive_, real_run_op, prim_type
from .._checkparam import Validator
from . import signature as sig
@ -213,6 +214,7 @@ class Primitive(Primitive_):
Args:
mode (bool): Specifies whether the primitive is recomputed. Default: True.
"""
Validator.check_bool(mode)
self.add_prim_attr("recompute", mode)
return self