diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumulativelogsumexp_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumulativelogsumexp_impl.cu index 9503ce04fc4..d54745bddd0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumulativelogsumexp_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumulativelogsumexp_impl.cu @@ -58,9 +58,11 @@ __global__ void CumulativeLogsumexpKernelReverse(const T *input, T *output, size size_t read_index = (dim1 - 1) * stride2 + offset; output[read_index] = input[read_index]; + T pre_sum = expT(output[read_index]); for (int j = dim1 - 2; j >= 0; --j) { read_index = j * stride2 + offset; - output[read_index] = logT(expT(output[read_index + stride2]) + expT(input[read_index])); + pre_sum += expT(input[read_index]); + output[read_index] = logT(pre_sum); } } } @@ -78,9 +80,11 @@ __global__ void CumulativeLogsumexpKernel(const T *input, T *output, size_t dim0 size_t read_index = offset; output[read_index] = input[read_index]; + T pre_sum = expT(output[read_index]); for (size_t j = 1; j < dim1; ++j) { read_index = j * stride2 + offset; - output[read_index] = logT(expT(output[read_index - stride2]) + expT(input[read_index])); + pre_sum += expT(input[read_index]); + output[read_index] = logT(pre_sum); } } } @@ -98,9 +102,11 @@ __global__ void CumulativeLogsumexpKernelExclusive(const T *input, T *output, si size_t read_index = offset; output[read_index] = neg_infT(); + T pre_sum = expT(output[read_index]); for (size_t j = 1; j < dim1; ++j) { read_index = j * stride2 + offset; - output[read_index] = logT(expT(output[read_index - stride2]) + expT(input[read_index - stride2])); + pre_sum += expT(input[read_index - stride2]); + output[read_index] = logT(pre_sum); } output[offset] = neg_maxT(); } @@ -118,9 +124,11 @@ __global__ void CumulativeLogsumexpKernelReverseExclusive(const T *input, T *out size_t read_index = (dim1 - 1) * stride2 + offset; output[read_index] = neg_infT(); + T pre_sum = expT(output[read_index]); for (int j = dim1 - 2; j >= 0; --j) { read_index = j * stride2 + offset; - output[read_index] = logT(expT(output[read_index + stride2]) + expT(input[read_index + stride2])); + pre_sum += expT(input[read_index + stride2]); + output[read_index] = logT(pre_sum); } output[(dim1 - 1) * stride2 + offset] = neg_maxT(); }