the performance optimization of logsoftmax operator on gpu device. 50+ times better

This commit is contained in:
happylittleqiang 2023-01-02 23:11:48 +08:00
parent 44cfcca6f4
commit 777bec3cc1
2 changed files with 92 additions and 120 deletions

View File

@ -36,43 +36,48 @@ constexpr int ALIGN_BYTES = 16;
} \ } \
} }
inline dim3 SpatialSoftMax_getGridSize(dim3 *block, uint32_t max_active_blocks, uint64_t outer_size, uint64_t dim_size, inline dim3 SpatialSoftMaxGetGridSize(dim3 *block, uint32_t activate_block, uint64_t outer_size, uint64_t dim_size,
uint64_t inner_size) { uint64_t inner_size) {
// First, tile as many blocks as we can over the y axis uint32_t inner = (inner_size + block->y - 1) / block->y;
uint32_t inner_blocks = (inner_size + block->y - 1) / block->y; if (inner > activate_block) {
if (inner_blocks > max_active_blocks) inner_blocks = max_active_blocks; inner = activate_block;
// Fill the x axis with as many blocks as we can fit (a little more is ok too) }
uint32_t outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks; uint32_t outer = (activate_block + inner - 1) / inner;
if (outer_blocks > outer_size) outer_blocks = outer_size; if (outer > outer_size) {
return dim3(outer_blocks, inner_blocks); outer = outer_size;
}
return dim3(outer, inner);
} }
inline dim3 SpatialSoftMax_getBlockSize(uint64_t outer_size, uint64_t dim_size, uint64_t inner_size) { inline dim3 SpatialSoftMaxGetBlockSize(uint64_t outer_size, uint64_t dim_size, uint64_t inner_size) {
uint32_t inner_threads = inner_size; uint32_t inner_ths = inner_size;
inner_threads = std::min(inner_threads, static_cast<uint32_t>(max_threads)); inner_ths = std::min(inner_ths, static_cast<uint32_t>(max_threads));
uint32_t dim_threads = 1; uint32_t dim_threads = 1;
if (inner_threads <= 64 && dim_size >= 64) { if (inner_ths <= 64 && dim_size >= 64) {
while (inner_threads * dim_threads <= max_threads && dim_threads <= dim_size) dim_threads *= 2; while ((inner_ths * dim_threads <= max_threads) && (dim_threads <= dim_size)) {
dim_threads *= 2;
}
dim_threads /= 2; dim_threads /= 2;
} }
return dim3(dim_threads, inner_threads); return dim3(dim_threads, inner_ths);
} }
template <typename accumulate_t, typename Kernel> template <typename accumulate_t, typename Kernel>
void SpatialSoftMax_getLaunchSizes(Kernel k, uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, dim3 *grid, void SpatialSoftMaxGetLaunchSizes(Kernel k, uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, dim3 *grid,
dim3 *block, uint32_t *smem_size, uint32_t device_id) { dim3 *block, uint32_t *smem_size, uint32_t device_id) {
*block = SpatialSoftMax_getBlockSize(outer_size, dim_size, inner_size); *block = SpatialSoftMaxGetBlockSize(outer_size, dim_size, inner_size);
uint32_t block_threads = block->x * block->y; uint32_t block_ths = block->x * block->y;
*smem_size = block->x == 1 ? 0 : block_threads * sizeof(accumulate_t); if (block->x == 1) {
*smem_size = 0;
int max_active_blocks; } else {
*smem_size = block_ths * sizeof(accumulate_t);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, k, block_threads, *smem_size); }
int activate_size;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&activate_size, k, block_ths, *smem_size);
cudaDeviceProp prop; cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, device_id); (void)cudaGetDeviceProperties(&prop, device_id);
max_active_blocks *= prop.multiProcessorCount; activate_size *= prop.multiProcessorCount;
*grid = SpatialSoftMaxGetGridSize(block, activate_size, outer_size, dim_size, inner_size);
*grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, dim_size, inner_size);
} }
int log2_ceil(int val) { int log2_ceil(int val) {
@ -135,63 +140,58 @@ __forceinline__ __device__ T SpatialBlockReduceX(T *memsha, T val) {
return memsha[0]; return memsha[0];
} }
template <typename input_t, typename accumulate_t, typename output_t> template <typename input_t, typename accumulate_t, typename output_t, bool is_log_softmax>
__global__ void SpatialSoftMaxForward(output_t *output, input_t *input, uint32_t outer_size, uint32_t dim_size, __global__ void SpatialSoftMaxForward(output_t *output, input_t *input, uint32_t outer_size, uint32_t dim_size,
uint32_t inner_size) { uint32_t inner_size) {
extern __shared__ unsigned char smem[]; extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accumulate_t *>(smem); auto sdata = reinterpret_cast<accumulate_t *>(smem);
const uint32_t outer_stride = inner_size * dim_size; const uint32_t outer_stride = inner_size * dim_size;
const uint32_t dim_stride = inner_size; const uint32_t dim_stride = inner_size;
for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) { for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
const uint32_t outer_offset = outer_index * outer_stride; const uint32_t outer_offset = outer_index * outer_stride;
for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size;
inner_index += blockDim.y * gridDim.y) { inner_index += blockDim.y * gridDim.y) {
const uint32_t data_offset = outer_offset + inner_index; const uint32_t offset = outer_offset + inner_index;
if (blockDim.x > 1) { if (blockDim.x > 1) {
accumulate_t max_input = std::numeric_limits<accumulate_t>::lowest(); accumulate_t max_data_input = std::numeric_limits<accumulate_t>::lowest();
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) { for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
const accumulate_t value = static_cast<accumulate_t>(input[data_offset + d * dim_stride]); const accumulate_t value = static_cast<accumulate_t>(input[offset + d * dim_stride]);
max_input = atomic::Max()(max_input, value); max_data_input = atomic::Max()(max_data_input, value);
} }
max_input = SpatialBlockReduceX<accumulate_t, atomic::Max>(sdata, max_input); max_data_input = SpatialBlockReduceX<accumulate_t, atomic::Max>(sdata, max_data_input);
accumulate_t sum = 0; accumulate_t sum = 0;
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
sum += std::exp(static_cast<accumulate_t>(input[data_offset + d * dim_stride]) - max_input); sum += std::exp(static_cast<accumulate_t>(input[offset + d * dim_stride]) - max_data_input);
sum = SpatialBlockReduceX<accumulate_t, atomic::Add>(sdata, sum); sum = SpatialBlockReduceX<accumulate_t, atomic::Add>(sdata, sum);
SoftMaxForwardEpilogue<input_t, accumulate_t, output_t, is_log_softmax> epilogue(max_data_input, sum);
SoftMaxForwardEpilogue<input_t, accumulate_t, output_t> epilogue(max_input, sum);
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]); output[offset + d * dim_stride] = epilogue(input[offset + d * dim_stride]);
} else { } else {
accumulate_t max_input = std::numeric_limits<accumulate_t>::lowest(); accumulate_t max_data_input = std::numeric_limits<accumulate_t>::lowest();
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) { for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
const accumulate_t value = static_cast<accumulate_t>(input[data_offset + d * dim_stride]); const accumulate_t value = static_cast<accumulate_t>(input[offset + d * dim_stride]);
max_input = atomic::Max()(max_input, value); max_data_input = atomic::Max()(max_data_input, value);
} }
accumulate_t sum = 0; accumulate_t sum = 0;
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
sum += std::exp(static_cast<accumulate_t>(input[data_offset + d * dim_stride]) - max_input); sum += std::exp(static_cast<accumulate_t>(input[offset + d * dim_stride]) - max_data_input);
SoftMaxForwardEpilogue<input_t, accumulate_t, output_t> epilogue(max_input, sum);
SoftMaxForwardEpilogue<input_t, accumulate_t, output_t, is_log_softmax> epilogue(max_data_input, sum);
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]); output[offset + d * dim_stride] = epilogue(input[offset + d * dim_stride]);
} }
} }
} }
} }
template <int InsP, typename input_t, typename accum_t, typename output_t> template <int InsP, typename input_t, typename accum_t, typename output_t, bool is_log_softmax>
__device__ __forceinline__ void WriteResults(int clas, input_t *input, output_t *output, input_t max_k, __device__ __forceinline__ void WriteResults(int clas, input_t *input, output_t *output, input_t max_k,
input_t sum_all) { input_t sum_all) {
SoftMaxForwardEpilogue<input_t, accum_t, output_t> epilogue(max_k, sum_all); SoftMaxForwardEpilogue<input_t, accum_t, output_t, is_log_softmax> epilogue(max_k, sum_all);
int offset = threadIdx.x; int offset = threadIdx.x;
int last = clas % (InsP * blockDim.x); int last = clas % (InsP * blockDim.x);
for (; offset < clas - last; offset += blockDim.x * InsP) { for (; offset < clas - last; offset += blockDim.x * InsP) {
input_t tmp[InsP]; input_t tmp[InsP];
#pragma unroll #pragma unroll
for (int j = 0; j < InsP; ++j) { for (int j = 0; j < InsP; ++j) {
tmp[j] = input[offset + j * blockDim.x]; tmp[j] = input[offset + j * blockDim.x];
@ -201,54 +201,42 @@ __device__ __forceinline__ void WriteResults(int clas, input_t *input, output_t
output[offset + j * blockDim.x] = epilogue(tmp[j]); output[offset + j * blockDim.x] = epilogue(tmp[j]);
} }
} }
for (; offset < clas; offset += blockDim.x) { for (; offset < clas; offset += blockDim.x) {
output[offset] = epilogue(input[offset]); output[offset] = epilogue(input[offset]);
} }
} }
template <int InsP, typename input_t, typename accum_t, typename output_t> template <int InsP, typename input_t, typename accum_t, typename output_t, bool is_log_softmax>
__device__ __forceinline__ void WriteResultsVectorized(int size, const int shift, input_t *input, output_t *output, __device__ __forceinline__ void WriteResultsVectorized(int size, const int deviate, input_t *input, output_t *output,
input_t max_k, input_t sum_all) { input_t max_k, input_t sum_all) {
SoftMaxForwardEpilogue<input_t, accum_t, output_t> epilogue(max_k, sum_all); SoftMaxForwardEpilogue<input_t, accum_t, output_t, is_log_softmax> epilogue(max_k, sum_all);
using loadT = aligned_vector<input_t>;
using LoadT = aligned_vector<input_t>; using storeT = aligned_vector<output_t>;
using StoreT = aligned_vector<output_t>;
int offset = threadIdx.x; int offset = threadIdx.x;
if (deviate > 0) {
if (shift > 0) { input -= deviate;
input -= shift; output -= deviate;
output -= shift; size += deviate;
size += shift; if (threadIdx.x >= deviate) {
if (threadIdx.x >= shift) {
output[offset] = epilogue(input[offset]); output[offset] = epilogue(input[offset]);
} }
size -= blockDim.x; size -= blockDim.x;
input += blockDim.x; input += blockDim.x;
output += blockDim.x; output += blockDim.x;
} }
const int last = size % (InsP * blockDim.x); const int last = size % (InsP * blockDim.x);
input_t in_v[InsP]; input_t in_v[InsP];
LoadT *in_value = reinterpret_cast<LoadT *>(&in_v); loadT *in_value = reinterpret_cast<loadT *>(&in_v);
output_t out_v[InsP]; output_t out_v[InsP];
StoreT *out_value = reinterpret_cast<StoreT *>(&out_v); storeT *out_value = reinterpret_cast<storeT *>(&out_v);
for (; offset * InsP < (size - last); offset += blockDim.x) { for (; offset * InsP < (size - last); offset += blockDim.x) {
*in_value = reinterpret_cast<LoadT *>(input)[offset]; *in_value = reinterpret_cast<loadT *>(input)[offset];
#pragma unroll #pragma unroll
for (int j = 0; j < InsP; ++j) { for (int j = 0; j < InsP; ++j) {
out_v[j] = epilogue(in_v[j]); out_v[j] = epilogue(in_v[j]);
} }
reinterpret_cast<storeT *>(output)[offset] = *out_value;
reinterpret_cast<StoreT *>(output)[offset] = *out_value;
} }
offset = size - last + threadIdx.x; offset = size - last + threadIdx.x;
for (; offset < size; offset += blockDim.x) { for (; offset < size; offset += blockDim.x) {
output[offset] = epilogue(input[offset]); output[offset] = epilogue(input[offset]);
@ -256,52 +244,43 @@ __device__ __forceinline__ void WriteResultsVectorized(int size, const int shift
} }
template <typename Reduction, typename AccT> template <typename Reduction, typename AccT>
__device__ __forceinline__ AccT ReduceBlock(AccT *smem, AccT val, AccT defaultVal) { __device__ __forceinline__ AccT ReduceBlock(AccT *sharemen, AccT val, AccT initVal) {
Reduction r = Reduction(); Reduction r = Reduction();
__syncthreads(); __syncthreads();
sharemen[threadIdx.x] = val;
smem[threadIdx.x] = val;
__syncthreads(); __syncthreads();
AccT warpVal = initVal;
AccT warpVal = defaultVal;
uint32_t mask = (((uint64_t)1) << (blockDim.x / WARPSIZE)) - 1; uint32_t mask = (((uint64_t)1) << (blockDim.x / WARPSIZE)) - 1;
if (threadIdx.x < WARPSIZE) { if (threadIdx.x < WARPSIZE) {
int lane = threadIdx.x % WARPSIZE; int lane = threadIdx.x % WARPSIZE;
if (lane < blockDim.x / WARPSIZE) { if (lane < blockDim.x / WARPSIZE) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARPSIZE; ++i) { for (int i = 0; i < WARPSIZE; ++i) {
warpVal = r(warpVal, smem[lane * WARPSIZE + i]); warpVal = r(warpVal, sharemen[lane * WARPSIZE + i]);
} }
#ifndef __HIP_PLATFORM_HCC__ #ifndef __HIP_PLATFORM_HCC__
__syncwarp(mask); __syncwarp(mask);
#endif #endif
smem[lane] = warpVal; sharemen[lane] = warpVal;
} }
} }
__syncthreads(); __syncthreads();
AccT blockVal = defaultVal; AccT blockVal = initVal;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / WARPSIZE; ++i) { for (int i = 0; i < blockDim.x / WARPSIZE; ++i) {
blockVal = r(blockVal, smem[i]); blockVal = r(blockVal, sharemen[i]);
} }
smem[0] = blockVal; sharemen[0] = blockVal;
} }
__syncthreads(); __syncthreads();
return smem[0]; return sharemen[0];
} }
template <template <typename, typename> class Reduction, int InsP, typename T, typename AccT> template <template <typename, typename> class Reduction, int InsP, typename T, typename AccT>
__device__ __forceinline__ AccT ILPReduce(int shift, T *data, int size, const Reduction<T, AccT> &r, AccT defaultVal) { __device__ __forceinline__ AccT ILPReduce(int shift, T *data, int size, const Reduction<T, AccT> &r, AccT initVal) {
using LoadT = aligned_vector<T>; using loadT = aligned_vector<T>;
AccT threadVal = defaultVal; AccT threadVal = initVal;
int offset = threadIdx.x; int offset = threadIdx.x;
if (shift > 0) { if (shift > 0) {
data -= shift; data -= shift;
size += shift; size += shift;
@ -313,10 +292,9 @@ __device__ __forceinline__ AccT ILPReduce(int shift, T *data, int size, const Re
} }
int last = size % (InsP * blockDim.x); int last = size % (InsP * blockDim.x);
T v[InsP]; T v[InsP];
LoadT *value = reinterpret_cast<LoadT *>(&v); loadT *value = reinterpret_cast<loadT *>(&v);
for (; offset * InsP < (size - last); offset += blockDim.x) { for (; offset * InsP < (size - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT *>(data)[offset]; *value = reinterpret_cast<loadT *>(data)[offset];
#pragma unroll #pragma unroll
for (int j = 0; j < InsP; ++j) { for (int j = 0; j < InsP; ++j) {
threadVal = r(threadVal, v[j]); threadVal = r(threadVal, v[j]);
@ -324,7 +302,6 @@ __device__ __forceinline__ AccT ILPReduce(int shift, T *data, int size, const Re
} }
offset = size - last + threadIdx.x; offset = size - last + threadIdx.x;
for (; offset < size; offset += blockDim.x) threadVal = r(threadVal, data[offset]); for (; offset < size; offset += blockDim.x) threadVal = r(threadVal, data[offset]);
return threadVal; return threadVal;
} }
@ -350,7 +327,6 @@ __global__ void SoftMaxWarpForward(output_t *dst, const input_t *src, int batch_
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int local_batches = batch_size - first_batch; int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
@ -449,7 +425,6 @@ __global__ void SoftMaxWarpForward(output_t *dst, const input_t *src, int batch_
} }
} }
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, atomic::Add>(sum); warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, atomic::Add>(sum);
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break; if (i >= local_batches) break;
@ -472,17 +447,15 @@ __global__ void SoftMaxWarpForward(output_t *dst, const input_t *src, int batch_
} }
} }
template <int InsP, typename T, typename accumulate_t> template <int InsP, typename T, typename accumulate_t, bool is_log_softmax>
__global__ void cunn_SoftMaxForward(T *output, T *input, int classes) { __global__ void cunn_SoftMaxForward(T *output, T *input, int classes) {
extern __shared__ unsigned char smem[]; extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accumulate_t *>(smem); auto sdata = reinterpret_cast<accumulate_t *>(smem);
using LoadT = aligned_vector<T>; using loadT = aligned_vector<T>;
using StoreT = aligned_vector<T>; using storeT = aligned_vector<T>;
input += blockIdx.x * classes; input += blockIdx.x * classes;
output += blockIdx.x * classes; output += blockIdx.x * classes;
const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(T); const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(T);
const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(T); const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(T);
@ -496,14 +469,13 @@ __global__ void cunn_SoftMaxForward(T *output, T *input, int classes) {
accumulate_t sumAll = ReduceBlock<atomic::Add, accumulate_t>(sdata, threadExp, static_cast<accumulate_t>(0)); accumulate_t sumAll = ReduceBlock<atomic::Add, accumulate_t>(sdata, threadExp, static_cast<accumulate_t>(0));
if (shift == output_shift) { if (shift == output_shift) {
WriteResultsVectorized<InsP, T, accumulate_t, T>(classes, shift, input, output, max_k, sumAll); WriteResultsVectorized<InsP, T, accumulate_t, T, is_log_softmax>(classes, shift, input, output, max_k, sumAll);
} else { } else {
WriteResults<InsP, T, accumulate_t, T>(classes, input, output, max_k, sumAll); WriteResults<InsP, T, accumulate_t, T, is_log_softmax>(classes, input, output, max_k, sumAll);
} }
} }
// end of kernel function // end of kernel function
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked> template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride,
int batch_count, cudaStream_t stream, const bool *mask = nullptr, int chunk_size = -1, int batch_count, cudaStream_t stream, const bool *mask = nullptr, int chunk_size = -1,
@ -567,16 +539,16 @@ void Softmax(T *input_, T *output_, size_t dim_size_, size_t outer_size_, size_t
} else { } else {
constexpr int InsP = sizeof(float4) / sizeof(T); constexpr int InsP = sizeof(float4) / sizeof(T);
dim3 block = SoftMaxGetBlockSize(InsP, dim_size_); dim3 block = SoftMaxGetBlockSize(InsP, dim_size_);
cunn_SoftMaxForward<InsP, T, accumulate_t> cunn_SoftMaxForward<InsP, T, accumulate_t, is_log_softmax>
<<<grid, block, block.x * sizeof(T), cuda_stream>>>(output_, input_, dim_size_); <<<grid, block, block.x * sizeof(T), cuda_stream>>>(output_, input_, dim_size_);
CUDA_CHECK(); CUDA_CHECK();
} }
} else { } else {
uint32_t smem_size; uint32_t smem_size;
dim3 grid, block; dim3 grid, block;
SpatialSoftMax_getLaunchSizes<T>(&SpatialSoftMaxForward<T, accumulate_t, T>, outer_size_, dim_size_, inner_size_, SpatialSoftMaxGetLaunchSizes<T>(&SpatialSoftMaxForward<T, accumulate_t, T, is_log_softmax>, outer_size_, dim_size_,
&grid, &block, &smem_size, device_id); inner_size_, &grid, &block, &smem_size, device_id);
SpatialSoftMaxForward<T, accumulate_t, T> SpatialSoftMaxForward<T, accumulate_t, T, is_log_softmax>
<<<grid, block, smem_size, cuda_stream>>>(output_, input_, outer_size_, dim_size_, inner_size_); <<<grid, block, smem_size, cuda_stream>>>(output_, input_, outer_size_, dim_size_, inner_size_);
CUDA_CHECK(); CUDA_CHECK();
} }

View File

@ -37,14 +37,14 @@ struct AccumulateType<double, true> {
template <typename T, bool is_cuda> template <typename T, bool is_cuda>
using acc_type = typename AccumulateType<T, is_cuda>::type; using acc_type = typename AccumulateType<T, is_cuda>::type;
template <typename T, typename AccumT, typename OutT> template <typename T, typename AccumT, typename OutT, bool is_log_softmax>
struct SoftMaxForwardEpilogue { struct SoftMaxForwardEpilogue {
__device__ __forceinline__ SoftMaxForwardEpilogue(AccumT max_input, AccumT sum) : max_input(max_input), sum(sum) {} __device__ __forceinline__ SoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
: max_input(max_input), sum(is_log_softmax == true ? std::log(sum) : sum) {}
__device__ __forceinline__ OutT operator()(T input) const { __device__ __forceinline__ OutT operator()(T input) const {
return static_cast<OutT>(std::exp((AccumT)input - max_input) / sum); return is_log_softmax == true ? static_cast<OutT>((AccumT)input - max_input - sum)
: static_cast<OutT>(std::exp((AccumT)input - max_input) / sum);
} }
const AccumT max_input; const AccumT max_input;
const AccumT sum; const AccumT sum;
}; };