!20506 revert the workaround in pr 19949
Merge pull request !20506 from xychow/revert-workaround-in-pr-19949
This commit is contained in:
commit
4207e2851c
|
@ -37,7 +37,7 @@ namespace ad {
|
|||
std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
|
||||
std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
|
||||
|
||||
int lift_fv_before_grad = -1;
|
||||
bool lift_fv_before_grad = true;
|
||||
|
||||
DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
|
||||
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
|
||||
|
@ -76,7 +76,7 @@ void DFunctor::Clear() {
|
|||
|
||||
void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
||||
MS_EXCEPTION_IF_NULL(fv);
|
||||
if (lift_fv_before_grad == 1) {
|
||||
if (lift_fv_before_grad) {
|
||||
MS_EXCEPTION_IF_NULL(fv->func_graph());
|
||||
MS_LOG(EXCEPTION) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv:"
|
||||
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
||||
|
@ -446,7 +446,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
|
|||
// Add grads wrt fv.
|
||||
const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
|
||||
if (!is_top_ && free_variables_nodes.size() != 0) {
|
||||
if (lift_fv_before_grad == 1) {
|
||||
if (lift_fv_before_grad) {
|
||||
MS_LOG(EXCEPTION) << "direct fv size is: " << free_variables_nodes.size() << " in " << primal_graph_->ToString()
|
||||
<< ".";
|
||||
}
|
||||
|
@ -475,7 +475,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
|
|||
}
|
||||
|
||||
AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
|
||||
if (lift_fv_before_grad == 1) {
|
||||
if (lift_fv_before_grad) {
|
||||
MS_LOG(EXCEPTION) << "Lift free variable case: AttachIndirectFvDoutToTape backprop indirect fv "
|
||||
<< grad_fv->ToString() << " " << primal_graph_->ToString() << ".";
|
||||
}
|
||||
|
@ -517,7 +517,7 @@ void DFunctor::MapMorphism() {
|
|||
|
||||
// Set output for tape closure.
|
||||
AnfNodePtr grad_fv;
|
||||
if (lift_fv_before_grad == 1) {
|
||||
if (lift_fv_before_grad) {
|
||||
grad_fv = AttachFvDoutToTape(NewValueNode(newenv));
|
||||
} else {
|
||||
grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
|
||||
|
|
|
@ -44,7 +44,7 @@ class DFunctor;
|
|||
using DFunctorPtr = std::shared_ptr<DFunctor>;
|
||||
|
||||
// Flag to control if fv should be lifted before grad. If this lift_fv feature is mature, then this flag can be removed.
|
||||
extern int lift_fv_before_grad;
|
||||
extern bool lift_fv_before_grad;
|
||||
|
||||
// D Functor's rules to map closure object and morphisms.
|
||||
class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
||||
|
|
|
@ -60,16 +60,6 @@ FuncGraphPtr LiftFv(const pipeline::ResourceBasePtr &resource, const FuncGraphPt
|
|||
}
|
||||
return opt_fg;
|
||||
}
|
||||
|
||||
bool NeedLiftFv(const FuncGraphPtr &func_graph) {
|
||||
size_t switch_num = 0;
|
||||
for (auto node : func_graph->manager()->all_nodes()) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
|
||||
++switch_num;
|
||||
}
|
||||
}
|
||||
return switch_num > 1;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) {
|
||||
|
@ -84,11 +74,8 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
|||
manager_ptr->AddFuncGraph(func_graph);
|
||||
|
||||
FuncGraphPtr grad_fg = func_graph;
|
||||
// Only calculate lift_fv_before_grad once.
|
||||
if (lift_fv_before_grad == -1) {
|
||||
lift_fv_before_grad = (common::GetEnv("ENV_DONT_LIFT_FV_BEFORE_GRAD") != "1") && NeedLiftFv(func_graph) ? 1 : 0;
|
||||
}
|
||||
if (lift_fv_before_grad == 1 && func_graph->func_graphs_used().size() != 0) {
|
||||
lift_fv_before_grad = (common::GetEnv("ENV_DONT_LIFT_FV_BEFORE_GRAD") != "1");
|
||||
if (lift_fv_before_grad && func_graph->func_graphs_used().size() != 0) {
|
||||
grad_fg = LiftFv(resources, func_graph);
|
||||
}
|
||||
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
|
||||
|
|
Loading…
Reference in New Issue