diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 028a972433..a1a28bf57a 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -120,6 +120,9 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R FuncGraphPtr bprop_fg = nullptr; if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { + if (MsContext::GetInstance()->get_param(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) { + MS_LOG(EXCEPTION) << "HookBackward is not supported in graph mode."; + } bprop_fg = BpropCut(value_node, resources); } else { auto iter = bprop_registry_.find(prim); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index ba767052ce..1173fde886 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -1198,7 +1198,7 @@ void ClearPrimEvaluatorMap() { GetUniformPrimitiveToImplMap().clear(); } -bool IsInWhiteList(const PrimitivePtr primitive) { +bool IsInWhiteList(const PrimitivePtr &primitive) { MS_EXCEPTION_IF_NULL(primitive); auto iter = GetPrimitiveToEvalImplMap().find(primitive); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 477d9ba861..1dd8468372 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -111,7 +111,7 @@ class MixedPrecisionCastEvaluator : public Evaluator { PrimitivePtr prim_; }; -bool IsInWhiteList(PrimitivePtr primitive); +bool IsInWhiteList(const PrimitivePtr &primitive); StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); using ValuePtrList = std::vector; diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 6046666add..ad2b127fdd 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -47,7 +47,7 @@ void ValidateOperation(const AnfNodePtr &node) { } // Primitive must in whitelist - PrimitivePtr prim = GetValueNode(node); + auto prim = GetValueNode(node); if (abstract::IsInWhiteList(prim)) { return; } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 0a0ac388dc..199a083fcc 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1257,7 +1257,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje if (need_replace_param) { auto params = newfg->parameters(); auto manager = Manage({newfg}, false); - for (size_t i = 0; i < params.size(); i++) { + for (size_t i = 0; i < args.size(); i++) { ValuePtr value = PyAttrValue(args[i]); auto v_node = NewValueNode(value); manager->Replace(params[i], v_node); diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index f2098ac3db..9c3c95636a 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -304,7 +304,7 @@ class Cell(Cell_): _pynative_exec.end_graph(self, output, *inputs, **kwargs) for i, cell in enumerate(self.cells()): cell.set_grad(origin_grad[i]) - self._already_run = True + self._already_run = True return output def _add_attr(self, name, value):