forked from mindspore-Ecosystem/mindspore
bprop of kPrimHookBackward don't cache
prim name and attr can't unique a kPrimHookBackward
This commit is contained in:
parent
c8f9be7d72
commit
965cd35890
|
@ -92,7 +92,7 @@ void PrimBpropOptGraphLevel2Info::AnalysisNodeUsingInfo(
|
|||
for (auto &user_info : users_info) {
|
||||
auto user_node = user_info.first;
|
||||
arg_info.using_flg_ = true;
|
||||
MS_LOG(WARNING) << "param:" << param->ToString() << " used by node:" << user_node->ToString();
|
||||
MS_LOG(DEBUG) << "param:" << param->ToString() << " used by node:" << user_node->ToString();
|
||||
if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
|
||||
MS_LOG(EXCEPTION) << "tuple param:" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
|
||||
<< " unexpect used by node:" << user_node->ToString();
|
||||
|
@ -184,8 +184,8 @@ FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bpro
|
|||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(inputs[0]);
|
||||
MS_LOG(DEBUG) << "Hash of prim " << prim->ToString() << " is:" << prim->hash();
|
||||
|
||||
// kPrimBpropCut
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimBpropCut)) {
|
||||
// kPrimHookBackward
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward)) {
|
||||
return GenSpecOptBprop(bprop_fg, op_args, out, prim);
|
||||
}
|
||||
|
||||
|
@ -338,10 +338,5 @@ abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr
|
|||
return new_abs_list;
|
||||
}
|
||||
|
||||
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
|
||||
const ValuePtr &out) {
|
||||
return PrimBpropOptimizer::GetPrimBpropOptimizerInst().OptimizeBPropFuncGraph(bprop_fg, c_node, op_args, out);
|
||||
}
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
|
@ -149,13 +149,6 @@ private:
|
|||
|
||||
};
|
||||
|
||||
// bprop_fg has the signature:
|
||||
// (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
|
||||
// c_node contains the prim(input 0) and the input parameters of that prim;
|
||||
// op_args contains the arguments list of each input parameters, it maybe tensor or tuple
|
||||
// out contains the out of c_node;
|
||||
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
|
||||
const ValuePtr &out);
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue