diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index bbff6fbec3f..8144fd58906 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -37,7 +37,7 @@ namespace ad { std::unordered_map DFunctor::func_graph_to_functor_; std::unordered_map DFunctor::anfnode_to_adjoin_definition_; -int lift_fv_before_grad = -1; +bool lift_fv_before_grad = true; DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources) : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { @@ -76,7 +76,7 @@ void DFunctor::Clear() { void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { MS_EXCEPTION_IF_NULL(fv); - if (lift_fv_before_grad == 1) { + if (lift_fv_before_grad) { MS_EXCEPTION_IF_NULL(fv->func_graph()); MS_LOG(EXCEPTION) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv:" << fv->func_graph()->ToString() << " " << fv->ToString() << "."; @@ -446,7 +446,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { // Add grads wrt fv. const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); if (!is_top_ && free_variables_nodes.size() != 0) { - if (lift_fv_before_grad == 1) { + if (lift_fv_before_grad) { MS_LOG(EXCEPTION) << "direct fv size is: " << free_variables_nodes.size() << " in " << primal_graph_->ToString() << "."; } @@ -475,7 +475,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { } AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) { - if (lift_fv_before_grad == 1) { + if (lift_fv_before_grad) { MS_LOG(EXCEPTION) << "Lift free variable case: AttachIndirectFvDoutToTape backprop indirect fv " << grad_fv->ToString() << " " << primal_graph_->ToString() << "."; } @@ -517,7 +517,7 @@ void DFunctor::MapMorphism() { // Set output for tape closure. AnfNodePtr grad_fv; - if (lift_fv_before_grad == 1) { + if (lift_fv_before_grad) { grad_fv = AttachFvDoutToTape(NewValueNode(newenv)); } else { grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv))); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index 1d0a111271e..7aca3e21b0c 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -44,7 +44,7 @@ class DFunctor; using DFunctorPtr = std::shared_ptr; // Flag to control if fv should be lifted before grad. If this lift_fv feature is mature, then this flag can be removed. -extern int lift_fv_before_grad; +extern bool lift_fv_before_grad; // D Functor's rules to map closure object and morphisms. class DFunctor : public std::enable_shared_from_this { diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc index f6102bf0167..7adb1f5800f 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc @@ -60,16 +60,6 @@ FuncGraphPtr LiftFv(const pipeline::ResourceBasePtr &resource, const FuncGraphPt } return opt_fg; } - -bool NeedLiftFv(const FuncGraphPtr &func_graph) { - size_t switch_num = 0; - for (auto node : func_graph->manager()->all_nodes()) { - if (IsPrimitiveCNode(node, prim::kPrimSwitch)) { - ++switch_num; - } - } - return switch_num > 1; -} } // namespace FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) { @@ -84,11 +74,8 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt manager_ptr->AddFuncGraph(func_graph); FuncGraphPtr grad_fg = func_graph; - // Only calculate lift_fv_before_grad once. - if (lift_fv_before_grad == -1) { - lift_fv_before_grad = (common::GetEnv("ENV_DONT_LIFT_FV_BEFORE_GRAD") != "1") && NeedLiftFv(func_graph) ? 1 : 0; - } - if (lift_fv_before_grad == 1 && func_graph->func_graphs_used().size() != 0) { + lift_fv_before_grad = (common::GetEnv("ENV_DONT_LIFT_FV_BEFORE_GRAD") != "1"); + if (lift_fv_before_grad && func_graph->func_graphs_used().size() != 0) { grad_fg = LiftFv(resources, func_graph); } auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {