!20506 revert the workaround in pr 19949

Merge pull request !20506 from xychow/revert-workaround-in-pr-19949
This commit is contained in:
i-robot 2021-07-20 12:38:15 +00:00 committed by Gitee
commit 4207e2851c
3 changed files with 8 additions and 21 deletions

View File

@ -37,7 +37,7 @@ namespace ad {
std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
int lift_fv_before_grad = -1;
bool lift_fv_before_grad = true;
DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
@ -76,7 +76,7 @@ void DFunctor::Clear() {
void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
MS_EXCEPTION_IF_NULL(fv);
if (lift_fv_before_grad == 1) {
if (lift_fv_before_grad) {
MS_EXCEPTION_IF_NULL(fv->func_graph());
MS_LOG(EXCEPTION) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv:"
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
@ -446,7 +446,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
// Add grads wrt fv.
const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
if (!is_top_ && free_variables_nodes.size() != 0) {
if (lift_fv_before_grad == 1) {
if (lift_fv_before_grad) {
MS_LOG(EXCEPTION) << "direct fv size is: " << free_variables_nodes.size() << " in " << primal_graph_->ToString()
<< ".";
}
@ -475,7 +475,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
}
AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
if (lift_fv_before_grad == 1) {
if (lift_fv_before_grad) {
MS_LOG(EXCEPTION) << "Lift free variable case: AttachIndirectFvDoutToTape backprop indirect fv "
<< grad_fv->ToString() << " " << primal_graph_->ToString() << ".";
}
@ -517,7 +517,7 @@ void DFunctor::MapMorphism() {
// Set output for tape closure.
AnfNodePtr grad_fv;
if (lift_fv_before_grad == 1) {
if (lift_fv_before_grad) {
grad_fv = AttachFvDoutToTape(NewValueNode(newenv));
} else {
grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));

View File

@ -44,7 +44,7 @@ class DFunctor;
using DFunctorPtr = std::shared_ptr<DFunctor>;
// Flag to control if fv should be lifted before grad. If this lift_fv feature is mature, then this flag can be removed.
extern int lift_fv_before_grad;
extern bool lift_fv_before_grad;
// D Functor's rules to map closure object and morphisms.
class DFunctor : public std::enable_shared_from_this<DFunctor> {

View File

@ -60,16 +60,6 @@ FuncGraphPtr LiftFv(const pipeline::ResourceBasePtr &resource, const FuncGraphPt
}
return opt_fg;
}
bool NeedLiftFv(const FuncGraphPtr &func_graph) {
size_t switch_num = 0;
for (auto node : func_graph->manager()->all_nodes()) {
if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
++switch_num;
}
}
return switch_num > 1;
}
} // namespace
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) {
@ -84,11 +74,8 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
manager_ptr->AddFuncGraph(func_graph);
FuncGraphPtr grad_fg = func_graph;
// Only calculate lift_fv_before_grad once.
if (lift_fv_before_grad == -1) {
lift_fv_before_grad = (common::GetEnv("ENV_DONT_LIFT_FV_BEFORE_GRAD") != "1") && NeedLiftFv(func_graph) ? 1 : 0;
}
if (lift_fv_before_grad == 1 && func_graph->func_graphs_used().size() != 0) {
lift_fv_before_grad = (common::GetEnv("ENV_DONT_LIFT_FV_BEFORE_GRAD") != "1");
if (lift_fv_before_grad && func_graph->func_graphs_used().size() != 0) {
grad_fg = LiftFv(resources, func_graph);
}
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {