!49438 修复CumulativeLogsumexp算子问题
Merge pull request !49438 from zong_shuai/gpu_cumu
This commit is contained in:
commit
f4d0bbe629
|
@ -58,9 +58,11 @@ __global__ void CumulativeLogsumexpKernelReverse(const T *input, T *output, size
|
|||
|
||||
size_t read_index = (dim1 - 1) * stride2 + offset;
|
||||
output[read_index] = input[read_index];
|
||||
T pre_sum = expT(output[read_index]);
|
||||
for (int j = dim1 - 2; j >= 0; --j) {
|
||||
read_index = j * stride2 + offset;
|
||||
output[read_index] = logT(expT(output[read_index + stride2]) + expT(input[read_index]));
|
||||
pre_sum += expT(input[read_index]);
|
||||
output[read_index] = logT(pre_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -78,9 +80,11 @@ __global__ void CumulativeLogsumexpKernel(const T *input, T *output, size_t dim0
|
|||
|
||||
size_t read_index = offset;
|
||||
output[read_index] = input[read_index];
|
||||
T pre_sum = expT(output[read_index]);
|
||||
for (size_t j = 1; j < dim1; ++j) {
|
||||
read_index = j * stride2 + offset;
|
||||
output[read_index] = logT(expT(output[read_index - stride2]) + expT(input[read_index]));
|
||||
pre_sum += expT(input[read_index]);
|
||||
output[read_index] = logT(pre_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -98,9 +102,11 @@ __global__ void CumulativeLogsumexpKernelExclusive(const T *input, T *output, si
|
|||
|
||||
size_t read_index = offset;
|
||||
output[read_index] = neg_infT<T>();
|
||||
T pre_sum = expT(output[read_index]);
|
||||
for (size_t j = 1; j < dim1; ++j) {
|
||||
read_index = j * stride2 + offset;
|
||||
output[read_index] = logT(expT(output[read_index - stride2]) + expT(input[read_index - stride2]));
|
||||
pre_sum += expT(input[read_index - stride2]);
|
||||
output[read_index] = logT(pre_sum);
|
||||
}
|
||||
output[offset] = neg_maxT<T>();
|
||||
}
|
||||
|
@ -118,9 +124,11 @@ __global__ void CumulativeLogsumexpKernelReverseExclusive(const T *input, T *out
|
|||
|
||||
size_t read_index = (dim1 - 1) * stride2 + offset;
|
||||
output[read_index] = neg_infT<T>();
|
||||
T pre_sum = expT(output[read_index]);
|
||||
for (int j = dim1 - 2; j >= 0; --j) {
|
||||
read_index = j * stride2 + offset;
|
||||
output[read_index] = logT(expT(output[read_index + stride2]) + expT(input[read_index + stride2]));
|
||||
pre_sum += expT(input[read_index + stride2]);
|
||||
output[read_index] = logT(pre_sum);
|
||||
}
|
||||
output[(dim1 - 1) * stride2 + offset] = neg_maxT<T>();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue