forked from mindspore-Ecosystem/mindspore
!47541 the performance optimization of logsoftmax operator on gpu device. 50+ times better
Merge pull request !47541 from happy徐/GPU_OPT
This commit is contained in:
commit
b562edf89e
|
@ -36,43 +36,48 @@ constexpr int ALIGN_BYTES = 16;
|
|||
} \
|
||||
}
|
||||
|
||||
inline dim3 SpatialSoftMax_getGridSize(dim3 *block, uint32_t max_active_blocks, uint64_t outer_size, uint64_t dim_size,
|
||||
uint64_t inner_size) {
|
||||
// First, tile as many blocks as we can over the y axis
|
||||
uint32_t inner_blocks = (inner_size + block->y - 1) / block->y;
|
||||
if (inner_blocks > max_active_blocks) inner_blocks = max_active_blocks;
|
||||
// Fill the x axis with as many blocks as we can fit (a little more is ok too)
|
||||
uint32_t outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks;
|
||||
if (outer_blocks > outer_size) outer_blocks = outer_size;
|
||||
return dim3(outer_blocks, inner_blocks);
|
||||
inline dim3 SpatialSoftMaxGetGridSize(dim3 *block, uint32_t activate_block, uint64_t outer_size, uint64_t dim_size,
|
||||
uint64_t inner_size) {
|
||||
uint32_t inner = (inner_size + block->y - 1) / block->y;
|
||||
if (inner > activate_block) {
|
||||
inner = activate_block;
|
||||
}
|
||||
uint32_t outer = (activate_block + inner - 1) / inner;
|
||||
if (outer > outer_size) {
|
||||
outer = outer_size;
|
||||
}
|
||||
return dim3(outer, inner);
|
||||
}
|
||||
|
||||
inline dim3 SpatialSoftMax_getBlockSize(uint64_t outer_size, uint64_t dim_size, uint64_t inner_size) {
|
||||
uint32_t inner_threads = inner_size;
|
||||
inner_threads = std::min(inner_threads, static_cast<uint32_t>(max_threads));
|
||||
inline dim3 SpatialSoftMaxGetBlockSize(uint64_t outer_size, uint64_t dim_size, uint64_t inner_size) {
|
||||
uint32_t inner_ths = inner_size;
|
||||
inner_ths = std::min(inner_ths, static_cast<uint32_t>(max_threads));
|
||||
uint32_t dim_threads = 1;
|
||||
if (inner_threads <= 64 && dim_size >= 64) {
|
||||
while (inner_threads * dim_threads <= max_threads && dim_threads <= dim_size) dim_threads *= 2;
|
||||
if (inner_ths <= 64 && dim_size >= 64) {
|
||||
while ((inner_ths * dim_threads <= max_threads) && (dim_threads <= dim_size)) {
|
||||
dim_threads *= 2;
|
||||
}
|
||||
dim_threads /= 2;
|
||||
}
|
||||
return dim3(dim_threads, inner_threads);
|
||||
return dim3(dim_threads, inner_ths);
|
||||
}
|
||||
|
||||
template <typename accumulate_t, typename Kernel>
|
||||
void SpatialSoftMax_getLaunchSizes(Kernel k, uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, dim3 *grid,
|
||||
dim3 *block, uint32_t *smem_size, uint32_t device_id) {
|
||||
*block = SpatialSoftMax_getBlockSize(outer_size, dim_size, inner_size);
|
||||
uint32_t block_threads = block->x * block->y;
|
||||
*smem_size = block->x == 1 ? 0 : block_threads * sizeof(accumulate_t);
|
||||
|
||||
int max_active_blocks;
|
||||
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, k, block_threads, *smem_size);
|
||||
void SpatialSoftMaxGetLaunchSizes(Kernel k, uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, dim3 *grid,
|
||||
dim3 *block, uint32_t *smem_size, uint32_t device_id) {
|
||||
*block = SpatialSoftMaxGetBlockSize(outer_size, dim_size, inner_size);
|
||||
uint32_t block_ths = block->x * block->y;
|
||||
if (block->x == 1) {
|
||||
*smem_size = 0;
|
||||
} else {
|
||||
*smem_size = block_ths * sizeof(accumulate_t);
|
||||
}
|
||||
int activate_size;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&activate_size, k, block_ths, *smem_size);
|
||||
cudaDeviceProp prop;
|
||||
(void)cudaGetDeviceProperties(&prop, device_id);
|
||||
max_active_blocks *= prop.multiProcessorCount;
|
||||
|
||||
*grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, dim_size, inner_size);
|
||||
activate_size *= prop.multiProcessorCount;
|
||||
*grid = SpatialSoftMaxGetGridSize(block, activate_size, outer_size, dim_size, inner_size);
|
||||
}
|
||||
|
||||
int log2_ceil(int val) {
|
||||
|
@ -135,63 +140,58 @@ __forceinline__ __device__ T SpatialBlockReduceX(T *memsha, T val) {
|
|||
return memsha[0];
|
||||
}
|
||||
|
||||
template <typename input_t, typename accumulate_t, typename output_t>
|
||||
template <typename input_t, typename accumulate_t, typename output_t, bool is_log_softmax>
|
||||
__global__ void SpatialSoftMaxForward(output_t *output, input_t *input, uint32_t outer_size, uint32_t dim_size,
|
||||
uint32_t inner_size) {
|
||||
extern __shared__ unsigned char smem[];
|
||||
auto sdata = reinterpret_cast<accumulate_t *>(smem);
|
||||
const uint32_t outer_stride = inner_size * dim_size;
|
||||
const uint32_t dim_stride = inner_size;
|
||||
|
||||
for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
|
||||
const uint32_t outer_offset = outer_index * outer_stride;
|
||||
for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size;
|
||||
inner_index += blockDim.y * gridDim.y) {
|
||||
const uint32_t data_offset = outer_offset + inner_index;
|
||||
|
||||
const uint32_t offset = outer_offset + inner_index;
|
||||
if (blockDim.x > 1) {
|
||||
accumulate_t max_input = std::numeric_limits<accumulate_t>::lowest();
|
||||
accumulate_t max_data_input = std::numeric_limits<accumulate_t>::lowest();
|
||||
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
|
||||
const accumulate_t value = static_cast<accumulate_t>(input[data_offset + d * dim_stride]);
|
||||
max_input = atomic::Max()(max_input, value);
|
||||
const accumulate_t value = static_cast<accumulate_t>(input[offset + d * dim_stride]);
|
||||
max_data_input = atomic::Max()(max_data_input, value);
|
||||
}
|
||||
max_input = SpatialBlockReduceX<accumulate_t, atomic::Max>(sdata, max_input);
|
||||
|
||||
max_data_input = SpatialBlockReduceX<accumulate_t, atomic::Max>(sdata, max_data_input);
|
||||
accumulate_t sum = 0;
|
||||
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
|
||||
sum += std::exp(static_cast<accumulate_t>(input[data_offset + d * dim_stride]) - max_input);
|
||||
sum += std::exp(static_cast<accumulate_t>(input[offset + d * dim_stride]) - max_data_input);
|
||||
sum = SpatialBlockReduceX<accumulate_t, atomic::Add>(sdata, sum);
|
||||
|
||||
SoftMaxForwardEpilogue<input_t, accumulate_t, output_t> epilogue(max_input, sum);
|
||||
SoftMaxForwardEpilogue<input_t, accumulate_t, output_t, is_log_softmax> epilogue(max_data_input, sum);
|
||||
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
|
||||
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
|
||||
output[offset + d * dim_stride] = epilogue(input[offset + d * dim_stride]);
|
||||
} else {
|
||||
accumulate_t max_input = std::numeric_limits<accumulate_t>::lowest();
|
||||
accumulate_t max_data_input = std::numeric_limits<accumulate_t>::lowest();
|
||||
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
|
||||
const accumulate_t value = static_cast<accumulate_t>(input[data_offset + d * dim_stride]);
|
||||
max_input = atomic::Max()(max_input, value);
|
||||
const accumulate_t value = static_cast<accumulate_t>(input[offset + d * dim_stride]);
|
||||
max_data_input = atomic::Max()(max_data_input, value);
|
||||
}
|
||||
accumulate_t sum = 0;
|
||||
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
|
||||
sum += std::exp(static_cast<accumulate_t>(input[data_offset + d * dim_stride]) - max_input);
|
||||
SoftMaxForwardEpilogue<input_t, accumulate_t, output_t> epilogue(max_input, sum);
|
||||
sum += std::exp(static_cast<accumulate_t>(input[offset + d * dim_stride]) - max_data_input);
|
||||
|
||||
SoftMaxForwardEpilogue<input_t, accumulate_t, output_t, is_log_softmax> epilogue(max_data_input, sum);
|
||||
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
|
||||
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
|
||||
output[offset + d * dim_stride] = epilogue(input[offset + d * dim_stride]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int InsP, typename input_t, typename accum_t, typename output_t>
|
||||
template <int InsP, typename input_t, typename accum_t, typename output_t, bool is_log_softmax>
|
||||
__device__ __forceinline__ void WriteResults(int clas, input_t *input, output_t *output, input_t max_k,
|
||||
input_t sum_all) {
|
||||
SoftMaxForwardEpilogue<input_t, accum_t, output_t> epilogue(max_k, sum_all);
|
||||
SoftMaxForwardEpilogue<input_t, accum_t, output_t, is_log_softmax> epilogue(max_k, sum_all);
|
||||
int offset = threadIdx.x;
|
||||
|
||||
int last = clas % (InsP * blockDim.x);
|
||||
for (; offset < clas - last; offset += blockDim.x * InsP) {
|
||||
input_t tmp[InsP];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < InsP; ++j) {
|
||||
tmp[j] = input[offset + j * blockDim.x];
|
||||
|
@ -201,54 +201,42 @@ __device__ __forceinline__ void WriteResults(int clas, input_t *input, output_t
|
|||
output[offset + j * blockDim.x] = epilogue(tmp[j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (; offset < clas; offset += blockDim.x) {
|
||||
output[offset] = epilogue(input[offset]);
|
||||
}
|
||||
}
|
||||
|
||||
template <int InsP, typename input_t, typename accum_t, typename output_t>
|
||||
__device__ __forceinline__ void WriteResultsVectorized(int size, const int shift, input_t *input, output_t *output,
|
||||
template <int InsP, typename input_t, typename accum_t, typename output_t, bool is_log_softmax>
|
||||
__device__ __forceinline__ void WriteResultsVectorized(int size, const int deviate, input_t *input, output_t *output,
|
||||
input_t max_k, input_t sum_all) {
|
||||
SoftMaxForwardEpilogue<input_t, accum_t, output_t> epilogue(max_k, sum_all);
|
||||
|
||||
using LoadT = aligned_vector<input_t>;
|
||||
using StoreT = aligned_vector<output_t>;
|
||||
|
||||
SoftMaxForwardEpilogue<input_t, accum_t, output_t, is_log_softmax> epilogue(max_k, sum_all);
|
||||
using loadT = aligned_vector<input_t>;
|
||||
using storeT = aligned_vector<output_t>;
|
||||
int offset = threadIdx.x;
|
||||
|
||||
if (shift > 0) {
|
||||
input -= shift;
|
||||
output -= shift;
|
||||
size += shift;
|
||||
|
||||
if (threadIdx.x >= shift) {
|
||||
if (deviate > 0) {
|
||||
input -= deviate;
|
||||
output -= deviate;
|
||||
size += deviate;
|
||||
if (threadIdx.x >= deviate) {
|
||||
output[offset] = epilogue(input[offset]);
|
||||
}
|
||||
size -= blockDim.x;
|
||||
input += blockDim.x;
|
||||
output += blockDim.x;
|
||||
}
|
||||
|
||||
const int last = size % (InsP * blockDim.x);
|
||||
|
||||
input_t in_v[InsP];
|
||||
LoadT *in_value = reinterpret_cast<LoadT *>(&in_v);
|
||||
|
||||
loadT *in_value = reinterpret_cast<loadT *>(&in_v);
|
||||
output_t out_v[InsP];
|
||||
StoreT *out_value = reinterpret_cast<StoreT *>(&out_v);
|
||||
|
||||
storeT *out_value = reinterpret_cast<storeT *>(&out_v);
|
||||
for (; offset * InsP < (size - last); offset += blockDim.x) {
|
||||
*in_value = reinterpret_cast<LoadT *>(input)[offset];
|
||||
|
||||
*in_value = reinterpret_cast<loadT *>(input)[offset];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < InsP; ++j) {
|
||||
out_v[j] = epilogue(in_v[j]);
|
||||
}
|
||||
|
||||
reinterpret_cast<StoreT *>(output)[offset] = *out_value;
|
||||
reinterpret_cast<storeT *>(output)[offset] = *out_value;
|
||||
}
|
||||
|
||||
offset = size - last + threadIdx.x;
|
||||
for (; offset < size; offset += blockDim.x) {
|
||||
output[offset] = epilogue(input[offset]);
|
||||
|
@ -256,52 +244,43 @@ __device__ __forceinline__ void WriteResultsVectorized(int size, const int shift
|
|||
}
|
||||
|
||||
template <typename Reduction, typename AccT>
|
||||
__device__ __forceinline__ AccT ReduceBlock(AccT *smem, AccT val, AccT defaultVal) {
|
||||
__device__ __forceinline__ AccT ReduceBlock(AccT *sharemen, AccT val, AccT initVal) {
|
||||
Reduction r = Reduction();
|
||||
|
||||
__syncthreads();
|
||||
|
||||
smem[threadIdx.x] = val;
|
||||
|
||||
sharemen[threadIdx.x] = val;
|
||||
__syncthreads();
|
||||
|
||||
AccT warpVal = defaultVal;
|
||||
|
||||
AccT warpVal = initVal;
|
||||
uint32_t mask = (((uint64_t)1) << (blockDim.x / WARPSIZE)) - 1;
|
||||
if (threadIdx.x < WARPSIZE) {
|
||||
int lane = threadIdx.x % WARPSIZE;
|
||||
if (lane < blockDim.x / WARPSIZE) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARPSIZE; ++i) {
|
||||
warpVal = r(warpVal, smem[lane * WARPSIZE + i]);
|
||||
warpVal = r(warpVal, sharemen[lane * WARPSIZE + i]);
|
||||
}
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
__syncwarp(mask);
|
||||
#endif
|
||||
smem[lane] = warpVal;
|
||||
sharemen[lane] = warpVal;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
AccT blockVal = defaultVal;
|
||||
|
||||
AccT blockVal = initVal;
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < blockDim.x / WARPSIZE; ++i) {
|
||||
blockVal = r(blockVal, smem[i]);
|
||||
blockVal = r(blockVal, sharemen[i]);
|
||||
}
|
||||
smem[0] = blockVal;
|
||||
sharemen[0] = blockVal;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
return smem[0];
|
||||
return sharemen[0];
|
||||
}
|
||||
|
||||
template <template <typename, typename> class Reduction, int InsP, typename T, typename AccT>
|
||||
__device__ __forceinline__ AccT ILPReduce(int shift, T *data, int size, const Reduction<T, AccT> &r, AccT defaultVal) {
|
||||
using LoadT = aligned_vector<T>;
|
||||
AccT threadVal = defaultVal;
|
||||
__device__ __forceinline__ AccT ILPReduce(int shift, T *data, int size, const Reduction<T, AccT> &r, AccT initVal) {
|
||||
using loadT = aligned_vector<T>;
|
||||
AccT threadVal = initVal;
|
||||
int offset = threadIdx.x;
|
||||
|
||||
if (shift > 0) {
|
||||
data -= shift;
|
||||
size += shift;
|
||||
|
@ -313,10 +292,9 @@ __device__ __forceinline__ AccT ILPReduce(int shift, T *data, int size, const Re
|
|||
}
|
||||
int last = size % (InsP * blockDim.x);
|
||||
T v[InsP];
|
||||
LoadT *value = reinterpret_cast<LoadT *>(&v);
|
||||
loadT *value = reinterpret_cast<loadT *>(&v);
|
||||
for (; offset * InsP < (size - last); offset += blockDim.x) {
|
||||
*value = reinterpret_cast<LoadT *>(data)[offset];
|
||||
|
||||
*value = reinterpret_cast<loadT *>(data)[offset];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < InsP; ++j) {
|
||||
threadVal = r(threadVal, v[j]);
|
||||
|
@ -324,7 +302,6 @@ __device__ __forceinline__ AccT ILPReduce(int shift, T *data, int size, const Re
|
|||
}
|
||||
offset = size - last + threadIdx.x;
|
||||
for (; offset < size; offset += blockDim.x) threadVal = r(threadVal, data[offset]);
|
||||
|
||||
return threadVal;
|
||||
}
|
||||
|
||||
|
@ -350,7 +327,6 @@ __global__ void SoftMaxWarpForward(output_t *dst, const input_t *src, int batch_
|
|||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
|
||||
int local_batches = batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
int local_idx = threadIdx.x;
|
||||
|
@ -449,7 +425,6 @@ __global__ void SoftMaxWarpForward(output_t *dst, const input_t *src, int batch_
|
|||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, atomic::Add>(sum);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
|
@ -472,17 +447,15 @@ __global__ void SoftMaxWarpForward(output_t *dst, const input_t *src, int batch_
|
|||
}
|
||||
}
|
||||
|
||||
template <int InsP, typename T, typename accumulate_t>
|
||||
template <int InsP, typename T, typename accumulate_t, bool is_log_softmax>
|
||||
__global__ void cunn_SoftMaxForward(T *output, T *input, int classes) {
|
||||
extern __shared__ unsigned char smem[];
|
||||
auto sdata = reinterpret_cast<accumulate_t *>(smem);
|
||||
|
||||
using LoadT = aligned_vector<T>;
|
||||
using StoreT = aligned_vector<T>;
|
||||
|
||||
using loadT = aligned_vector<T>;
|
||||
using storeT = aligned_vector<T>;
|
||||
input += blockIdx.x * classes;
|
||||
output += blockIdx.x * classes;
|
||||
|
||||
const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(T);
|
||||
const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(T);
|
||||
|
||||
|
@ -496,14 +469,13 @@ __global__ void cunn_SoftMaxForward(T *output, T *input, int classes) {
|
|||
accumulate_t sumAll = ReduceBlock<atomic::Add, accumulate_t>(sdata, threadExp, static_cast<accumulate_t>(0));
|
||||
|
||||
if (shift == output_shift) {
|
||||
WriteResultsVectorized<InsP, T, accumulate_t, T>(classes, shift, input, output, max_k, sumAll);
|
||||
WriteResultsVectorized<InsP, T, accumulate_t, T, is_log_softmax>(classes, shift, input, output, max_k, sumAll);
|
||||
} else {
|
||||
WriteResults<InsP, T, accumulate_t, T>(classes, input, output, max_k, sumAll);
|
||||
WriteResults<InsP, T, accumulate_t, T, is_log_softmax>(classes, input, output, max_k, sumAll);
|
||||
}
|
||||
}
|
||||
|
||||
// end of kernel function
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
|
||||
void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride,
|
||||
int batch_count, cudaStream_t stream, const bool *mask = nullptr, int chunk_size = -1,
|
||||
|
@ -567,16 +539,16 @@ void Softmax(T *input_, T *output_, size_t dim_size_, size_t outer_size_, size_t
|
|||
} else {
|
||||
constexpr int InsP = sizeof(float4) / sizeof(T);
|
||||
dim3 block = SoftMaxGetBlockSize(InsP, dim_size_);
|
||||
cunn_SoftMaxForward<InsP, T, accumulate_t>
|
||||
cunn_SoftMaxForward<InsP, T, accumulate_t, is_log_softmax>
|
||||
<<<grid, block, block.x * sizeof(T), cuda_stream>>>(output_, input_, dim_size_);
|
||||
CUDA_CHECK();
|
||||
}
|
||||
} else {
|
||||
uint32_t smem_size;
|
||||
dim3 grid, block;
|
||||
SpatialSoftMax_getLaunchSizes<T>(&SpatialSoftMaxForward<T, accumulate_t, T>, outer_size_, dim_size_, inner_size_,
|
||||
&grid, &block, &smem_size, device_id);
|
||||
SpatialSoftMaxForward<T, accumulate_t, T>
|
||||
SpatialSoftMaxGetLaunchSizes<T>(&SpatialSoftMaxForward<T, accumulate_t, T, is_log_softmax>, outer_size_, dim_size_,
|
||||
inner_size_, &grid, &block, &smem_size, device_id);
|
||||
SpatialSoftMaxForward<T, accumulate_t, T, is_log_softmax>
|
||||
<<<grid, block, smem_size, cuda_stream>>>(output_, input_, outer_size_, dim_size_, inner_size_);
|
||||
CUDA_CHECK();
|
||||
}
|
||||
|
|
|
@ -37,14 +37,14 @@ struct AccumulateType<double, true> {
|
|||
template <typename T, bool is_cuda>
|
||||
using acc_type = typename AccumulateType<T, is_cuda>::type;
|
||||
|
||||
template <typename T, typename AccumT, typename OutT>
|
||||
template <typename T, typename AccumT, typename OutT, bool is_log_softmax>
|
||||
struct SoftMaxForwardEpilogue {
|
||||
__device__ __forceinline__ SoftMaxForwardEpilogue(AccumT max_input, AccumT sum) : max_input(max_input), sum(sum) {}
|
||||
|
||||
__device__ __forceinline__ SoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
|
||||
: max_input(max_input), sum(is_log_softmax == true ? std::log(sum) : sum) {}
|
||||
__device__ __forceinline__ OutT operator()(T input) const {
|
||||
return static_cast<OutT>(std::exp((AccumT)input - max_input) / sum);
|
||||
return is_log_softmax == true ? static_cast<OutT>((AccumT)input - max_input - sum)
|
||||
: static_cast<OutT>(std::exp((AccumT)input - max_input) / sum);
|
||||
}
|
||||
|
||||
const AccumT max_input;
|
||||
const AccumT sum;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue