!7731 Bind the fvs with single fg bprop for switch layer.

Merge pull request !7731 from 张清华/grad_opt
This commit is contained in:
mindspore-ci-bot 2020-10-26 14:49:53 +08:00 committed by Gitee
commit 9edb3abdfd
1 changed files with 9 additions and 0 deletions

View File

@ -130,6 +130,7 @@ void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNode
if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << "."; MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
} }
std::unordered_map<AnfNodePtr, FuncGraphPtr> node_to_fg;
auto tuple_graphs = input->cast<CNodePtr>(); auto tuple_graphs = input->cast<CNodePtr>();
for (size_t i = 1; i < tuple_graphs->size(); ++i) { for (size_t i = 1; i < tuple_graphs->size(); ++i) {
auto graph = tuple_graphs->input(i); auto graph = tuple_graphs->input(i);
@ -145,11 +146,19 @@ void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNode
} }
// Consider direct and indirect fvs. // Consider direct and indirect fvs.
for (auto fv : func_graph->free_variables_nodes()) { for (auto fv : func_graph->free_variables_nodes()) {
if (node_to_fg.find(fv) != node_to_fg.end()) {
continue;
}
node_to_fg[fv] = func_graph;
BackPropagateFv(fv, env); BackPropagateFv(fv, env);
} }
for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " " MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " "
<< indirect_fv.first->ToString() << "."; << indirect_fv.first->ToString() << ".";
if (node_to_fg.find(indirect_fv.first) != node_to_fg.end()) {
continue;
}
node_to_fg[indirect_fv.first] = func_graph;
BackPropagateFv(indirect_fv.first, env); BackPropagateFv(indirect_fv.first, env);
} }
} }