forked from mindspore-Ecosystem/mindspore
add squreasumall grad
This commit is contained in:
parent
3e7ba14e19
commit
7011508379
|
@ -397,6 +397,22 @@ def get_bprop_xlogy(self):
|
|||
|
||||
return bprop
|
||||
|
||||
@bprop_getters.register(P.SquareSumAll)
|
||||
def get_bprop_square_sum_all(self):
|
||||
"""Grad definition for `Square` operation."""
|
||||
mul_func = P.Mul()
|
||||
fill_func = P.Fill()
|
||||
dtype = P.DType()
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
temp_x = mul_func(dout[0], x)
|
||||
temp_y = mul_func(dout[1], y)
|
||||
dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
|
||||
dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
|
||||
return (dx, dy)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Sqrt)
|
||||
def get_bprop_sqrt(self):
|
||||
|
|
|
@ -1188,7 +1188,8 @@ test_case_math_ops = [
|
|||
'block': P.SquareSumAll(),
|
||||
'desc_inputs': [Tensor(np.array([0, 1, 4, 5]).astype(np.float32)),
|
||||
Tensor(np.array([1, 1, 3, 7]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
'desc_bprop': [Tensor(np.array(0.1).astype(np.float32)),
|
||||
Tensor(np.array(0.1).astype(np.float32))]}),
|
||||
('Cos', {
|
||||
'block': P.Cos(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
|
|
Loading…
Reference in New Issue