From 74d500a75634f21c4ada04148915d417b64df798 Mon Sep 17 00:00:00 2001 From: zhousiyi Date: Thu, 25 Mar 2021 09:32:44 +0000 Subject: [PATCH] If a fv cnode is not mapped but belong to primal_graph_, then don't propagate this fv to caller, it should mapped instead --- .../ccsrc/frontend/optimizer/ad/dfunctor.cc | 37 ++++++++++++------- tests/ut/python/optimizer/test_auto_grad.py | 28 ++++++++++++++ 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index ca825519f97..6e67f3f9a8b 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -91,21 +91,32 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { if (fv_adjoint == anfnode_to_adjoin_.end()) { MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " " << fv->ToString() << "."; - fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); - if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) { - MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv " - << fv->func_graph()->ToString() << " " << fv->ToString() << "."; - auto parent_adjoint = FindAdjoint(fv); - AdjointPtr adjoint = nullptr; - if (parent_adjoint != nullptr) { - adjoint = std::make_shared(fv, parent_adjoint->k(), tape_); - } else { - MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole " - << fv->func_graph()->ToString() << " " << fv->ToString() << "."; - adjoint = std::make_shared(fv, nullptr, tape_); + + if (fv->func_graph() == primal_graph_) { + // If this fv is not mapped by MapMorphism because of cnode order, then map it now. + (void)MapMorphism(fv); + fv_adjoint = anfnode_to_adjoin_.find(fv); + if (fv_adjoint == anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " " + << fv->ToString() << "."; } - anfnode_to_adjoin_indirect_fv_[fv] = adjoint; + } else { fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); + if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) { + MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv " + << fv->func_graph()->ToString() << " " << fv->ToString() << "."; + auto parent_adjoint = FindAdjoint(fv); + AdjointPtr adjoint = nullptr; + if (parent_adjoint != nullptr) { + adjoint = std::make_shared(fv, parent_adjoint->k(), tape_); + } else { + MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole " + << fv->func_graph()->ToString() << " " << fv->ToString() << "."; + adjoint = std::make_shared(fv, nullptr, tape_); + } + anfnode_to_adjoin_indirect_fv_[fv] = adjoint; + fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); + } } } auto fv_node = fv_adjoint->second->k(); diff --git a/tests/ut/python/optimizer/test_auto_grad.py b/tests/ut/python/optimizer/test_auto_grad.py index d9d3f5b176e..5c84838478a 100644 --- a/tests/ut/python/optimizer/test_auto_grad.py +++ b/tests/ut/python/optimizer/test_auto_grad.py @@ -110,3 +110,31 @@ def test_second_grad_with_j_primitive(): second_grad = SinGradSec(first_grad) x = Tensor(np.array([1.0], dtype=np.float32)) second_grad(x) + + +# A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode; +def test_ad_fv_cnode_order(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + # cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and + # BackPropagateFv as MapMorphism is started from output node and from left to right order. + def construct(self, x, y): + def first_level(): + xay = x + y + + def second_level(): + return xay + + return second_level() + xay + return first_level() + + input_x = Tensor(np.array([1.0], dtype=np.float32)) + input_y = Tensor(np.array([2.0], dtype=np.float32)) + + net = Net() + net.add_flags_recursive(defer_inline=True) + grad_net = grad_all(net) + grad_net(input_x, input_y)