!14097 map a fv cnode if it is not mapped and belong to primal_graph_

From: @xychow
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-03-26 09:20:42 +08:00 committed by Gitee
commit 5d0490909d
2 changed files with 52 additions and 13 deletions

View File

@ -91,6 +91,16 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
if (fv_adjoint == anfnode_to_adjoin_.end()) { if (fv_adjoint == anfnode_to_adjoin_.end()) {
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
<< " " << fv->ToString() << "."; << " " << fv->ToString() << ".";
if (fv->func_graph() == primal_graph_) {
// If this fv is not mapped by MapMorphism because of cnode order, then map it now.
(void)MapMorphism(fv);
fv_adjoint = anfnode_to_adjoin_.find(fv);
if (fv_adjoint == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " "
<< fv->ToString() << ".";
}
} else {
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) { if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv " MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
@ -108,6 +118,7 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
} }
} }
}
auto fv_node = fv_adjoint->second->k(); auto fv_node = fv_adjoint->second->k();
auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node); auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
CNodePtr embed_node, default_val_node; CNodePtr embed_node, default_val_node;

View File

@ -110,3 +110,31 @@ def test_second_grad_with_j_primitive():
second_grad = SinGradSec(first_grad) second_grad = SinGradSec(first_grad)
x = Tensor(np.array([1.0], dtype=np.float32)) x = Tensor(np.array([1.0], dtype=np.float32))
second_grad(x) second_grad(x)
# A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode;
def test_ad_fv_cnode_order():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
# cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and
# BackPropagateFv as MapMorphism is started from output node and from left to right order.
def construct(self, x, y):
def first_level():
xay = x + y
def second_level():
return xay
return second_level() + xay
return first_level()
input_x = Tensor(np.array([1.0], dtype=np.float32))
input_y = Tensor(np.array([2.0], dtype=np.float32))
net = Net()
net.add_flags_recursive(defer_inline=True)
grad_net = grad_all(net)
grad_net(input_x, input_y)