fix grad missing due to indirect dependent free morphism

This commit is contained in:
panyifeng 2020-04-03 17:09:04 +08:00
parent d4e51c8f6e
commit 1fb776fe09
5 changed files with 54 additions and 17 deletions

View File

@ -185,19 +185,32 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
return node_adjoint; return node_adjoint;
} }
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
// Do not care about non-CNode
if (!node->isa<CNode>()) {
return false;
}
// Do not care about kPrimReturn
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
return false;
}
auto &users = primal_graph_->manager()->node_users()[node];
// Do not care about isolated morphisms
if (users.empty()) {
return false;
}
// Not free if it's used by some node in primal_graph
bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
auto &user = kv.first;
return user->func_graph() == primal_graph_;
});
return !nonfree;
}
void DFunctor::MapFreeMorphism() { void DFunctor::MapFreeMorphism() {
// Handle cnode not attached to output, that might be refered in other functions. // Handle cnode not attached to output, that might be refered in other functions.
for (auto &node : primal_graph_->nodes()) { for (auto &node : primal_graph_->nodes()) {
auto adjoint = FindAdjoint(node); if (!IsFreeMorphism(node)) {
if (adjoint != nullptr) {
continue;
}
if (!node->isa<CNode>()) {
MS_LOG(DEBUG) << "MapFreeMorphism noncnode not mapped after MapMorphism " << node->ToString() << " "
<< node->type_name() << ".";
continue;
}
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
continue; continue;
} }
MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << "."; MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
@ -256,9 +269,10 @@ void DFunctor::MapMorphism() {
// Set stop_gradient before MapMorphism. // Set stop_gradient before MapMorphism.
BroadCastStopFlag(); BroadCastStopFlag();
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
MapFreeMorphism();
// Handle morphism from output. // Handle morphism from output.
(void)MapMorphism(primal_graph_->output()); (void)MapMorphism(primal_graph_->output());
MapFreeMorphism();
// Construct K for primal_graph_ // Construct K for primal_graph_
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output()); auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
@ -298,9 +312,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
const size_t param_diff = 1; const size_t param_diff = 1;
if (bprop_graph->output()->isa<CNode>() && if (bprop_graph->output()->isa<CNode>() &&
bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) { bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) {
MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope " // It does not matter with the final tangents, just a tip for debugging
<< primal->output()->scope()->name() MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope "
<< " output must be a tuple and output number should be the same with inputs."; << primal->output()->scope()->name()
<< " output must be a tuple and output number should be the same with inputs.";
} }
resources_->manager()->AddFuncGraph(bprop_graph); resources_->manager()->AddFuncGraph(bprop_graph);

View File

@ -61,6 +61,7 @@ class DFunctor {
private: private:
// Map one morphism. // Map one morphism.
AdjointPtr MapMorphism(const AnfNodePtr &morph); AdjointPtr MapMorphism(const AnfNodePtr &morph);
bool IsFreeMorphism(const AnfNodePtr &node);
// Map morphism that's not attached to output. // Map morphism that's not attached to output.
void MapFreeMorphism(); void MapFreeMorphism();
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);

View File

@ -111,7 +111,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
irpass.replace_applicator_, irpass.replace_applicator_,
}); });
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::OptPassConfig grad = opt::OptPassConfig({irpass.inline_, irpass.expand_jprim_}, true); opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);
OptPassGroupMap map_a({{"a_1", a_1}, OptPassGroupMap map_a({{"a_1", a_1},
{"a_2", a_2}, {"a_2", a_2},

View File

@ -304,5 +304,4 @@ class MulAddWithWrongOutputNum(nn.Cell):
def test_grad_mul_add_with_wrong_output_num(): def test_grad_mul_add_with_wrong_output_num():
mul_add = MulAddWithWrongOutputNum() mul_add = MulAddWithWrongOutputNum()
with pytest.raises(RuntimeError): C.grad_all(mul_add)(1, 2)
C.grad_all(mul_add)(1, 2)

View File

@ -15,6 +15,7 @@
""" test_framstruct """ """ test_framstruct """
import pytest import pytest
import numpy as np import numpy as np
import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.ops import composite as C from mindspore.ops import composite as C
@ -706,3 +707,24 @@ def grad_refactor_14(a, b):
return inner1(b) + inner2(a) + inner3(a) return inner1(b) + inner2(a) + inner3(a)
def test_grad_refactor_14(): def test_grad_refactor_14():
assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9) assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
class IfDeferInline(nn.Cell):
def __init__(self, mul_size):
super().__init__()
self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32))
self.mul = P.Mul()
def construct(self, inputs):
x = self.mul(inputs, self.mul_weight)
if True:
x = x
return x
def test_grad_if_defer_inline():
""" test_grad_if_defer_inline """
network = IfDeferInline([128, 96])
network.add_flags(defer_inline=False)
inp = Tensor(np.ones([128, 96]).astype(np.float32))
grads = C.grad_all(network)(inp)
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)