fix grad missing due to indirect dependent free morphism
This commit is contained in:
parent
d4e51c8f6e
commit
1fb776fe09
|
@ -185,19 +185,32 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|||
return node_adjoint;
|
||||
}
|
||||
|
||||
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
|
||||
// Do not care about non-CNode
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
// Do not care about kPrimReturn
|
||||
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
|
||||
return false;
|
||||
}
|
||||
auto &users = primal_graph_->manager()->node_users()[node];
|
||||
// Do not care about isolated morphisms
|
||||
if (users.empty()) {
|
||||
return false;
|
||||
}
|
||||
// Not free if it's used by some node in primal_graph
|
||||
bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
|
||||
auto &user = kv.first;
|
||||
return user->func_graph() == primal_graph_;
|
||||
});
|
||||
return !nonfree;
|
||||
}
|
||||
|
||||
void DFunctor::MapFreeMorphism() {
|
||||
// Handle cnode not attached to output, that might be refered in other functions.
|
||||
for (auto &node : primal_graph_->nodes()) {
|
||||
auto adjoint = FindAdjoint(node);
|
||||
if (adjoint != nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(DEBUG) << "MapFreeMorphism noncnode not mapped after MapMorphism " << node->ToString() << " "
|
||||
<< node->type_name() << ".";
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
|
||||
if (!IsFreeMorphism(node)) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
|
||||
|
@ -256,9 +269,10 @@ void DFunctor::MapMorphism() {
|
|||
// Set stop_gradient before MapMorphism.
|
||||
BroadCastStopFlag();
|
||||
|
||||
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
|
||||
MapFreeMorphism();
|
||||
// Handle morphism from output.
|
||||
(void)MapMorphism(primal_graph_->output());
|
||||
MapFreeMorphism();
|
||||
|
||||
// Construct K for primal_graph_
|
||||
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
|
||||
|
@ -298,9 +312,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
|
|||
const size_t param_diff = 1;
|
||||
if (bprop_graph->output()->isa<CNode>() &&
|
||||
bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) {
|
||||
MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope "
|
||||
<< primal->output()->scope()->name()
|
||||
<< " output must be a tuple and output number should be the same with inputs.";
|
||||
// It does not matter with the final tangents, just a tip for debugging
|
||||
MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope "
|
||||
<< primal->output()->scope()->name()
|
||||
<< " output must be a tuple and output number should be the same with inputs.";
|
||||
}
|
||||
resources_->manager()->AddFuncGraph(bprop_graph);
|
||||
|
||||
|
|
|
@ -61,6 +61,7 @@ class DFunctor {
|
|||
private:
|
||||
// Map one morphism.
|
||||
AdjointPtr MapMorphism(const AnfNodePtr &morph);
|
||||
bool IsFreeMorphism(const AnfNodePtr &node);
|
||||
// Map morphism that's not attached to output.
|
||||
void MapFreeMorphism();
|
||||
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
|
||||
|
|
|
@ -111,7 +111,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
|
|||
irpass.replace_applicator_,
|
||||
});
|
||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||
opt::OptPassConfig grad = opt::OptPassConfig({irpass.inline_, irpass.expand_jprim_}, true);
|
||||
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);
|
||||
|
||||
OptPassGroupMap map_a({{"a_1", a_1},
|
||||
{"a_2", a_2},
|
||||
|
|
|
@ -304,5 +304,4 @@ class MulAddWithWrongOutputNum(nn.Cell):
|
|||
|
||||
def test_grad_mul_add_with_wrong_output_num():
|
||||
mul_add = MulAddWithWrongOutputNum()
|
||||
with pytest.raises(RuntimeError):
|
||||
C.grad_all(mul_add)(1, 2)
|
||||
C.grad_all(mul_add)(1, 2)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
""" test_framstruct """
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.ops import composite as C
|
||||
|
@ -706,3 +707,24 @@ def grad_refactor_14(a, b):
|
|||
return inner1(b) + inner2(a) + inner3(a)
|
||||
def test_grad_refactor_14():
|
||||
assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
|
||||
|
||||
|
||||
class IfDeferInline(nn.Cell):
|
||||
def __init__(self, mul_size):
|
||||
super().__init__()
|
||||
self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32))
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, inputs):
|
||||
x = self.mul(inputs, self.mul_weight)
|
||||
if True:
|
||||
x = x
|
||||
return x
|
||||
|
||||
def test_grad_if_defer_inline():
|
||||
""" test_grad_if_defer_inline """
|
||||
network = IfDeferInline([128, 96])
|
||||
network.add_flags(defer_inline=False)
|
||||
inp = Tensor(np.ones([128, 96]).astype(np.float32))
|
||||
grads = C.grad_all(network)(inp)
|
||||
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
|
||||
|
|
Loading…
Reference in New Issue