!135 fix grad missing due to indirect dependent free morphism

Merge pull request !135 from penn/fix_free_morphism_error
This commit is contained in:
mindspore-ci-bot 2020-04-06 10:17:13 +08:00 committed by Gitee
commit 7a367af9c6
5 changed files with 54 additions and 17 deletions

View File

@ -185,19 +185,32 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
return node_adjoint;
}
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
// Do not care about non-CNode
if (!node->isa<CNode>()) {
return false;
}
// Do not care about kPrimReturn
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
return false;
}
auto &users = primal_graph_->manager()->node_users()[node];
// Do not care about isolated morphisms
if (users.empty()) {
return false;
}
// Not free if it's used by some node in primal_graph
bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
auto &user = kv.first;
return user->func_graph() == primal_graph_;
});
return !nonfree;
}
void DFunctor::MapFreeMorphism() {
// Handle cnode not attached to output, that might be refered in other functions.
for (auto &node : primal_graph_->nodes()) {
auto adjoint = FindAdjoint(node);
if (adjoint != nullptr) {
continue;
}
if (!node->isa<CNode>()) {
MS_LOG(DEBUG) << "MapFreeMorphism noncnode not mapped after MapMorphism " << node->ToString() << " "
<< node->type_name() << ".";
continue;
}
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
if (!IsFreeMorphism(node)) {
continue;
}
MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
@ -256,9 +269,10 @@ void DFunctor::MapMorphism() {
// Set stop_gradient before MapMorphism.
BroadCastStopFlag();
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
MapFreeMorphism();
// Handle morphism from output.
(void)MapMorphism(primal_graph_->output());
MapFreeMorphism();
// Construct K for primal_graph_
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
@ -298,9 +312,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
const size_t param_diff = 1;
if (bprop_graph->output()->isa<CNode>() &&
bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) {
MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope "
<< primal->output()->scope()->name()
<< " output must be a tuple and output number should be the same with inputs.";
// It does not matter with the final tangents, just a tip for debugging
MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope "
<< primal->output()->scope()->name()
<< " output must be a tuple and output number should be the same with inputs.";
}
resources_->manager()->AddFuncGraph(bprop_graph);

View File

@ -61,6 +61,7 @@ class DFunctor {
private:
// Map one morphism.
AdjointPtr MapMorphism(const AnfNodePtr &morph);
bool IsFreeMorphism(const AnfNodePtr &node);
// Map morphism that's not attached to output.
void MapFreeMorphism();
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);

View File

@ -111,7 +111,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
irpass.replace_applicator_,
});
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::OptPassConfig grad = opt::OptPassConfig({irpass.inline_, irpass.expand_jprim_}, true);
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);
OptPassGroupMap map_a({{"a_1", a_1},
{"a_2", a_2},

View File

@ -304,5 +304,4 @@ class MulAddWithWrongOutputNum(nn.Cell):
def test_grad_mul_add_with_wrong_output_num():
mul_add = MulAddWithWrongOutputNum()
with pytest.raises(RuntimeError):
C.grad_all(mul_add)(1, 2)
C.grad_all(mul_add)(1, 2)

View File

@ -15,6 +15,7 @@
""" test_framstruct """
import pytest
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore.ops import composite as C
@ -706,3 +707,24 @@ def grad_refactor_14(a, b):
return inner1(b) + inner2(a) + inner3(a)
def test_grad_refactor_14():
assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
class IfDeferInline(nn.Cell):
def __init__(self, mul_size):
super().__init__()
self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32))
self.mul = P.Mul()
def construct(self, inputs):
x = self.mul(inputs, self.mul_weight)
if True:
x = x
return x
def test_grad_if_defer_inline():
""" test_grad_if_defer_inline """
network = IfDeferInline([128, 96])
network.add_flags(defer_inline=False)
inp = Tensor(np.ones([128, 96]).astype(np.float32))
grads = C.grad_all(network)(inp)
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)