From e3d62123e53442580e8ef96ae69620d51c75bc72 Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Sat, 24 Oct 2020 21:42:21 +0800 Subject: [PATCH] Bind the fvs with single fg bprop for switch layer. --- mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index c34cc8a4825..4230da26a0a 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -130,6 +130,7 @@ void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNode if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << "."; } + std::unordered_map node_to_fg; auto tuple_graphs = input->cast(); for (size_t i = 1; i < tuple_graphs->size(); ++i) { auto graph = tuple_graphs->input(i); @@ -145,11 +146,19 @@ void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNode } // Consider direct and indirect fvs. for (auto fv : func_graph->free_variables_nodes()) { + if (node_to_fg.find(fv) != node_to_fg.end()) { + continue; + } + node_to_fg[fv] = func_graph; BackPropagateFv(fv, env); } for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " " << indirect_fv.first->ToString() << "."; + if (node_to_fg.find(indirect_fv.first) != node_to_fg.end()) { + continue; + } + node_to_fg[indirect_fv.first] = func_graph; BackPropagateFv(indirect_fv.first, env); } }