!6001 raise exception when use HookBackward in graph mode

Merge pull request !6001 from zhangbuxue/check_is_graph_mode_for_HookBackward
This commit is contained in:
mindspore-ci-bot 2020-09-10 21:31:41 +08:00 committed by Gitee
commit 77e05e32a4
6 changed files with 8 additions and 5 deletions

View File

@ -120,6 +120,9 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
FuncGraphPtr bprop_fg = nullptr;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
if (MsContext::GetInstance()->get_param<int>(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) {
MS_LOG(EXCEPTION) << "HookBackward is not supported in graph mode.";
}
bprop_fg = BpropCut(value_node, resources);
} else {
auto iter = bprop_registry_.find(prim);

View File

@ -1198,7 +1198,7 @@ void ClearPrimEvaluatorMap() {
GetUniformPrimitiveToImplMap().clear();
}
bool IsInWhiteList(const PrimitivePtr primitive) {
bool IsInWhiteList(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto iter = GetPrimitiveToEvalImplMap().find(primitive);

View File

@ -111,7 +111,7 @@ class MixedPrecisionCastEvaluator : public Evaluator {
PrimitivePtr prim_;
};
bool IsInWhiteList(PrimitivePtr primitive);
bool IsInWhiteList(const PrimitivePtr &primitive);
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
using ValuePtrList = std::vector<ValuePtr>;

View File

@ -47,7 +47,7 @@ void ValidateOperation(const AnfNodePtr &node) {
}
// Primitive must in whitelist
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
auto prim = GetValueNode<PrimitivePtr>(node);
if (abstract::IsInWhiteList(prim)) {
return;
}

View File

@ -1257,7 +1257,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
if (need_replace_param) {
auto params = newfg->parameters();
auto manager = Manage({newfg}, false);
for (size_t i = 0; i < params.size(); i++) {
for (size_t i = 0; i < args.size(); i++) {
ValuePtr value = PyAttrValue(args[i]);
auto v_node = NewValueNode(value);
manager->Replace(params[i], v_node);

View File

@ -304,7 +304,7 @@ class Cell(Cell_):
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
for i, cell in enumerate(self.cells()):
cell.set_grad(origin_grad[i])
self._already_run = True
self._already_run = True
return output
def _add_attr(self, name, value):