forked from mindspore-Ecosystem/mindspore
fix grad missing due to indirect dependent free morphism
This commit is contained in:
parent
d4e51c8f6e
commit
1fb776fe09
|
@ -185,19 +185,32 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
||||||
return node_adjoint;
|
return node_adjoint;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
|
||||||
|
// Do not care about non-CNode
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Do not care about kPrimReturn
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto &users = primal_graph_->manager()->node_users()[node];
|
||||||
|
// Do not care about isolated morphisms
|
||||||
|
if (users.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Not free if it's used by some node in primal_graph
|
||||||
|
bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
|
||||||
|
auto &user = kv.first;
|
||||||
|
return user->func_graph() == primal_graph_;
|
||||||
|
});
|
||||||
|
return !nonfree;
|
||||||
|
}
|
||||||
|
|
||||||
void DFunctor::MapFreeMorphism() {
|
void DFunctor::MapFreeMorphism() {
|
||||||
// Handle cnode not attached to output, that might be refered in other functions.
|
// Handle cnode not attached to output, that might be refered in other functions.
|
||||||
for (auto &node : primal_graph_->nodes()) {
|
for (auto &node : primal_graph_->nodes()) {
|
||||||
auto adjoint = FindAdjoint(node);
|
if (!IsFreeMorphism(node)) {
|
||||||
if (adjoint != nullptr) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!node->isa<CNode>()) {
|
|
||||||
MS_LOG(DEBUG) << "MapFreeMorphism noncnode not mapped after MapMorphism " << node->ToString() << " "
|
|
||||||
<< node->type_name() << ".";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
|
MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
|
||||||
|
@ -256,9 +269,10 @@ void DFunctor::MapMorphism() {
|
||||||
// Set stop_gradient before MapMorphism.
|
// Set stop_gradient before MapMorphism.
|
||||||
BroadCastStopFlag();
|
BroadCastStopFlag();
|
||||||
|
|
||||||
|
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
|
||||||
|
MapFreeMorphism();
|
||||||
// Handle morphism from output.
|
// Handle morphism from output.
|
||||||
(void)MapMorphism(primal_graph_->output());
|
(void)MapMorphism(primal_graph_->output());
|
||||||
MapFreeMorphism();
|
|
||||||
|
|
||||||
// Construct K for primal_graph_
|
// Construct K for primal_graph_
|
||||||
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
|
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
|
||||||
|
@ -298,9 +312,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
|
||||||
const size_t param_diff = 1;
|
const size_t param_diff = 1;
|
||||||
if (bprop_graph->output()->isa<CNode>() &&
|
if (bprop_graph->output()->isa<CNode>() &&
|
||||||
bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) {
|
bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) {
|
||||||
MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope "
|
// It does not matter with the final tangents, just a tip for debugging
|
||||||
<< primal->output()->scope()->name()
|
MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope "
|
||||||
<< " output must be a tuple and output number should be the same with inputs.";
|
<< primal->output()->scope()->name()
|
||||||
|
<< " output must be a tuple and output number should be the same with inputs.";
|
||||||
}
|
}
|
||||||
resources_->manager()->AddFuncGraph(bprop_graph);
|
resources_->manager()->AddFuncGraph(bprop_graph);
|
||||||
|
|
||||||
|
|
|
@ -61,6 +61,7 @@ class DFunctor {
|
||||||
private:
|
private:
|
||||||
// Map one morphism.
|
// Map one morphism.
|
||||||
AdjointPtr MapMorphism(const AnfNodePtr &morph);
|
AdjointPtr MapMorphism(const AnfNodePtr &morph);
|
||||||
|
bool IsFreeMorphism(const AnfNodePtr &node);
|
||||||
// Map morphism that's not attached to output.
|
// Map morphism that's not attached to output.
|
||||||
void MapFreeMorphism();
|
void MapFreeMorphism();
|
||||||
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
|
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
|
||||||
|
|
|
@ -111,7 +111,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
|
||||||
irpass.replace_applicator_,
|
irpass.replace_applicator_,
|
||||||
});
|
});
|
||||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||||
opt::OptPassConfig grad = opt::OptPassConfig({irpass.inline_, irpass.expand_jprim_}, true);
|
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);
|
||||||
|
|
||||||
OptPassGroupMap map_a({{"a_1", a_1},
|
OptPassGroupMap map_a({{"a_1", a_1},
|
||||||
{"a_2", a_2},
|
{"a_2", a_2},
|
||||||
|
|
|
@ -304,5 +304,4 @@ class MulAddWithWrongOutputNum(nn.Cell):
|
||||||
|
|
||||||
def test_grad_mul_add_with_wrong_output_num():
|
def test_grad_mul_add_with_wrong_output_num():
|
||||||
mul_add = MulAddWithWrongOutputNum()
|
mul_add = MulAddWithWrongOutputNum()
|
||||||
with pytest.raises(RuntimeError):
|
C.grad_all(mul_add)(1, 2)
|
||||||
C.grad_all(mul_add)(1, 2)
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
""" test_framstruct """
|
""" test_framstruct """
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
|
@ -706,3 +707,24 @@ def grad_refactor_14(a, b):
|
||||||
return inner1(b) + inner2(a) + inner3(a)
|
return inner1(b) + inner2(a) + inner3(a)
|
||||||
def test_grad_refactor_14():
|
def test_grad_refactor_14():
|
||||||
assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
|
assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
|
||||||
|
|
||||||
|
|
||||||
|
class IfDeferInline(nn.Cell):
|
||||||
|
def __init__(self, mul_size):
|
||||||
|
super().__init__()
|
||||||
|
self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32))
|
||||||
|
self.mul = P.Mul()
|
||||||
|
|
||||||
|
def construct(self, inputs):
|
||||||
|
x = self.mul(inputs, self.mul_weight)
|
||||||
|
if True:
|
||||||
|
x = x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test_grad_if_defer_inline():
|
||||||
|
""" test_grad_if_defer_inline """
|
||||||
|
network = IfDeferInline([128, 96])
|
||||||
|
network.add_flags(defer_inline=False)
|
||||||
|
inp = Tensor(np.ones([128, 96]).astype(np.float32))
|
||||||
|
grads = C.grad_all(network)(inp)
|
||||||
|
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
|
||||||
|
|
Loading…
Reference in New Issue