forked from mindspore-Ecosystem/mindspore
fix bprop cache caused error with variable params
This commit is contained in:
parent
268d358a1d
commit
2bef22d8a3
|
@ -92,9 +92,11 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_faked_bprop = false;
|
||||||
auto bprop_fg = GetBprop(prim);
|
auto bprop_fg = GetBprop(prim);
|
||||||
if (bprop_fg == nullptr) {
|
if (bprop_fg == nullptr) {
|
||||||
bprop_fg = FakeBprop(value_node, resources);
|
bprop_fg = FakeBprop(value_node, resources);
|
||||||
|
is_faked_bprop = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto expanded_fg = BpropToK(prim, bprop_fg);
|
auto expanded_fg = BpropToK(prim, bprop_fg);
|
||||||
|
@ -104,8 +106,11 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
||||||
<< trace::GetDebugInfo(bprop_fg->debug_info());
|
<< trace::GetDebugInfo(bprop_fg->debug_info());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set bprop_g graph cache
|
// To support primitives with variable params, do not cache faked bprop
|
||||||
bprop_registry_[prim] = expanded_fg;
|
if (!is_faked_bprop) {
|
||||||
|
// Set bprop_g graph cache
|
||||||
|
bprop_registry_[prim] = expanded_fg;
|
||||||
|
}
|
||||||
return expanded_fg;
|
return expanded_fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -366,3 +366,15 @@ def test_stop_gradient_11():
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)),
|
bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)),
|
||||||
Tensor(np.ones([2]).astype(np.float32)))
|
Tensor(np.ones([2]).astype(np.float32)))
|
||||||
|
|
||||||
|
def test_stop_print():
|
||||||
|
class StopPrint(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(StopPrint, self).__init__()
|
||||||
|
self.printm = P.Print()
|
||||||
|
def construct(self, x, y):
|
||||||
|
self.printm("StopPrint", x)
|
||||||
|
self.printm(y)
|
||||||
|
return x, y
|
||||||
|
C.grad_all(StopPrint())(Tensor(np.ones([2]).astype(np.float32)),
|
||||||
|
Tensor(np.ones([2]).astype(np.float32)))
|
||||||
|
|
Loading…
Reference in New Issue