!39281 Add CumSum dynamic shape for ascend platform.

Merge pull request !39281 from hezhenhao1/opt_cumop
This commit is contained in:
i-robot 2022-08-05 00:11:23 +00:00 committed by Gitee
commit 37e6676f1f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 299 additions and 208 deletions

View File

@ -15,6 +15,7 @@
*/
#include "plugin/device/gpu/kernel/arrays/index_fill_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
namespace kernel {
@ -23,6 +24,9 @@ constexpr size_t kIndexFillInputsNum = 4;
constexpr size_t kIndexFillOutputsNum = 1;
} // namespace
template <typename T>
using Complex = mindspore::utils::Complex<T>;
bool IndexFillGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kIndexFillInputsNum, kernel_name_);
@ -232,6 +236,20 @@ std::vector<std::pair<KernelAttr, IndexFillGpuKernelMod::IndexFillLaunchFunc>> I
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&IndexFillGpuKernelMod::LaunchKernel<double, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&IndexFillGpuKernelMod::LaunchKernel<Complex<float>, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&IndexFillGpuKernelMod::LaunchKernel<Complex<double>, int>},
};
std::vector<KernelAttr> IndexFillGpuKernelMod::GetOpSupport() {

View File

@ -16,39 +16,17 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_minmax_impl.cuh"
#include <cub/cub.cuh>
#include <thrust/functional.h>
#include <thrust/tuple.h>
#include <thrust/device_vector.h>
#include <thrust/iterator/zip_iterator.h>
#include <algorithm>
#include <limits>
#include "include/cuda_fp16.h"
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
namespace {
using uint = unsigned int;
}
int GetMaxSharedMemoryPerBlock(const uint32_t &device_id) {
int max_size = 128;
(void)cudaDeviceGetAttribute(&max_size, cudaDevAttrMaxSharedMemoryPerBlock, static_cast<int>(device_id));
return max_size;
}
int GetMaxThreadsPerBlock(const uint32_t &device_id) {
int max_size = 128;
(void)cudaDeviceGetAttribute(&max_size, cudaDevAttrMaxThreadsPerBlock, static_cast<int>(device_id));
return max_size;
}
int GetMaxGridDimX(const uint32_t &device_id) {
int max_size = 128;
(void)cudaDeviceGetAttribute(&max_size, cudaDevAttrMaxGridDimX, static_cast<int>(device_id));
return max_size;
}
template <typename DataType>
__device__ __forceinline__ bool IsNan(const DataType &x) {
return isnan(x);
}
__device__ __forceinline__ bool IsNan(const half &x) { return __hisnan(x); }
template <typename DataType>
DataType NumericMax() {
@ -72,210 +50,250 @@ half NumericMin<half>() {
return half(__half_raw{x});
}
template <typename BinaryFunctor, typename DataType, typename IndexType>
__device__ __forceinline__ void Update(BinaryFunctor fun, DataType *dst_data, IndexType *dst_index, DataType src_data,
IndexType src_index) {
if (fun(src_data, *dst_data)) {
*dst_data = src_data;
*dst_index = src_index;
int GetMaxGridDimY(const uint32_t &device_id) {
int max_size = 1 << 16;
(void)cudaDeviceGetAttribute(&max_size, cudaDevAttrMaxGridDimY, static_cast<int>(device_id));
return max_size;
}
} // namespace
template <typename DataType>
__device__ __forceinline__ bool IsNan(const DataType &x) {
return isnan(x);
}
__device__ __forceinline__ bool IsNan(const half &x) { return __hisnan(x); }
template <typename BinaryOp, typename DataType>
struct BinaryFunctor {
BinaryOp op_;
__device__ __forceinline__ bool operator()(DataType lhs, DataType rhs) {
return (IsNan(lhs) || !op_(rhs, lhs)) && !IsNan(rhs);
}
};
template <typename BinaryFunctor, typename TupleType>
struct BlockScanFunctor {
BinaryFunctor functor_;
explicit BlockScanFunctor(BinaryFunctor functor) : functor_(functor) {}
__device__ __forceinline__ TupleType operator()(TupleType lhs, TupleType rhs) {
return functor_(thrust::get<0>(lhs), thrust::get<0>(rhs)) ? lhs : rhs;
}
};
// Inspired by cub documentation.
template <typename BlockScanFunctor, typename TupleType>
struct BlockPrefixCallbackFunctor {
BlockScanFunctor functor_;
TupleType block_aggregate_;
// Constructor
__device__ BlockPrefixCallbackFunctor(BlockScanFunctor functor, TupleType block_aggregate)
: functor_(functor), block_aggregate_(block_aggregate) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ __forceinline__ TupleType operator()(TupleType block_aggregate) {
TupleType old_block_aggregate = block_aggregate_;
block_aggregate_ = functor_(old_block_aggregate, block_aggregate);
return old_block_aggregate;
}
};
template <typename BlockScanFunctor, typename ValueType, typename IndexType, uint BlockDim>
__global__ void LargeBlockScanKernel(BlockScanFunctor functor, const ValueType *input_ptr, ValueType *value_ptr,
IndexType *index_ptr, uint axis_size, uint inner_size, uint axis_inner_size,
uint outer_inner_size, ValueType init) {
typedef thrust::tuple<ValueType, IndexType> DataType;
typedef cub::BlockScan<DataType, BlockDim> BlockScan;
__shared__ typename BlockScan::TempStorage share_data;
for (uint bid = blockIdx.x; bid < outer_inner_size; bid += gridDim.x) {
uint outer_idx = bid / inner_size;
uint inner_idx = bid % inner_size;
DataType init_data{init, 0};
BlockPrefixCallbackFunctor<BlockScanFunctor, DataType> cb_functor{functor, init_data};
uint axis_idx = threadIdx.x;
uint axis_offset = outer_idx * axis_inner_size + inner_idx + axis_idx * inner_size;
for (uint block_offset = 0; block_offset < axis_size; block_offset += BlockDim) {
DataType thread_data = init_data;
if (axis_idx < axis_size) {
thread_data = thrust::make_tuple(input_ptr[axis_offset], axis_idx);
}
BlockScan(share_data).template InclusiveScan(thread_data, thread_data, functor, cb_functor);
__syncthreads();
if (axis_idx < axis_size) {
thrust::tie(value_ptr[axis_offset], index_ptr[axis_offset]) = thread_data;
}
axis_idx += BlockDim;
axis_offset += BlockDim * inner_size;
}
}
}
template <typename BinaryFunctor, typename DataType, typename IndexType>
__global__ void CumMinMaxKernel(BinaryFunctor fun, const DataType *input_ptr, DataType *value_ptr, IndexType *index_ptr,
uint axis_size, uint inner_size, uint axis_inner_size, uint outer_inner_size,
DataType init) {
uint tid = threadIdx.y;
uint tid_d = tid << 1; // The suffix `d` represents double.
uint scan_per_block = blockDim.y * 2;
extern __shared__ char share_data[];
auto total_value_size = sizeof(DataType) * blockDim.x * scan_per_block;
auto share_value_ptr = reinterpret_cast<DataType *>(share_data) + threadIdx.x * scan_per_block;
auto share_index_ptr = reinterpret_cast<IndexType *>(share_data + total_value_size) + threadIdx.x * scan_per_block;
for (uint bid = threadIdx.x + blockIdx.x * blockDim.x; bid < outer_inner_size; bid += blockDim.x * gridDim.x) {
uint outer_idx = bid / inner_size;
uint inner_idx = bid % inner_size;
uint outer_inner_offset = outer_idx * axis_inner_size + inner_idx;
auto cur_input_ptr = input_ptr + outer_inner_offset;
auto cur_value_ptr = value_ptr + outer_inner_offset;
auto cur_index_ptr = index_ptr + outer_inner_offset;
DataType block_value = init;
IndexType block_index = 0;
// Each iteration processes (2 * blockDim.y) elements, since share memory typically larger than thread number of
// each block.
for (uint cid = 0; cid < axis_size; cid += scan_per_block) {
// The following parallel scan algorithm refers to:
// Figure 9.7 from David B. Kirk, et al. 'Programming Massively Parallel Processors'.
uint axis_idx = cid + tid_d;
uint axis_offset = axis_idx * inner_size;
// Initializing share memory with input value.
if (axis_idx < axis_size) {
share_value_ptr[tid_d] = cur_input_ptr[axis_offset];
share_index_ptr[tid_d] = axis_idx;
} else {
share_value_ptr[tid_d] = init;
}
if (axis_idx + 1 < axis_size) {
share_value_ptr[tid_d + 1] = cur_input_ptr[axis_offset + inner_size];
share_index_ptr[tid_d + 1] = axis_idx + 1;
} else {
share_value_ptr[tid_d + 1] = init;
}
// update with previous block result.
if (tid == 0) {
Update(fun, share_value_ptr, share_index_ptr, block_value, block_index);
template <typename BlockScanFunctor, typename ValueType, typename IndexType, uint BlockDimX, uint BlockDimY>
__global__ void ScanInnerMostDimKernel(BlockScanFunctor functor, const ValueType *input_ptr, ValueType *value_ptr,
IndexType *index_ptr, uint outer_size, uint axis_size, ValueType init) {
typedef thrust::tuple<ValueType, IndexType> DataType;
constexpr uint scan_per_block = BlockDimX * 2;
__shared__ ValueType share_value[BlockDimY][scan_per_block];
__shared__ IndexType share_index[BlockDimY][scan_per_block];
auto share_value_ptr = share_value[threadIdx.y];
auto share_index_ptr = share_index[threadIdx.y];
for (uint bid = blockIdx.x * blockDim.y; bid < outer_size; bid += gridDim.x * blockDim.y) {
uint outer_idx = bid + threadIdx.y;
bool is_valid = outer_idx < outer_size;
uint offset = outer_idx * axis_size;
DataType block_data{init, 0};
// The following parallel scan algorithm refers to:
// Figure 9.7 from David B. Kirk, et al. 'Programming Massively Parallel Processors'.
for (uint i = 0; i < axis_size; i += scan_per_block) {
// Initializing share memory with input value, and each thread process two elements.
uint idx1 = threadIdx.x + i;
uint idx2 = idx1 + BlockDimX;
if (is_valid) {
if (idx1 < axis_size) {
share_value_ptr[threadIdx.x] = input_ptr[offset + idx1];
share_index_ptr[threadIdx.x] = idx1;
} else {
share_value_ptr[threadIdx.x] = init;
}
if (idx2 < axis_size) {
share_value_ptr[threadIdx.x + BlockDimX] = input_ptr[offset + idx2];
share_index_ptr[threadIdx.x + BlockDimX] = idx2;
} else {
share_value_ptr[threadIdx.x + BlockDimX] = init;
}
// update with previous block result.
if (threadIdx.x == 0) {
thrust::tie(share_value_ptr[0], share_index_ptr[0]) =
functor(thrust::make_tuple(share_value_ptr[0], share_index_ptr[0]), block_data);
}
}
// up-sweep
for (uint stride = 1; stride < scan_per_block; stride <<= 1) {
__syncthreads();
uint index = (tid + 1) * (stride << 1) - 1;
if (index < scan_per_block) {
Update(fun, share_value_ptr + index, share_index_ptr + index, share_value_ptr[index - stride],
share_index_ptr[index - stride]);
uint index = (threadIdx.x + 1) * (stride << 1) - 1;
if (is_valid && index < scan_per_block) {
thrust::tie(share_value_ptr[index], share_index_ptr[index]) =
functor(thrust::make_tuple(share_value_ptr[index - stride], share_index_ptr[index - stride]),
thrust::make_tuple(share_value_ptr[index], share_index_ptr[index]));
}
}
// down-sweep
for (uint stride = scan_per_block >> 2; stride > 0; stride >>= 1) {
__syncthreads();
uint index = (tid + 1) * (stride << 1) - 1;
if (index + stride < scan_per_block) {
Update(fun, share_value_ptr + (index + stride), share_index_ptr + (index + stride), share_value_ptr[index],
share_index_ptr[index]);
uint index = (threadIdx.x + 1) * (stride << 1) - 1;
if (is_valid && index + stride < scan_per_block) {
thrust::tie(share_value_ptr[index + stride], share_index_ptr[index + stride]) =
functor(thrust::make_tuple(share_value_ptr[index], share_index_ptr[index]),
thrust::make_tuple(share_value_ptr[index + stride], share_index_ptr[index + stride]));
}
}
// write to output.
__syncthreads();
if (axis_idx < axis_size) {
cur_value_ptr[axis_offset] = share_value_ptr[tid_d];
cur_index_ptr[axis_offset] = share_index_ptr[tid_d];
if (is_valid) {
if (idx1 < axis_size) {
value_ptr[offset + idx1] = share_value_ptr[threadIdx.x];
index_ptr[offset + idx1] = share_index_ptr[threadIdx.x];
}
if (idx2 < axis_size) {
value_ptr[offset + idx2] = share_value_ptr[threadIdx.x + BlockDimX];
index_ptr[offset + idx2] = share_index_ptr[threadIdx.x + BlockDimX];
}
// update block_data
block_data = thrust::make_tuple(share_value_ptr[scan_per_block - 1], share_index_ptr[scan_per_block - 1]);
}
if (axis_idx + 1 < axis_size) {
cur_value_ptr[axis_offset + inner_size] = share_value_ptr[tid_d + 1];
cur_index_ptr[axis_offset + inner_size] = share_index_ptr[tid_d + 1];
}
// update block_value & block_index
if (tid == 0) {
block_value = share_value_ptr[scan_per_block - 1];
block_index = share_index_ptr[scan_per_block - 1];
}
__syncthreads();
}
}
}
template <typename BinaryFunctor, typename DataType, typename IndexType>
struct IndexFunctor {
const DataType *input_ptr_;
BinaryFunctor functor_;
explicit IndexFunctor(const DataType *input_ptr, BinaryFunctor functor) : input_ptr_(input_ptr), functor_(functor) {}
__device__ __forceinline__ IndexType operator()(IndexType x, IndexType y) {
auto lhs = input_ptr_[x];
auto rhs = input_ptr_[y];
return functor_(lhs, rhs) ? x : y;
}
};
template <typename BinaryFunctor, typename DataType>
struct ValueFunctor {
BinaryFunctor functor_;
explicit ValueFunctor(BinaryFunctor functor) : functor_(functor) {}
__device__ __forceinline__ DataType operator()(DataType lhs, DataType rhs) { return functor_(lhs, rhs) ? lhs : rhs; }
};
template <typename BinaryOp, typename DataType>
struct BinaryFunctor {
BinaryOp binary_op_;
__device__ __forceinline__ bool operator()(DataType lhs, DataType rhs) {
return !IsNan(rhs) && (IsNan(lhs) || !binary_op_(rhs, lhs));
}
};
template <typename BinaryFunctor, typename DataType, typename IndexType>
__global__ void CumMinMaxSlowKernel(BinaryFunctor functor, const DataType *input_ptr, DataType *value_ptr,
IndexType *index_ptr, uint axis_size, uint inner_size, uint axis_inner_size,
uint outer_inner_size) {
for (uint tid = blockIdx.x * blockDim.x + threadIdx.x; tid < outer_inner_size; tid += blockDim.x * gridDim.x) {
uint outer_idx = tid / inner_size;
uint inner_idx = tid % inner_size;
template <typename BlockScanFunctor, typename ValueType, typename IndexType>
__global__ void ScanOuterDimKernel(BlockScanFunctor functor, const ValueType *input_ptr, ValueType *value_ptr,
IndexType *index_ptr, uint axis_size, uint inner_size, uint axis_inner_size,
uint outer_inner_size, ValueType init) {
typedef thrust::tuple<ValueType, IndexType> DataType;
for (uint bid = blockIdx.x * blockDim.x + threadIdx.x; bid < outer_inner_size; bid += gridDim.x * blockDim.x) {
uint outer_idx = bid / inner_size;
uint inner_idx = bid % inner_size;
DataType out{init, 0};
uint offset = outer_idx * axis_inner_size + inner_idx;
auto cur_input_ptr = input_ptr + offset;
auto cur_value_ptr = value_ptr + offset;
auto cur_index_ptr = index_ptr + offset;
DataType out_val = *cur_value_ptr = *cur_input_ptr;
IndexType out_idx = *cur_index_ptr = 0;
for (uint j = 1; j < axis_size; j++) {
cur_input_ptr += inner_size;
cur_value_ptr += inner_size;
cur_index_ptr += inner_size;
DataType cur_val = *cur_input_ptr;
if (!functor(out_val, cur_val)) {
out_val = cur_val;
out_idx = static_cast<IndexType>(j);
}
*cur_value_ptr = out_val;
*cur_index_ptr = out_idx;
for (uint i = 0; i < axis_size; i++) {
DataType thread_data = thrust::make_tuple(input_ptr[offset], i);
out = functor(out, thread_data);
thrust::tie(value_ptr[offset], index_ptr[offset]) = out;
offset += inner_size;
}
}
}
template <typename BinaryFunctor, typename DataType, typename IndexType>
void KernelHelper(BinaryFunctor fun, DataType init, const DataType *input_ptr, DataType *value_ptr,
template <typename BinaryFunctor, typename ValueType, typename IndexType>
void KernelHelper(BinaryFunctor functor, ValueType init, const ValueType *input_ptr, ValueType *value_ptr,
IndexType *index_ptr, size_t outer_size_st, size_t axis_size_st, size_t inner_size_st,
const uint32_t &device_id, cudaStream_t cuda_stream) {
if (outer_size_st == 1 && inner_size_st == 1) {
// Special case where only one dimension that needs to compute, so using cub library is the most efficient way.
ValueFunctor<BinaryFunctor, DataType> value_fun{fun};
IndexFunctor<BinaryFunctor, DataType, IndexType> index_fun{input_ptr, fun};
size_t value_storage_bytes = 0;
size_t index_storage_bytes = 0;
auto outer_size = static_cast<uint>(outer_size_st);
auto inner_size = static_cast<uint>(inner_size_st);
auto axis_size = static_cast<uint>(axis_size_st);
auto outer_inner_size = outer_size * inner_size;
auto axis_inner_size = axis_size * inner_size;
uint max_grid_size = GetMaxGridDimY(device_id);
typedef BlockScanFunctor<BinaryFunctor, thrust::tuple<ValueType, IndexType>> BlockScanFunctor;
BlockScanFunctor scan_op{functor};
#if defined(CUB_VERSION) && (CUB_VERSION > 100800)
// Special case where only one dimension that needs to compute, so using cub library is the most efficient way.
if (outer_size == 1 && inner_size == 1) {
// Using thrust::zip_iterator to make an iterator for (ValueType, IndexType).
cub::CountingInputIterator<IndexType> count_iter(0);
(void)cub::DeviceScan::InclusiveScan(nullptr, value_storage_bytes, input_ptr, value_ptr, value_fun, axis_size_st,
cuda_stream);
(void)cub::DeviceScan::InclusiveScan(nullptr, index_storage_bytes, count_iter, index_ptr, index_fun, axis_size_st,
cuda_stream);
// Here only allocate once.
char *temp_storage_ptr = nullptr;
(void)cudaMalloc(&temp_storage_ptr, value_storage_bytes + index_storage_bytes);
void *value_storage_ptr = reinterpret_cast<void *>(temp_storage_ptr);
void *index_storage_ptr = reinterpret_cast<void *>(temp_storage_ptr + value_storage_bytes);
typedef typename thrust::detail::normal_iterator<const ValueType *> InputValueIterator;
typedef cub::CountingInputIterator<IndexType> InputIndexIterator;
typedef thrust::zip_iterator<thrust::tuple<InputValueIterator, InputIndexIterator>> InputZipIterator;
InputZipIterator input_iter(thrust::make_tuple(input_ptr, count_iter));
(void)cub::DeviceScan::InclusiveScan(value_storage_ptr, value_storage_bytes, input_ptr, value_ptr, value_fun,
axis_size_st, cuda_stream);
(void)cub::DeviceScan::InclusiveScan(index_storage_ptr, index_storage_bytes, count_iter, index_ptr, index_fun,
axis_size_st, cuda_stream);
typedef typename thrust::detail::normal_iterator<ValueType *> OutputValueIterator;
typedef typename thrust::detail::normal_iterator<IndexType *> OutputIndexIterator;
typedef thrust::zip_iterator<thrust::tuple<OutputValueIterator, OutputIndexIterator>> OutputZipIterator;
OutputZipIterator output_iter(thrust::make_tuple(value_ptr, index_ptr));
// Calculate the size of temporary storage.
size_t temp_storage_bytes = 0;
(void)cub::DeviceScan::InclusiveScan(nullptr, temp_storage_bytes, input_iter, output_iter, scan_op, axis_size,
cuda_stream);
// Allocate temporary storage.
char *temp_storage_ptr = nullptr;
(void)cudaMalloc(&temp_storage_ptr, temp_storage_bytes);
// Core computation process.
(void)cub::DeviceScan::InclusiveScan(temp_storage_ptr, temp_storage_bytes, input_iter, output_iter, scan_op,
axis_size, cuda_stream);
(void)cudaFree(temp_storage_ptr);
return;
}
// When computing capacity of CUDA is not recommended (<7), we instead use self-implemented scan algorithm.
// Otherwise, we use cub::BlockScan, which is faster than self-implemented one.
const int major_sm = GET_MAJOR_SM;
const bool check_sm = mindspore::device::gpu::CudaCommon::GetInstance().check_sm();
constexpr uint threshold_large_scan_dim = 500;
if (!(check_sm && major_sm < RECOMMEND_SM) && axis_size > threshold_large_scan_dim) {
constexpr uint block_dim = 512;
uint grid_x = std::min(outer_inner_size, max_grid_size);
dim3 block{block_dim};
dim3 grid{grid_x};
LargeBlockScanKernel<BlockScanFunctor, ValueType, IndexType, block_dim><<<grid, block, 0, cuda_stream>>>(
scan_op, input_ptr, value_ptr, index_ptr, axis_size, inner_size, axis_inner_size, outer_inner_size, init);
return;
}
#endif
if (inner_size == 1) {
constexpr uint block_dim_x = 32;
constexpr uint block_dim_y = 16;
// The reason why x-dimension of block is set to 32:
// Each thread process 2 elements, so each x-dimension of block process 64 elements. An obvious advantage is no
// bank conflict. In addition, we don't need `__syncthreads`, since 32 is equal to warp size.
uint grid_x = std::min(UP_DIV(outer_size, block_dim_y), max_grid_size);
dim3 block = {block_dim_x, block_dim_y};
dim3 grid = {grid_x};
ScanInnerMostDimKernel<BlockScanFunctor, ValueType, IndexType, block_dim_x, block_dim_y>
<<<grid, block, 0, cuda_stream>>>(scan_op, input_ptr, value_ptr, index_ptr, outer_size, axis_size, init);
} else {
auto outer_size = static_cast<uint>(outer_size_st);
auto inner_size = static_cast<uint>(inner_size_st);
auto axis_size = static_cast<uint>(axis_size_st);
auto outer_inner_size = outer_size * inner_size;
auto axis_inner_size = axis_size * inner_size;
if (inner_size_st == 1) {
// The partitioning strategy is as follows:
// 1. The block has two dimensions, the y dimension with max size is 128, scan an array with axis_size, while the
// other one is used to process batch dimension on parallel, and the specific size depends on the max size of
// shared memory and max threads number.
// 2. The gird has only one dimension, which requires to take over the remaining batch dimension.
constexpr uint max_block_y = 128;
uint max_share_size = GetMaxSharedMemoryPerBlock(device_id);
uint max_thread_size = GetMaxThreadsPerBlock(device_id);
uint max_grid_size = GetMaxGridDimX(device_id);
uint axis_power2 = 1u << Log2Ceil(axis_size);
uint block_y = std::min(max_block_y, axis_power2);
uint has_allocate = block_y * 2 * (sizeof(DataType) + sizeof(IndexType));
uint block_x = std::min(max_thread_size / block_y, max_share_size / has_allocate);
uint grid_x = std::min(max_grid_size, UP_DIV(outer_inner_size, block_x));
dim3 block = {block_x, block_y};
dim3 grid = {grid_x};
uint share_size = block_x * has_allocate;
CumMinMaxKernel<BinaryFunctor, DataType, IndexType><<<grid, block, share_size, cuda_stream>>>(
fun, input_ptr, value_ptr, index_ptr, axis_size, inner_size, axis_inner_size, outer_inner_size, init);
} else {
// A useless case. If you don't like this branch, please delete it.
CumMinMaxSlowKernel<BinaryFunctor, DataType, IndexType>
<<<CUDA_BLOCKS(device_id, outer_inner_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
fun, input_ptr, value_ptr, index_ptr, axis_size, inner_size, axis_inner_size, outer_inner_size);
}
constexpr uint block_dim = 512;
uint grid_x = std::min(UP_DIV(outer_inner_size, block_dim), max_grid_size);
dim3 block{block_dim};
dim3 grid{grid_x};
ScanOuterDimKernel<<<grid, block, 0, cuda_stream>>>(scan_op, input_ptr, value_ptr, index_ptr, axis_size, inner_size,
axis_inner_size, outer_inner_size, init);
}
}

View File

@ -17,7 +17,6 @@
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUM_OP_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUM_OP_IMPL_CUH_
#include <thrust/functional.h>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
enum CumOpType { CUMMIN = 0, CUMMAX, CUM_OP_INVALID_TYPE = 255 };

View File

@ -15,6 +15,7 @@
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/index_fill_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
template <typename DataType, typename Int>
__global__ void IndexFillKernel(const int *__restrict__ index_ptr, const DataType *__restrict__ value_ptr,
@ -109,3 +110,13 @@ template CUDA_LIB_EXPORT void IndexFill<double>(double *out_ptr, const int *inde
int64_t outer_size, int64_t dim_size, int64_t inner_size,
const double *value_ptr, bool *out_bound_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<Complex<float>>(Complex<float> *out_ptr, const int *index_ptr,
int64_t index_size, int64_t outer_size, int64_t dim_size,
int64_t inner_size, const Complex<float> *value_ptr,
bool *out_bound_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void IndexFill<Complex<double>>(Complex<double> *out_ptr, const int *index_ptr,
int64_t index_size, int64_t outer_size, int64_t dim_size,
int64_t inner_size, const Complex<double> *value_ptr,
bool *out_bound_ptr, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -29,14 +29,13 @@
namespace mindspore {
namespace ops {
namespace {
constexpr int64_t kIndexFillInputsNum = 4;
constexpr int64_t kIndexFillOutputsNum = 1;
TypePtr IndexFillInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kIndexFillInputsNum, prim_name);
constexpr int64_t input_num = 4;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
const std::set<TypePtr> valid_data_types = common_valid_types_with_bool;
const std::set<TypePtr> valid_data_types = common_valid_types_with_complex_and_bool;
const std::set<TypePtr> valid_dim_types = {kInt32, kInt64};
// Input 'dim' can be scalar or tensor.
@ -47,7 +46,7 @@ TypePtr IndexFillInferType(const PrimitivePtr &primitive, const std::vector<Abst
auto index_type = input_args[kInputIndex2]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("index", index_type, {kInt32}, prim_name);
// Input 'x' must must be a tensor.
// Input 'x' must be a tensor.
auto x_type = input_args[kInputIndex0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_data_types, prim_name);

View File

@ -428,6 +428,7 @@ from .cos import _cos_tbe
from .tan import _tan_tbe
from .tan_ds import _tan_ds_tbe
from .cum_sum import _cum_sum_tbe
from .cum_sum_ds import _cum_sum_ds_tbe
from .apply_rms_prop import _apply_rms_prop_tbe
from .cumprod import _cumprop_tbe
from .reduce_prod import _reduce_prod_tbe

View File

@ -0,0 +1,45 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CumSum op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
cum_sum_ds_op_info = TBERegOp("CumSum") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("cumsum.so") \
.compute_cost(10) \
.kernel_name("cumsum") \
.partial_flag(True) \
.dynamic_shape(True) \
.attr("axis", "optional", "int", "all", "0") \
.attr("exclusive", "optional", "bool", "true,false", "false") \
.attr("reverse", "optional", "bool", "true,false", "false") \
.input(0, "x", False, "required", "all") \
.input(1, "axis", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.get_op_info()
@op_info_register(cum_sum_ds_op_info)
def _cum_sum_ds_tbe():
"""CumSum TBE register"""
return