From 7cac3d3a476159ad5a82b716fba7e3be41d2ad6b Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Wed, 6 Jan 2021 15:05:54 +0800 Subject: [PATCH] Fix continuous calls for recompute api --- mindspore/ccsrc/frontend/optimizer/recompute.cc | 4 +--- mindspore/nn/cell.py | 10 +++++++--- mindspore/ops/primitive.py | 2 ++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/recompute.cc b/mindspore/ccsrc/frontend/optimizer/recompute.cc index 68cc76816f7..d4f7f23f866 100644 --- a/mindspore/ccsrc/frontend/optimizer/recompute.cc +++ b/mindspore/ccsrc/frontend/optimizer/recompute.cc @@ -31,7 +31,6 @@ namespace opt { namespace { constexpr auto kGradientsFlag = "Gradients"; constexpr auto kAttrRecompute = "recompute"; -constexpr auto kAttrNoRecompute = "no_recompute"; bool IsBpropNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { @@ -46,8 +45,7 @@ bool WithRecomputedScope(const AnfNodePtr &node) { return false; } auto full_name_with_scope = node->fullname_with_scope(); - return full_name_with_scope.find(kAttrRecompute) == 0 && - full_name_with_scope.find(kAttrNoRecompute) == full_name_with_scope.npos; + return full_name_with_scope.find(kAttrRecompute) == 0; } bool HasRecomputeCNodeAttr(const AnfNodePtr &node) { diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 7d03cd13dbf..da7c5137eec 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -913,8 +913,11 @@ class Cell(Cell_): """Sets the name on the first time.""" if self._scope is None: self._scope = name - elif self._scope == 'recomputed': - self._scope = self._scope + "_" + name + elif self._scope == 'recompute': + if name is None: + self._scope = None + elif name != 'recompute': + self._scope = self._scope + '_' + name def _children_scope_recursive(self, parent_prefix='Default'): """Generates the scope of each layer of the network recursively.""" @@ -1102,10 +1105,11 @@ class Cell(Cell_): Args: mode (bool): Specifies whether the cell is recomputed. Default: True. """ + Validator.check_bool(mode) if mode is True: self._set_scope("recompute") else: - self._set_scope("no_recompute") + self._set_scope(None) for cell in self.cells(): cell.recompute(mode) diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 79c3fc2e2cc..d2008ba770f 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -19,6 +19,7 @@ import copy from mindspore.common.api import _wrap_func from mindspore import context from .._c_expression import Primitive_, real_run_op, prim_type +from .._checkparam import Validator from . import signature as sig @@ -213,6 +214,7 @@ class Primitive(Primitive_): Args: mode (bool): Specifies whether the primitive is recomputed. Default: True. """ + Validator.check_bool(mode) self.add_prim_attr("recompute", mode) return self