diff --git a/mindspore/ccsrc/optimizer/ad/kprim.cc b/mindspore/ccsrc/optimizer/ad/kprim.cc index 4576cc1ea99..2c8ddbfa82e 100644 --- a/mindspore/ccsrc/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/optimizer/ad/kprim.cc @@ -92,9 +92,11 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R return nullptr; } + bool is_faked_bprop = false; auto bprop_fg = GetBprop(prim); if (bprop_fg == nullptr) { bprop_fg = FakeBprop(value_node, resources); + is_faked_bprop = true; } auto expanded_fg = BpropToK(prim, bprop_fg); @@ -104,8 +106,11 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R << trace::GetDebugInfo(bprop_fg->debug_info()); } - // Set bprop_g graph cache - bprop_registry_[prim] = expanded_fg; + // To support primitives with variable params, do not cache faked bprop + if (!is_faked_bprop) { + // Set bprop_g graph cache + bprop_registry_[prim] = expanded_fg; + } return expanded_fg; } diff --git a/tests/ut/python/pynative_mode/test_stop_gradient.py b/tests/ut/python/pynative_mode/test_stop_gradient.py index b274b3988ad..a26d635aadb 100644 --- a/tests/ut/python/pynative_mode/test_stop_gradient.py +++ b/tests/ut/python/pynative_mode/test_stop_gradient.py @@ -366,3 +366,15 @@ def test_stop_gradient_11(): with pytest.raises(RuntimeError): bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32))) + +def test_stop_print(): + class StopPrint(nn.Cell): + def __init__(self): + super(StopPrint, self).__init__() + self.printm = P.Print() + def construct(self, x, y): + self.printm("StopPrint", x) + self.printm(y) + return x, y + C.grad_all(StopPrint())(Tensor(np.ones([2]).astype(np.float32)), + Tensor(np.ones([2]).astype(np.float32)))