!9207 Fix maxgradgrad minimumgradgrad relugradgrad
From: @yuan_shen_zhou Reviewed-by: @liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
f42adabb57
|
@ -14,29 +14,40 @@
|
|||
# ============================================================================
|
||||
|
||||
"""bprop primitives"""
|
||||
from ..operations import _grad_ops as G
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ..composite import multitype_ops as C
|
||||
from .grad_base import bprops
|
||||
|
||||
get_dtype = P.DType()
|
||||
# Unused parameters are placeholders.
|
||||
|
||||
|
||||
@bprops.register("MaximumGrad")
|
||||
def bprop_maximum_grad_grad(x, y, z, out, dout):
|
||||
"""Backpropagator for primitive `MaximumGrad`."""
|
||||
return F.zeros_like(x), F.zeros_like(y), F.zeros_like(z)
|
||||
out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
|
||||
out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
|
||||
dz = out0 * dout[0] + out1 * dout[1]
|
||||
return F.zeros_like(x), F.zeros_like(y), dz
|
||||
|
||||
|
||||
@bprops.register("MinimumGrad")
|
||||
def bprop_minimum_grad_grad(x, y, z, out, dout):
|
||||
"""Backpropagator for primitive `MinimumGrad`."""
|
||||
return F.zeros_like(x), F.zeros_like(y), F.zeros_like(z)
|
||||
out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
|
||||
out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
|
||||
dz = out0 * dout[0] + out1 * dout[1]
|
||||
return F.zeros_like(x), F.zeros_like(y), dz
|
||||
|
||||
|
||||
@bprops.register("ReluGrad")
|
||||
def bprop_relu_grad_grad(x, y, out, dout):
|
||||
"""Backpropagator for primitive `ReluGrad`."""
|
||||
return F.zeros_like(x), F.zeros_like(y)
|
||||
input_grad = G.ReluGrad()
|
||||
dy = input_grad(dout, y)
|
||||
return dy, F.zeros_like(y)
|
||||
|
||||
|
||||
@bprops.register("scalar_add")
|
||||
|
|
Loading…
Reference in New Issue