If a fv cnode is not mapped but belong to primal_graph_, then don't propagate this fv to caller, it should mapped instead
This commit is contained in:
parent
5b59277158
commit
74d500a756
|
@ -91,21 +91,32 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
|||
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
||||
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
|
||||
<< " " << fv->ToString() << ".";
|
||||
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
||||
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
|
||||
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
|
||||
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
||||
auto parent_adjoint = FindAdjoint(fv);
|
||||
AdjointPtr adjoint = nullptr;
|
||||
if (parent_adjoint != nullptr) {
|
||||
adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
|
||||
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
||||
adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
|
||||
|
||||
if (fv->func_graph() == primal_graph_) {
|
||||
// If this fv is not mapped by MapMorphism because of cnode order, then map it now.
|
||||
(void)MapMorphism(fv);
|
||||
fv_adjoint = anfnode_to_adjoin_.find(fv);
|
||||
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " "
|
||||
<< fv->ToString() << ".";
|
||||
}
|
||||
anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
|
||||
} else {
|
||||
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
||||
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
|
||||
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
|
||||
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
||||
auto parent_adjoint = FindAdjoint(fv);
|
||||
AdjointPtr adjoint = nullptr;
|
||||
if (parent_adjoint != nullptr) {
|
||||
adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
|
||||
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
||||
adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
|
||||
}
|
||||
anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
|
||||
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto fv_node = fv_adjoint->second->k();
|
||||
|
|
|
@ -110,3 +110,31 @@ def test_second_grad_with_j_primitive():
|
|||
second_grad = SinGradSec(first_grad)
|
||||
x = Tensor(np.array([1.0], dtype=np.float32))
|
||||
second_grad(x)
|
||||
|
||||
|
||||
# A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode;
|
||||
def test_ad_fv_cnode_order():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
# cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and
|
||||
# BackPropagateFv as MapMorphism is started from output node and from left to right order.
|
||||
def construct(self, x, y):
|
||||
def first_level():
|
||||
xay = x + y
|
||||
|
||||
def second_level():
|
||||
return xay
|
||||
|
||||
return second_level() + xay
|
||||
return first_level()
|
||||
|
||||
input_x = Tensor(np.array([1.0], dtype=np.float32))
|
||||
input_y = Tensor(np.array([2.0], dtype=np.float32))
|
||||
|
||||
net = Net()
|
||||
net.add_flags_recursive(defer_inline=True)
|
||||
grad_net = grad_all(net)
|
||||
grad_net(input_x, input_y)
|
||||
|
|
Loading…
Reference in New Issue