|
|
|
@ -16,39 +16,17 @@
|
|
|
|
|
|
|
|
|
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_minmax_impl.cuh"
|
|
|
|
|
#include <cub/cub.cuh>
|
|
|
|
|
#include <thrust/functional.h>
|
|
|
|
|
#include <thrust/tuple.h>
|
|
|
|
|
#include <thrust/device_vector.h>
|
|
|
|
|
#include <thrust/iterator/zip_iterator.h>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include "include/cuda_fp16.h"
|
|
|
|
|
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
|
|
|
|
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
using uint = unsigned int;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GetMaxSharedMemoryPerBlock(const uint32_t &device_id) {
|
|
|
|
|
int max_size = 128;
|
|
|
|
|
(void)cudaDeviceGetAttribute(&max_size, cudaDevAttrMaxSharedMemoryPerBlock, static_cast<int>(device_id));
|
|
|
|
|
return max_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GetMaxThreadsPerBlock(const uint32_t &device_id) {
|
|
|
|
|
int max_size = 128;
|
|
|
|
|
(void)cudaDeviceGetAttribute(&max_size, cudaDevAttrMaxThreadsPerBlock, static_cast<int>(device_id));
|
|
|
|
|
return max_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GetMaxGridDimX(const uint32_t &device_id) {
|
|
|
|
|
int max_size = 128;
|
|
|
|
|
(void)cudaDeviceGetAttribute(&max_size, cudaDevAttrMaxGridDimX, static_cast<int>(device_id));
|
|
|
|
|
return max_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DataType>
|
|
|
|
|
__device__ __forceinline__ bool IsNan(const DataType &x) {
|
|
|
|
|
return isnan(x);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ bool IsNan(const half &x) { return __hisnan(x); }
|
|
|
|
|
|
|
|
|
|
template <typename DataType>
|
|
|
|
|
DataType NumericMax() {
|
|
|
|
@ -72,210 +50,250 @@ half NumericMin<half>() {
|
|
|
|
|
return half(__half_raw{x});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename BinaryFunctor, typename DataType, typename IndexType>
|
|
|
|
|
__device__ __forceinline__ void Update(BinaryFunctor fun, DataType *dst_data, IndexType *dst_index, DataType src_data,
|
|
|
|
|
IndexType src_index) {
|
|
|
|
|
if (fun(src_data, *dst_data)) {
|
|
|
|
|
*dst_data = src_data;
|
|
|
|
|
*dst_index = src_index;
|
|
|
|
|
int GetMaxGridDimY(const uint32_t &device_id) {
|
|
|
|
|
int max_size = 1 << 16;
|
|
|
|
|
(void)cudaDeviceGetAttribute(&max_size, cudaDevAttrMaxGridDimY, static_cast<int>(device_id));
|
|
|
|
|
return max_size;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
template <typename DataType>
|
|
|
|
|
__device__ __forceinline__ bool IsNan(const DataType &x) {
|
|
|
|
|
return isnan(x);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ bool IsNan(const half &x) { return __hisnan(x); }
|
|
|
|
|
|
|
|
|
|
template <typename BinaryOp, typename DataType>
|
|
|
|
|
struct BinaryFunctor {
|
|
|
|
|
BinaryOp op_;
|
|
|
|
|
__device__ __forceinline__ bool operator()(DataType lhs, DataType rhs) {
|
|
|
|
|
return (IsNan(lhs) || !op_(rhs, lhs)) && !IsNan(rhs);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename BinaryFunctor, typename TupleType>
|
|
|
|
|
struct BlockScanFunctor {
|
|
|
|
|
BinaryFunctor functor_;
|
|
|
|
|
explicit BlockScanFunctor(BinaryFunctor functor) : functor_(functor) {}
|
|
|
|
|
__device__ __forceinline__ TupleType operator()(TupleType lhs, TupleType rhs) {
|
|
|
|
|
return functor_(thrust::get<0>(lhs), thrust::get<0>(rhs)) ? lhs : rhs;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Inspired by cub documentation.
|
|
|
|
|
template <typename BlockScanFunctor, typename TupleType>
|
|
|
|
|
struct BlockPrefixCallbackFunctor {
|
|
|
|
|
BlockScanFunctor functor_;
|
|
|
|
|
TupleType block_aggregate_;
|
|
|
|
|
// Constructor
|
|
|
|
|
__device__ BlockPrefixCallbackFunctor(BlockScanFunctor functor, TupleType block_aggregate)
|
|
|
|
|
: functor_(functor), block_aggregate_(block_aggregate) {}
|
|
|
|
|
// Callback operator to be entered by the first warp of threads in the block.
|
|
|
|
|
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
|
|
|
|
__device__ __forceinline__ TupleType operator()(TupleType block_aggregate) {
|
|
|
|
|
TupleType old_block_aggregate = block_aggregate_;
|
|
|
|
|
block_aggregate_ = functor_(old_block_aggregate, block_aggregate);
|
|
|
|
|
return old_block_aggregate;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename BlockScanFunctor, typename ValueType, typename IndexType, uint BlockDim>
|
|
|
|
|
__global__ void LargeBlockScanKernel(BlockScanFunctor functor, const ValueType *input_ptr, ValueType *value_ptr,
|
|
|
|
|
IndexType *index_ptr, uint axis_size, uint inner_size, uint axis_inner_size,
|
|
|
|
|
uint outer_inner_size, ValueType init) {
|
|
|
|
|
typedef thrust::tuple<ValueType, IndexType> DataType;
|
|
|
|
|
typedef cub::BlockScan<DataType, BlockDim> BlockScan;
|
|
|
|
|
__shared__ typename BlockScan::TempStorage share_data;
|
|
|
|
|
for (uint bid = blockIdx.x; bid < outer_inner_size; bid += gridDim.x) {
|
|
|
|
|
uint outer_idx = bid / inner_size;
|
|
|
|
|
uint inner_idx = bid % inner_size;
|
|
|
|
|
DataType init_data{init, 0};
|
|
|
|
|
BlockPrefixCallbackFunctor<BlockScanFunctor, DataType> cb_functor{functor, init_data};
|
|
|
|
|
uint axis_idx = threadIdx.x;
|
|
|
|
|
uint axis_offset = outer_idx * axis_inner_size + inner_idx + axis_idx * inner_size;
|
|
|
|
|
for (uint block_offset = 0; block_offset < axis_size; block_offset += BlockDim) {
|
|
|
|
|
DataType thread_data = init_data;
|
|
|
|
|
if (axis_idx < axis_size) {
|
|
|
|
|
thread_data = thrust::make_tuple(input_ptr[axis_offset], axis_idx);
|
|
|
|
|
}
|
|
|
|
|
BlockScan(share_data).template InclusiveScan(thread_data, thread_data, functor, cb_functor);
|
|
|
|
|
__syncthreads();
|
|
|
|
|
if (axis_idx < axis_size) {
|
|
|
|
|
thrust::tie(value_ptr[axis_offset], index_ptr[axis_offset]) = thread_data;
|
|
|
|
|
}
|
|
|
|
|
axis_idx += BlockDim;
|
|
|
|
|
axis_offset += BlockDim * inner_size;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename BinaryFunctor, typename DataType, typename IndexType>
|
|
|
|
|
__global__ void CumMinMaxKernel(BinaryFunctor fun, const DataType *input_ptr, DataType *value_ptr, IndexType *index_ptr,
|
|
|
|
|
uint axis_size, uint inner_size, uint axis_inner_size, uint outer_inner_size,
|
|
|
|
|
DataType init) {
|
|
|
|
|
uint tid = threadIdx.y;
|
|
|
|
|
uint tid_d = tid << 1; // The suffix `d` represents double.
|
|
|
|
|
uint scan_per_block = blockDim.y * 2;
|
|
|
|
|
extern __shared__ char share_data[];
|
|
|
|
|
auto total_value_size = sizeof(DataType) * blockDim.x * scan_per_block;
|
|
|
|
|
auto share_value_ptr = reinterpret_cast<DataType *>(share_data) + threadIdx.x * scan_per_block;
|
|
|
|
|
auto share_index_ptr = reinterpret_cast<IndexType *>(share_data + total_value_size) + threadIdx.x * scan_per_block;
|
|
|
|
|
for (uint bid = threadIdx.x + blockIdx.x * blockDim.x; bid < outer_inner_size; bid += blockDim.x * gridDim.x) {
|
|
|
|
|
uint outer_idx = bid / inner_size;
|
|
|
|
|
uint inner_idx = bid % inner_size;
|
|
|
|
|
uint outer_inner_offset = outer_idx * axis_inner_size + inner_idx;
|
|
|
|
|
auto cur_input_ptr = input_ptr + outer_inner_offset;
|
|
|
|
|
auto cur_value_ptr = value_ptr + outer_inner_offset;
|
|
|
|
|
auto cur_index_ptr = index_ptr + outer_inner_offset;
|
|
|
|
|
DataType block_value = init;
|
|
|
|
|
IndexType block_index = 0;
|
|
|
|
|
// Each iteration processes (2 * blockDim.y) elements, since share memory typically larger than thread number of
|
|
|
|
|
// each block.
|
|
|
|
|
for (uint cid = 0; cid < axis_size; cid += scan_per_block) {
|
|
|
|
|
// The following parallel scan algorithm refers to:
|
|
|
|
|
// Figure 9.7 from David B. Kirk, et al. 'Programming Massively Parallel Processors'.
|
|
|
|
|
uint axis_idx = cid + tid_d;
|
|
|
|
|
uint axis_offset = axis_idx * inner_size;
|
|
|
|
|
// Initializing share memory with input value.
|
|
|
|
|
if (axis_idx < axis_size) {
|
|
|
|
|
share_value_ptr[tid_d] = cur_input_ptr[axis_offset];
|
|
|
|
|
share_index_ptr[tid_d] = axis_idx;
|
|
|
|
|
} else {
|
|
|
|
|
share_value_ptr[tid_d] = init;
|
|
|
|
|
}
|
|
|
|
|
if (axis_idx + 1 < axis_size) {
|
|
|
|
|
share_value_ptr[tid_d + 1] = cur_input_ptr[axis_offset + inner_size];
|
|
|
|
|
share_index_ptr[tid_d + 1] = axis_idx + 1;
|
|
|
|
|
} else {
|
|
|
|
|
share_value_ptr[tid_d + 1] = init;
|
|
|
|
|
}
|
|
|
|
|
// update with previous block result.
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
Update(fun, share_value_ptr, share_index_ptr, block_value, block_index);
|
|
|
|
|
template <typename BlockScanFunctor, typename ValueType, typename IndexType, uint BlockDimX, uint BlockDimY>
|
|
|
|
|
__global__ void ScanInnerMostDimKernel(BlockScanFunctor functor, const ValueType *input_ptr, ValueType *value_ptr,
|
|
|
|
|
IndexType *index_ptr, uint outer_size, uint axis_size, ValueType init) {
|
|
|
|
|
typedef thrust::tuple<ValueType, IndexType> DataType;
|
|
|
|
|
constexpr uint scan_per_block = BlockDimX * 2;
|
|
|
|
|
__shared__ ValueType share_value[BlockDimY][scan_per_block];
|
|
|
|
|
__shared__ IndexType share_index[BlockDimY][scan_per_block];
|
|
|
|
|
auto share_value_ptr = share_value[threadIdx.y];
|
|
|
|
|
auto share_index_ptr = share_index[threadIdx.y];
|
|
|
|
|
for (uint bid = blockIdx.x * blockDim.y; bid < outer_size; bid += gridDim.x * blockDim.y) {
|
|
|
|
|
uint outer_idx = bid + threadIdx.y;
|
|
|
|
|
bool is_valid = outer_idx < outer_size;
|
|
|
|
|
uint offset = outer_idx * axis_size;
|
|
|
|
|
DataType block_data{init, 0};
|
|
|
|
|
// The following parallel scan algorithm refers to:
|
|
|
|
|
// Figure 9.7 from David B. Kirk, et al. 'Programming Massively Parallel Processors'.
|
|
|
|
|
for (uint i = 0; i < axis_size; i += scan_per_block) {
|
|
|
|
|
// Initializing share memory with input value, and each thread process two elements.
|
|
|
|
|
uint idx1 = threadIdx.x + i;
|
|
|
|
|
uint idx2 = idx1 + BlockDimX;
|
|
|
|
|
if (is_valid) {
|
|
|
|
|
if (idx1 < axis_size) {
|
|
|
|
|
share_value_ptr[threadIdx.x] = input_ptr[offset + idx1];
|
|
|
|
|
share_index_ptr[threadIdx.x] = idx1;
|
|
|
|
|
} else {
|
|
|
|
|
share_value_ptr[threadIdx.x] = init;
|
|
|
|
|
}
|
|
|
|
|
if (idx2 < axis_size) {
|
|
|
|
|
share_value_ptr[threadIdx.x + BlockDimX] = input_ptr[offset + idx2];
|
|
|
|
|
share_index_ptr[threadIdx.x + BlockDimX] = idx2;
|
|
|
|
|
} else {
|
|
|
|
|
share_value_ptr[threadIdx.x + BlockDimX] = init;
|
|
|
|
|
}
|
|
|
|
|
// update with previous block result.
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
thrust::tie(share_value_ptr[0], share_index_ptr[0]) =
|
|
|
|
|
functor(thrust::make_tuple(share_value_ptr[0], share_index_ptr[0]), block_data);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// up-sweep
|
|
|
|
|
for (uint stride = 1; stride < scan_per_block; stride <<= 1) {
|
|
|
|
|
__syncthreads();
|
|
|
|
|
uint index = (tid + 1) * (stride << 1) - 1;
|
|
|
|
|
if (index < scan_per_block) {
|
|
|
|
|
Update(fun, share_value_ptr + index, share_index_ptr + index, share_value_ptr[index - stride],
|
|
|
|
|
share_index_ptr[index - stride]);
|
|
|
|
|
uint index = (threadIdx.x + 1) * (stride << 1) - 1;
|
|
|
|
|
if (is_valid && index < scan_per_block) {
|
|
|
|
|
thrust::tie(share_value_ptr[index], share_index_ptr[index]) =
|
|
|
|
|
functor(thrust::make_tuple(share_value_ptr[index - stride], share_index_ptr[index - stride]),
|
|
|
|
|
thrust::make_tuple(share_value_ptr[index], share_index_ptr[index]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// down-sweep
|
|
|
|
|
for (uint stride = scan_per_block >> 2; stride > 0; stride >>= 1) {
|
|
|
|
|
__syncthreads();
|
|
|
|
|
uint index = (tid + 1) * (stride << 1) - 1;
|
|
|
|
|
if (index + stride < scan_per_block) {
|
|
|
|
|
Update(fun, share_value_ptr + (index + stride), share_index_ptr + (index + stride), share_value_ptr[index],
|
|
|
|
|
share_index_ptr[index]);
|
|
|
|
|
uint index = (threadIdx.x + 1) * (stride << 1) - 1;
|
|
|
|
|
if (is_valid && index + stride < scan_per_block) {
|
|
|
|
|
thrust::tie(share_value_ptr[index + stride], share_index_ptr[index + stride]) =
|
|
|
|
|
functor(thrust::make_tuple(share_value_ptr[index], share_index_ptr[index]),
|
|
|
|
|
thrust::make_tuple(share_value_ptr[index + stride], share_index_ptr[index + stride]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// write to output.
|
|
|
|
|
__syncthreads();
|
|
|
|
|
if (axis_idx < axis_size) {
|
|
|
|
|
cur_value_ptr[axis_offset] = share_value_ptr[tid_d];
|
|
|
|
|
cur_index_ptr[axis_offset] = share_index_ptr[tid_d];
|
|
|
|
|
if (is_valid) {
|
|
|
|
|
if (idx1 < axis_size) {
|
|
|
|
|
value_ptr[offset + idx1] = share_value_ptr[threadIdx.x];
|
|
|
|
|
index_ptr[offset + idx1] = share_index_ptr[threadIdx.x];
|
|
|
|
|
}
|
|
|
|
|
if (idx2 < axis_size) {
|
|
|
|
|
value_ptr[offset + idx2] = share_value_ptr[threadIdx.x + BlockDimX];
|
|
|
|
|
index_ptr[offset + idx2] = share_index_ptr[threadIdx.x + BlockDimX];
|
|
|
|
|
}
|
|
|
|
|
// update block_data
|
|
|
|
|
block_data = thrust::make_tuple(share_value_ptr[scan_per_block - 1], share_index_ptr[scan_per_block - 1]);
|
|
|
|
|
}
|
|
|
|
|
if (axis_idx + 1 < axis_size) {
|
|
|
|
|
cur_value_ptr[axis_offset + inner_size] = share_value_ptr[tid_d + 1];
|
|
|
|
|
cur_index_ptr[axis_offset + inner_size] = share_index_ptr[tid_d + 1];
|
|
|
|
|
}
|
|
|
|
|
// update block_value & block_index
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
block_value = share_value_ptr[scan_per_block - 1];
|
|
|
|
|
block_index = share_index_ptr[scan_per_block - 1];
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename BinaryFunctor, typename DataType, typename IndexType>
|
|
|
|
|
struct IndexFunctor {
|
|
|
|
|
const DataType *input_ptr_;
|
|
|
|
|
BinaryFunctor functor_;
|
|
|
|
|
explicit IndexFunctor(const DataType *input_ptr, BinaryFunctor functor) : input_ptr_(input_ptr), functor_(functor) {}
|
|
|
|
|
__device__ __forceinline__ IndexType operator()(IndexType x, IndexType y) {
|
|
|
|
|
auto lhs = input_ptr_[x];
|
|
|
|
|
auto rhs = input_ptr_[y];
|
|
|
|
|
return functor_(lhs, rhs) ? x : y;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename BinaryFunctor, typename DataType>
|
|
|
|
|
struct ValueFunctor {
|
|
|
|
|
BinaryFunctor functor_;
|
|
|
|
|
explicit ValueFunctor(BinaryFunctor functor) : functor_(functor) {}
|
|
|
|
|
__device__ __forceinline__ DataType operator()(DataType lhs, DataType rhs) { return functor_(lhs, rhs) ? lhs : rhs; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename BinaryOp, typename DataType>
|
|
|
|
|
struct BinaryFunctor {
|
|
|
|
|
BinaryOp binary_op_;
|
|
|
|
|
__device__ __forceinline__ bool operator()(DataType lhs, DataType rhs) {
|
|
|
|
|
return !IsNan(rhs) && (IsNan(lhs) || !binary_op_(rhs, lhs));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename BinaryFunctor, typename DataType, typename IndexType>
|
|
|
|
|
__global__ void CumMinMaxSlowKernel(BinaryFunctor functor, const DataType *input_ptr, DataType *value_ptr,
|
|
|
|
|
IndexType *index_ptr, uint axis_size, uint inner_size, uint axis_inner_size,
|
|
|
|
|
uint outer_inner_size) {
|
|
|
|
|
for (uint tid = blockIdx.x * blockDim.x + threadIdx.x; tid < outer_inner_size; tid += blockDim.x * gridDim.x) {
|
|
|
|
|
uint outer_idx = tid / inner_size;
|
|
|
|
|
uint inner_idx = tid % inner_size;
|
|
|
|
|
template <typename BlockScanFunctor, typename ValueType, typename IndexType>
|
|
|
|
|
__global__ void ScanOuterDimKernel(BlockScanFunctor functor, const ValueType *input_ptr, ValueType *value_ptr,
|
|
|
|
|
IndexType *index_ptr, uint axis_size, uint inner_size, uint axis_inner_size,
|
|
|
|
|
uint outer_inner_size, ValueType init) {
|
|
|
|
|
typedef thrust::tuple<ValueType, IndexType> DataType;
|
|
|
|
|
for (uint bid = blockIdx.x * blockDim.x + threadIdx.x; bid < outer_inner_size; bid += gridDim.x * blockDim.x) {
|
|
|
|
|
uint outer_idx = bid / inner_size;
|
|
|
|
|
uint inner_idx = bid % inner_size;
|
|
|
|
|
DataType out{init, 0};
|
|
|
|
|
uint offset = outer_idx * axis_inner_size + inner_idx;
|
|
|
|
|
auto cur_input_ptr = input_ptr + offset;
|
|
|
|
|
auto cur_value_ptr = value_ptr + offset;
|
|
|
|
|
auto cur_index_ptr = index_ptr + offset;
|
|
|
|
|
DataType out_val = *cur_value_ptr = *cur_input_ptr;
|
|
|
|
|
IndexType out_idx = *cur_index_ptr = 0;
|
|
|
|
|
for (uint j = 1; j < axis_size; j++) {
|
|
|
|
|
cur_input_ptr += inner_size;
|
|
|
|
|
cur_value_ptr += inner_size;
|
|
|
|
|
cur_index_ptr += inner_size;
|
|
|
|
|
DataType cur_val = *cur_input_ptr;
|
|
|
|
|
if (!functor(out_val, cur_val)) {
|
|
|
|
|
out_val = cur_val;
|
|
|
|
|
out_idx = static_cast<IndexType>(j);
|
|
|
|
|
}
|
|
|
|
|
*cur_value_ptr = out_val;
|
|
|
|
|
*cur_index_ptr = out_idx;
|
|
|
|
|
for (uint i = 0; i < axis_size; i++) {
|
|
|
|
|
DataType thread_data = thrust::make_tuple(input_ptr[offset], i);
|
|
|
|
|
out = functor(out, thread_data);
|
|
|
|
|
thrust::tie(value_ptr[offset], index_ptr[offset]) = out;
|
|
|
|
|
offset += inner_size;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename BinaryFunctor, typename DataType, typename IndexType>
|
|
|
|
|
void KernelHelper(BinaryFunctor fun, DataType init, const DataType *input_ptr, DataType *value_ptr,
|
|
|
|
|
template <typename BinaryFunctor, typename ValueType, typename IndexType>
|
|
|
|
|
void KernelHelper(BinaryFunctor functor, ValueType init, const ValueType *input_ptr, ValueType *value_ptr,
|
|
|
|
|
IndexType *index_ptr, size_t outer_size_st, size_t axis_size_st, size_t inner_size_st,
|
|
|
|
|
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
|
|
|
|
if (outer_size_st == 1 && inner_size_st == 1) {
|
|
|
|
|
// Special case where only one dimension that needs to compute, so using cub library is the most efficient way.
|
|
|
|
|
ValueFunctor<BinaryFunctor, DataType> value_fun{fun};
|
|
|
|
|
IndexFunctor<BinaryFunctor, DataType, IndexType> index_fun{input_ptr, fun};
|
|
|
|
|
size_t value_storage_bytes = 0;
|
|
|
|
|
size_t index_storage_bytes = 0;
|
|
|
|
|
auto outer_size = static_cast<uint>(outer_size_st);
|
|
|
|
|
auto inner_size = static_cast<uint>(inner_size_st);
|
|
|
|
|
auto axis_size = static_cast<uint>(axis_size_st);
|
|
|
|
|
auto outer_inner_size = outer_size * inner_size;
|
|
|
|
|
auto axis_inner_size = axis_size * inner_size;
|
|
|
|
|
uint max_grid_size = GetMaxGridDimY(device_id);
|
|
|
|
|
typedef BlockScanFunctor<BinaryFunctor, thrust::tuple<ValueType, IndexType>> BlockScanFunctor;
|
|
|
|
|
BlockScanFunctor scan_op{functor};
|
|
|
|
|
#if defined(CUB_VERSION) && (CUB_VERSION > 100800)
|
|
|
|
|
// Special case where only one dimension that needs to compute, so using cub library is the most efficient way.
|
|
|
|
|
if (outer_size == 1 && inner_size == 1) {
|
|
|
|
|
// Using thrust::zip_iterator to make an iterator for (ValueType, IndexType).
|
|
|
|
|
cub::CountingInputIterator<IndexType> count_iter(0);
|
|
|
|
|
(void)cub::DeviceScan::InclusiveScan(nullptr, value_storage_bytes, input_ptr, value_ptr, value_fun, axis_size_st,
|
|
|
|
|
cuda_stream);
|
|
|
|
|
(void)cub::DeviceScan::InclusiveScan(nullptr, index_storage_bytes, count_iter, index_ptr, index_fun, axis_size_st,
|
|
|
|
|
cuda_stream);
|
|
|
|
|
// Here only allocate once.
|
|
|
|
|
char *temp_storage_ptr = nullptr;
|
|
|
|
|
(void)cudaMalloc(&temp_storage_ptr, value_storage_bytes + index_storage_bytes);
|
|
|
|
|
void *value_storage_ptr = reinterpret_cast<void *>(temp_storage_ptr);
|
|
|
|
|
void *index_storage_ptr = reinterpret_cast<void *>(temp_storage_ptr + value_storage_bytes);
|
|
|
|
|
typedef typename thrust::detail::normal_iterator<const ValueType *> InputValueIterator;
|
|
|
|
|
typedef cub::CountingInputIterator<IndexType> InputIndexIterator;
|
|
|
|
|
typedef thrust::zip_iterator<thrust::tuple<InputValueIterator, InputIndexIterator>> InputZipIterator;
|
|
|
|
|
InputZipIterator input_iter(thrust::make_tuple(input_ptr, count_iter));
|
|
|
|
|
|
|
|
|
|
(void)cub::DeviceScan::InclusiveScan(value_storage_ptr, value_storage_bytes, input_ptr, value_ptr, value_fun,
|
|
|
|
|
axis_size_st, cuda_stream);
|
|
|
|
|
(void)cub::DeviceScan::InclusiveScan(index_storage_ptr, index_storage_bytes, count_iter, index_ptr, index_fun,
|
|
|
|
|
axis_size_st, cuda_stream);
|
|
|
|
|
typedef typename thrust::detail::normal_iterator<ValueType *> OutputValueIterator;
|
|
|
|
|
typedef typename thrust::detail::normal_iterator<IndexType *> OutputIndexIterator;
|
|
|
|
|
typedef thrust::zip_iterator<thrust::tuple<OutputValueIterator, OutputIndexIterator>> OutputZipIterator;
|
|
|
|
|
OutputZipIterator output_iter(thrust::make_tuple(value_ptr, index_ptr));
|
|
|
|
|
|
|
|
|
|
// Calculate the size of temporary storage.
|
|
|
|
|
size_t temp_storage_bytes = 0;
|
|
|
|
|
(void)cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes, input_iter, output_iter, scan_op, axis_size,
|
|
|
|
|
cuda_stream);
|
|
|
|
|
// Allocate temporary storage.
|
|
|
|
|
char *temp_storage_ptr = nullptr;
|
|
|
|
|
(void)cudaMalloc(&temp_storage_ptr, temp_storage_bytes);
|
|
|
|
|
// Core computation process.
|
|
|
|
|
(void)cub::DeviceScan::InclusiveScan(temp_storage_ptr, temp_storage_bytes, input_iter, output_iter, scan_op,
|
|
|
|
|
axis_size, cuda_stream);
|
|
|
|
|
(void)cudaFree(temp_storage_ptr);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// When computing capacity of CUDA is not recommended (<7), we instead use self-implemented scan algorithm.
|
|
|
|
|
// Otherwise, we use cub::BlockScan, which is faster than self-implemented one.
|
|
|
|
|
const int major_sm = GET_MAJOR_SM;
|
|
|
|
|
const bool check_sm = mindspore::device::gpu::CudaCommon::GetInstance().check_sm();
|
|
|
|
|
constexpr uint threshold_large_scan_dim = 500;
|
|
|
|
|
if (!(check_sm && major_sm < RECOMMEND_SM) && axis_size > threshold_large_scan_dim) {
|
|
|
|
|
constexpr uint block_dim = 512;
|
|
|
|
|
uint grid_x = std::min(outer_inner_size, max_grid_size);
|
|
|
|
|
dim3 block{block_dim};
|
|
|
|
|
dim3 grid{grid_x};
|
|
|
|
|
LargeBlockScanKernel<BlockScanFunctor, ValueType, IndexType, block_dim><<<grid, block, 0, cuda_stream>>>(
|
|
|
|
|
scan_op, input_ptr, value_ptr, index_ptr, axis_size, inner_size, axis_inner_size, outer_inner_size, init);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
if (inner_size == 1) {
|
|
|
|
|
constexpr uint block_dim_x = 32;
|
|
|
|
|
constexpr uint block_dim_y = 16;
|
|
|
|
|
// The reason why x-dimension of block is set to 32:
|
|
|
|
|
// Each thread process 2 elements, so each x-dimension of block process 64 elements. An obvious advantage is no
|
|
|
|
|
// bank conflict. In addition, we don't need `__syncthreads`, since 32 is equal to warp size.
|
|
|
|
|
uint grid_x = std::min(UP_DIV(outer_size, block_dim_y), max_grid_size);
|
|
|
|
|
dim3 block = {block_dim_x, block_dim_y};
|
|
|
|
|
dim3 grid = {grid_x};
|
|
|
|
|
ScanInnerMostDimKernel<BlockScanFunctor, ValueType, IndexType, block_dim_x, block_dim_y>
|
|
|
|
|
<<<grid, block, 0, cuda_stream>>>(scan_op, input_ptr, value_ptr, index_ptr, outer_size, axis_size, init);
|
|
|
|
|
} else {
|
|
|
|
|
auto outer_size = static_cast<uint>(outer_size_st);
|
|
|
|
|
auto inner_size = static_cast<uint>(inner_size_st);
|
|
|
|
|
auto axis_size = static_cast<uint>(axis_size_st);
|
|
|
|
|
auto outer_inner_size = outer_size * inner_size;
|
|
|
|
|
auto axis_inner_size = axis_size * inner_size;
|
|
|
|
|
if (inner_size_st == 1) {
|
|
|
|
|
// The partitioning strategy is as follows:
|
|
|
|
|
// 1. The block has two dimensions, the y dimension with max size is 128, scan an array with axis_size, while the
|
|
|
|
|
// other one is used to process batch dimension on parallel, and the specific size depends on the max size of
|
|
|
|
|
// shared memory and max threads number.
|
|
|
|
|
// 2. The gird has only one dimension, which requires to take over the remaining batch dimension.
|
|
|
|
|
constexpr uint max_block_y = 128;
|
|
|
|
|
uint max_share_size = GetMaxSharedMemoryPerBlock(device_id);
|
|
|
|
|
uint max_thread_size = GetMaxThreadsPerBlock(device_id);
|
|
|
|
|
uint max_grid_size = GetMaxGridDimX(device_id);
|
|
|
|
|
uint axis_power2 = 1u << Log2Ceil(axis_size);
|
|
|
|
|
uint block_y = std::min(max_block_y, axis_power2);
|
|
|
|
|
uint has_allocate = block_y * 2 * (sizeof(DataType) + sizeof(IndexType));
|
|
|
|
|
uint block_x = std::min(max_thread_size / block_y, max_share_size / has_allocate);
|
|
|
|
|
uint grid_x = std::min(max_grid_size, UP_DIV(outer_inner_size, block_x));
|
|
|
|
|
dim3 block = {block_x, block_y};
|
|
|
|
|
dim3 grid = {grid_x};
|
|
|
|
|
uint share_size = block_x * has_allocate;
|
|
|
|
|
CumMinMaxKernel<BinaryFunctor, DataType, IndexType><<<grid, block, share_size, cuda_stream>>>(
|
|
|
|
|
fun, input_ptr, value_ptr, index_ptr, axis_size, inner_size, axis_inner_size, outer_inner_size, init);
|
|
|
|
|
} else {
|
|
|
|
|
// A useless case. If you don't like this branch, please delete it.
|
|
|
|
|
CumMinMaxSlowKernel<BinaryFunctor, DataType, IndexType>
|
|
|
|
|
<<<CUDA_BLOCKS(device_id, outer_inner_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
|
|
|
|
fun, input_ptr, value_ptr, index_ptr, axis_size, inner_size, axis_inner_size, outer_inner_size);
|
|
|
|
|
}
|
|
|
|
|
constexpr uint block_dim = 512;
|
|
|
|
|
uint grid_x = std::min(UP_DIV(outer_inner_size, block_dim), max_grid_size);
|
|
|
|
|
dim3 block{block_dim};
|
|
|
|
|
dim3 grid{grid_x};
|
|
|
|
|
ScanOuterDimKernel<<<grid, block, 0, cuda_stream>>>(scan_op, input_ptr, value_ptr, index_ptr, axis_size, inner_size,
|
|
|
|
|
axis_inner_size, outer_inner_size, init);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|