forked from mindspore-Ecosystem/mindspore
!47467 add bp of ParameterizedTruncatedNormal
Merge pull request !47467 from zong_shuai/ptn_debug
This commit is contained in:
commit
82b0e9399b
|
@ -50,6 +50,7 @@ from mindspore.ops.operations.array_ops import MaskedScatter
|
|||
from mindspore.ops.operations.array_ops import CountNonZero
|
||||
from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
|
||||
from mindspore.ops.operations.random_ops import LogNormalReverse
|
||||
from mindspore.ops.operations.random_ops import ParameterizedTruncatedNormal
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -378,6 +379,15 @@ def get_bprop_log_normal_reverse(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(ParameterizedTruncatedNormal)
|
||||
def get_bprop_parameterized_truncated_normal(self):
|
||||
"""Grad definition for `ParameterizedTruncatedNormal` operation."""
|
||||
def bprop(shape, mean, stdevs, min_val, max_val, out, dout):
|
||||
return (zeros_like(shape), zeros_like(mean), zeros_like(stdevs), zeros_like(min_val), zeros_like(max_val))
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.TensorScatterMax)
|
||||
def get_bprop_tensor_scatter_max(self):
|
||||
"""Generate bprop for TensorScatterMax"""
|
||||
|
|
Loading…
Reference in New Issue