!6001 raise exception when use HookBackward in graph mode
Merge pull request !6001 from zhangbuxue/check_is_graph_mode_for_HookBackward
This commit is contained in:
commit
77e05e32a4
|
@ -120,6 +120,9 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
|||
|
||||
FuncGraphPtr bprop_fg = nullptr;
|
||||
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
|
||||
if (MsContext::GetInstance()->get_param<int>(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
MS_LOG(EXCEPTION) << "HookBackward is not supported in graph mode.";
|
||||
}
|
||||
bprop_fg = BpropCut(value_node, resources);
|
||||
} else {
|
||||
auto iter = bprop_registry_.find(prim);
|
||||
|
|
|
@ -1198,7 +1198,7 @@ void ClearPrimEvaluatorMap() {
|
|||
GetUniformPrimitiveToImplMap().clear();
|
||||
}
|
||||
|
||||
bool IsInWhiteList(const PrimitivePtr primitive) {
|
||||
bool IsInWhiteList(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
||||
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
|
||||
|
|
|
@ -111,7 +111,7 @@ class MixedPrecisionCastEvaluator : public Evaluator {
|
|||
PrimitivePtr prim_;
|
||||
};
|
||||
|
||||
bool IsInWhiteList(PrimitivePtr primitive);
|
||||
bool IsInWhiteList(const PrimitivePtr &primitive);
|
||||
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
|
||||
|
||||
using ValuePtrList = std::vector<ValuePtr>;
|
||||
|
|
|
@ -47,7 +47,7 @@ void ValidateOperation(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
// Primitive must in whitelist
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
|
||||
auto prim = GetValueNode<PrimitivePtr>(node);
|
||||
if (abstract::IsInWhiteList(prim)) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1257,7 +1257,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
|||
if (need_replace_param) {
|
||||
auto params = newfg->parameters();
|
||||
auto manager = Manage({newfg}, false);
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
ValuePtr value = PyAttrValue(args[i]);
|
||||
auto v_node = NewValueNode(value);
|
||||
manager->Replace(params[i], v_node);
|
||||
|
|
|
@ -304,7 +304,7 @@ class Cell(Cell_):
|
|||
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
|
||||
for i, cell in enumerate(self.cells()):
|
||||
cell.set_grad(origin_grad[i])
|
||||
self._already_run = True
|
||||
self._already_run = True
|
||||
return output
|
||||
|
||||
def _add_attr(self, name, value):
|
||||
|
|
Loading…
Reference in New Issue