If a fv cnode is not mapped but belong to primal_graph_, then don't propagate this fv to caller, it should mapped instead

This commit is contained in:
zhousiyi 2021-03-25 09:32:44 +00:00
parent 5b59277158
commit 74d500a756
2 changed files with 52 additions and 13 deletions

View File

@ -91,21 +91,32 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
if (fv_adjoint == anfnode_to_adjoin_.end()) {
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
<< " " << fv->ToString() << ".";
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
auto parent_adjoint = FindAdjoint(fv);
AdjointPtr adjoint = nullptr;
if (parent_adjoint != nullptr) {
adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
} else {
MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
if (fv->func_graph() == primal_graph_) {
// If this fv is not mapped by MapMorphism because of cnode order, then map it now.
(void)MapMorphism(fv);
fv_adjoint = anfnode_to_adjoin_.find(fv);
if (fv_adjoint == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " "
<< fv->ToString() << ".";
}
anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
} else {
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
auto parent_adjoint = FindAdjoint(fv);
AdjointPtr adjoint = nullptr;
if (parent_adjoint != nullptr) {
adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
} else {
MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
}
anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
}
}
}
auto fv_node = fv_adjoint->second->k();

View File

@ -110,3 +110,31 @@ def test_second_grad_with_j_primitive():
second_grad = SinGradSec(first_grad)
x = Tensor(np.array([1.0], dtype=np.float32))
second_grad(x)
# A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode;
def test_ad_fv_cnode_order():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
# cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and
# BackPropagateFv as MapMorphism is started from output node and from left to right order.
def construct(self, x, y):
def first_level():
xay = x + y
def second_level():
return xay
return second_level() + xay
return first_level()
input_x = Tensor(np.array([1.0], dtype=np.float32))
input_y = Tensor(np.array([2.0], dtype=np.float32))
net = Net()
net.add_flags_recursive(defer_inline=True)
grad_net = grad_all(net)
grad_net(input_x, input_y)