diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py index 67629df99b5..bc03fc551ef 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py @@ -50,6 +50,7 @@ from mindspore.ops.operations.array_ops import MaskedScatter from mindspore.ops.operations.array_ops import CountNonZero from mindspore.ops.operations._grad_ops import StridedSliceV2Grad from mindspore.ops.operations.random_ops import LogNormalReverse +from mindspore.ops.operations.random_ops import ParameterizedTruncatedNormal from mindspore.ops.operations import _inner_ops as inner from mindspore.ops import functional as F from mindspore.ops import operations as P @@ -378,6 +379,15 @@ def get_bprop_log_normal_reverse(self): return bprop +@bprop_getters.register(ParameterizedTruncatedNormal) +def get_bprop_parameterized_truncated_normal(self): + """Grad definition for `ParameterizedTruncatedNormal` operation.""" + def bprop(shape, mean, stdevs, min_val, max_val, out, dout): + return (zeros_like(shape), zeros_like(mean), zeros_like(stdevs), zeros_like(min_val), zeros_like(max_val)) + + return bprop + + @bprop_getters.register(P.TensorScatterMax) def get_bprop_tensor_scatter_max(self): """Generate bprop for TensorScatterMax"""