add squreasumall grad

This commit is contained in:
fangzehua 2020-08-03 14:29:36 +08:00
parent 3e7ba14e19
commit 7011508379
2 changed files with 18 additions and 1 deletions

View File

@ -397,6 +397,22 @@ def get_bprop_xlogy(self):
return bprop
@bprop_getters.register(P.SquareSumAll)
def get_bprop_square_sum_all(self):
"""Grad definition for `Square` operation."""
mul_func = P.Mul()
fill_func = P.Fill()
dtype = P.DType()
def bprop(x, y, out, dout):
temp_x = mul_func(dout[0], x)
temp_y = mul_func(dout[1], y)
dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
return (dx, dy)
return bprop
@bprop_getters.register(P.Sqrt)
def get_bprop_sqrt(self):

View File

@ -1188,7 +1188,8 @@ test_case_math_ops = [
'block': P.SquareSumAll(),
'desc_inputs': [Tensor(np.array([0, 1, 4, 5]).astype(np.float32)),
Tensor(np.array([1, 1, 3, 7]).astype(np.float32))],
'skip': ['backward']}),
'desc_bprop': [Tensor(np.array(0.1).astype(np.float32)),
Tensor(np.array(0.1).astype(np.float32))]}),
('Cos', {
'block': P.Cos(),
'desc_inputs': [[2, 3]],