forked from mindspore-Ecosystem/mindspore
!48169 solve the bp bug of cumulativelogsumexp
Merge pull request !48169 from zong_shuai/cumulative_res_error
This commit is contained in:
commit
9f12d31d49
|
@ -483,26 +483,16 @@ def get_brop_cumulative_logsumexp(self):
|
|||
|
||||
def bprop(x, axis, out, dout):
|
||||
dtype_min = 0
|
||||
fp64_flag = False
|
||||
if x.dtype == mstype.float16:
|
||||
dtype_min = -65500e+0
|
||||
elif x.dtype == mstype.float32:
|
||||
dtype_min = -3.4028235e+38
|
||||
elif x.dtype == mstype.float64:
|
||||
dout = F.cast(dout, mstype.float32)
|
||||
x = F.cast(x, mstype.float32)
|
||||
out = F.cast(out, mstype.float32)
|
||||
dtype_min = -3.4028235e+38
|
||||
fp64_flag = True
|
||||
|
||||
dtype_min = -1.7976931348623157e+308
|
||||
log_grad_positive = where_v2(greater_op(dout, 0), log_op(dout), dtype_min)
|
||||
log_grad_negative = where_v2(less_op(dout, 0), log_op(neg_op(dout)), dtype_min)
|
||||
output_pos = exp_op(cumulative_op(log_grad_positive - out, axis) + x)
|
||||
output_neg = exp_op(cumulative_op(log_grad_negative - out, axis) + x)
|
||||
if fp64_flag:
|
||||
output_pos = F.cast(output_pos, mstype.float64)
|
||||
output_neg = F.cast(output_neg, mstype.float64)
|
||||
x = F.cast(x, mstype.float64)
|
||||
return (output_pos - output_neg, zeros_like(x))
|
||||
|
||||
return bprop
|
||||
|
|
Loading…
Reference in New Issue