!48169 solve the bp bug of cumulativelogsumexp

Merge pull request !48169 from zong_shuai/cumulative_res_error
This commit is contained in:
i-robot 2023-01-30 01:42:41 +00:00 committed by Gitee
commit 9f12d31d49
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 1 additions and 11 deletions

View File

@ -483,26 +483,16 @@ def get_brop_cumulative_logsumexp(self):
def bprop(x, axis, out, dout):
dtype_min = 0
fp64_flag = False
if x.dtype == mstype.float16:
dtype_min = -65500e+0
elif x.dtype == mstype.float32:
dtype_min = -3.4028235e+38
elif x.dtype == mstype.float64:
dout = F.cast(dout, mstype.float32)
x = F.cast(x, mstype.float32)
out = F.cast(out, mstype.float32)
dtype_min = -3.4028235e+38
fp64_flag = True
dtype_min = -1.7976931348623157e+308
log_grad_positive = where_v2(greater_op(dout, 0), log_op(dout), dtype_min)
log_grad_negative = where_v2(less_op(dout, 0), log_op(neg_op(dout)), dtype_min)
output_pos = exp_op(cumulative_op(log_grad_positive - out, axis) + x)
output_neg = exp_op(cumulative_op(log_grad_negative - out, axis) + x)
if fp64_flag:
output_pos = F.cast(output_pos, mstype.float64)
output_neg = F.cast(output_neg, mstype.float64)
x = F.cast(x, mstype.float64)
return (output_pos - output_neg, zeros_like(x))
return bprop