forked from mindspore-Ecosystem/mindspore
!14097 map a fv cnode if it is not mapped and belong to primal_graph_
From: @xychow Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
5d0490909d
|
@ -91,6 +91,16 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
||||||
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
||||||
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
|
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
|
||||||
<< " " << fv->ToString() << ".";
|
<< " " << fv->ToString() << ".";
|
||||||
|
|
||||||
|
if (fv->func_graph() == primal_graph_) {
|
||||||
|
// If this fv is not mapped by MapMorphism because of cnode order, then map it now.
|
||||||
|
(void)MapMorphism(fv);
|
||||||
|
fv_adjoint = anfnode_to_adjoin_.find(fv);
|
||||||
|
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " "
|
||||||
|
<< fv->ToString() << ".";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
||||||
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
|
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
|
||||||
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
|
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
|
||||||
|
@ -108,6 +118,7 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
||||||
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
auto fv_node = fv_adjoint->second->k();
|
auto fv_node = fv_adjoint->second->k();
|
||||||
auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
|
auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
|
||||||
CNodePtr embed_node, default_val_node;
|
CNodePtr embed_node, default_val_node;
|
||||||
|
|
|
@ -110,3 +110,31 @@ def test_second_grad_with_j_primitive():
|
||||||
second_grad = SinGradSec(first_grad)
|
second_grad = SinGradSec(first_grad)
|
||||||
x = Tensor(np.array([1.0], dtype=np.float32))
|
x = Tensor(np.array([1.0], dtype=np.float32))
|
||||||
second_grad(x)
|
second_grad(x)
|
||||||
|
|
||||||
|
|
||||||
|
# A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode;
|
||||||
|
def test_ad_fv_cnode_order():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
# cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and
|
||||||
|
# BackPropagateFv as MapMorphism is started from output node and from left to right order.
|
||||||
|
def construct(self, x, y):
|
||||||
|
def first_level():
|
||||||
|
xay = x + y
|
||||||
|
|
||||||
|
def second_level():
|
||||||
|
return xay
|
||||||
|
|
||||||
|
return second_level() + xay
|
||||||
|
return first_level()
|
||||||
|
|
||||||
|
input_x = Tensor(np.array([1.0], dtype=np.float32))
|
||||||
|
input_y = Tensor(np.array([2.0], dtype=np.float32))
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
net.add_flags_recursive(defer_inline=True)
|
||||||
|
grad_net = grad_all(net)
|
||||||
|
grad_net(input_x, input_y)
|
||||||
|
|
Loading…
Reference in New Issue