!47467 add bp of ParameterizedTruncatedNormal

Merge pull request !47467 from zong_shuai/ptn_debug
This commit is contained in:
i-robot 2023-01-06 07:29:13 +00:00 committed by Gitee
commit 82b0e9399b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 10 additions and 0 deletions

View File

@ -50,6 +50,7 @@ from mindspore.ops.operations.array_ops import MaskedScatter
from mindspore.ops.operations.array_ops import CountNonZero
from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
from mindspore.ops.operations.random_ops import LogNormalReverse
from mindspore.ops.operations.random_ops import ParameterizedTruncatedNormal
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import functional as F
from mindspore.ops import operations as P
@ -378,6 +379,15 @@ def get_bprop_log_normal_reverse(self):
return bprop
@bprop_getters.register(ParameterizedTruncatedNormal)
def get_bprop_parameterized_truncated_normal(self):
"""Grad definition for `ParameterizedTruncatedNormal` operation."""
def bprop(shape, mean, stdevs, min_val, max_val, out, dout):
return (zeros_like(shape), zeros_like(mean), zeros_like(stdevs), zeros_like(min_val), zeros_like(max_val))
return bprop
@bprop_getters.register(P.TensorScatterMax)
def get_bprop_tensor_scatter_max(self):
"""Generate bprop for TensorScatterMax"""