forked from mindspore-Ecosystem/mindspore
!14097 map a fv cnode if it is not mapped and belong to primal_graph_
From: @xychow Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
5d0490909d
|
@ -91,21 +91,32 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
||||||
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
||||||
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
|
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
|
||||||
<< " " << fv->ToString() << ".";
|
<< " " << fv->ToString() << ".";
|
||||||
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
|
||||||
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
|
if (fv->func_graph() == primal_graph_) {
|
||||||
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
|
// If this fv is not mapped by MapMorphism because of cnode order, then map it now.
|
||||||
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
(void)MapMorphism(fv);
|
||||||
auto parent_adjoint = FindAdjoint(fv);
|
fv_adjoint = anfnode_to_adjoin_.find(fv);
|
||||||
AdjointPtr adjoint = nullptr;
|
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
||||||
if (parent_adjoint != nullptr) {
|
MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " "
|
||||||
adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
|
<< fv->ToString() << ".";
|
||||||
} else {
|
|
||||||
MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
|
|
||||||
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
|
||||||
adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
|
|
||||||
}
|
}
|
||||||
anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
|
} else {
|
||||||
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
||||||
|
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
|
||||||
|
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
|
||||||
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
||||||
|
auto parent_adjoint = FindAdjoint(fv);
|
||||||
|
AdjointPtr adjoint = nullptr;
|
||||||
|
if (parent_adjoint != nullptr) {
|
||||||
|
adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
|
||||||
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
||||||
|
adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
|
||||||
|
}
|
||||||
|
anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
|
||||||
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto fv_node = fv_adjoint->second->k();
|
auto fv_node = fv_adjoint->second->k();
|
||||||
|
|
|
@ -110,3 +110,31 @@ def test_second_grad_with_j_primitive():
|
||||||
second_grad = SinGradSec(first_grad)
|
second_grad = SinGradSec(first_grad)
|
||||||
x = Tensor(np.array([1.0], dtype=np.float32))
|
x = Tensor(np.array([1.0], dtype=np.float32))
|
||||||
second_grad(x)
|
second_grad(x)
|
||||||
|
|
||||||
|
|
||||||
|
# A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode;
|
||||||
|
def test_ad_fv_cnode_order():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
# cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and
|
||||||
|
# BackPropagateFv as MapMorphism is started from output node and from left to right order.
|
||||||
|
def construct(self, x, y):
|
||||||
|
def first_level():
|
||||||
|
xay = x + y
|
||||||
|
|
||||||
|
def second_level():
|
||||||
|
return xay
|
||||||
|
|
||||||
|
return second_level() + xay
|
||||||
|
return first_level()
|
||||||
|
|
||||||
|
input_x = Tensor(np.array([1.0], dtype=np.float32))
|
||||||
|
input_y = Tensor(np.array([2.0], dtype=np.float32))
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
net.add_flags_recursive(defer_inline=True)
|
||||||
|
grad_net = grad_all(net)
|
||||||
|
grad_net(input_x, input_y)
|
||||||
|
|
Loading…
Reference in New Issue