!49438 修复CumulativeLogsumexp算子问题

Merge pull request !49438 from zong_shuai/gpu_cumu
This commit is contained in:
i-robot 2023-02-27 10:45:05 +00:00 committed by Gitee
commit f4d0bbe629
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 12 additions and 4 deletions

View File

@ -58,9 +58,11 @@ __global__ void CumulativeLogsumexpKernelReverse(const T *input, T *output, size
size_t read_index = (dim1 - 1) * stride2 + offset;
output[read_index] = input[read_index];
T pre_sum = expT(output[read_index]);
for (int j = dim1 - 2; j >= 0; --j) {
read_index = j * stride2 + offset;
output[read_index] = logT(expT(output[read_index + stride2]) + expT(input[read_index]));
pre_sum += expT(input[read_index]);
output[read_index] = logT(pre_sum);
}
}
}
@ -78,9 +80,11 @@ __global__ void CumulativeLogsumexpKernel(const T *input, T *output, size_t dim0
size_t read_index = offset;
output[read_index] = input[read_index];
T pre_sum = expT(output[read_index]);
for (size_t j = 1; j < dim1; ++j) {
read_index = j * stride2 + offset;
output[read_index] = logT(expT(output[read_index - stride2]) + expT(input[read_index]));
pre_sum += expT(input[read_index]);
output[read_index] = logT(pre_sum);
}
}
}
@ -98,9 +102,11 @@ __global__ void CumulativeLogsumexpKernelExclusive(const T *input, T *output, si
size_t read_index = offset;
output[read_index] = neg_infT<T>();
T pre_sum = expT(output[read_index]);
for (size_t j = 1; j < dim1; ++j) {
read_index = j * stride2 + offset;
output[read_index] = logT(expT(output[read_index - stride2]) + expT(input[read_index - stride2]));
pre_sum += expT(input[read_index - stride2]);
output[read_index] = logT(pre_sum);
}
output[offset] = neg_maxT<T>();
}
@ -118,9 +124,11 @@ __global__ void CumulativeLogsumexpKernelReverseExclusive(const T *input, T *out
size_t read_index = (dim1 - 1) * stride2 + offset;
output[read_index] = neg_infT<T>();
T pre_sum = expT(output[read_index]);
for (int j = dim1 - 2; j >= 0; --j) {
read_index = j * stride2 + offset;
output[read_index] = logT(expT(output[read_index + stride2]) + expT(input[read_index + stride2]));
pre_sum += expT(input[read_index + stride2]);
output[read_index] = logT(pre_sum);
}
output[(dim1 - 1) * stride2 + offset] = neg_maxT<T>();
}