[MSLITE] Fix op bugs and add tensorscatteradd conv2dbackpropinput in GPU

This commit is contained in:
张勇贤 2023-01-18 09:54:23 +08:00
parent f30e968c53
commit ebb98ad11a
10 changed files with 741 additions and 137 deletions

View File

@ -18,9 +18,10 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
template <typename T, typename S>
__global__ void TensorScatterUpdateKernel(T *input, S *indices, T *update, T *output, const size_t block_size,
const size_t input_size, const size_t output_size, const size_t indices_dim_0,
const size_t indices_dim_1, S *indices_stride, S *work_shape) {
__global__ void TensorScatterUpdateKernel(const T *input, const S *indices, const T *update, T *output,
const size_t block_size, const size_t input_size, const size_t output_size,
const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride,
S *work_shape) {
int i, j;
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
@ -45,9 +46,10 @@ __global__ void TensorScatterUpdateKernel(T *input, S *indices, T *update, T *ou
}
template <typename T, typename S>
__global__ void TensorScatterMinKernel(T *input, S *indices, T *update, T *output, const size_t block_size,
const size_t input_size, const size_t output_size, const size_t indices_dim_0,
const size_t indices_dim_1, S *indices_stride, S *work_shape) {
__global__ void TensorScatterMinKernel(const T *input, const S *indices, const T *update, T *output,
const size_t block_size, const size_t input_size, const size_t output_size,
const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride,
S *work_shape) {
int i, j;
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
@ -72,9 +74,10 @@ __global__ void TensorScatterMinKernel(T *input, S *indices, T *update, T *outpu
}
template <typename T, typename S>
__global__ void TensorScatterMaxKernel(T *input, S *indices, T *update, T *output, const size_t block_size,
const size_t input_size, const size_t output_size, const size_t indices_dim_0,
const size_t indices_dim_1, S *indices_stride, S *work_shape) {
__global__ void TensorScatterMaxKernel(const T *input, const S *indices, const T *update, T *output,
const size_t block_size, const size_t input_size, const size_t output_size,
const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride,
S *work_shape) {
int i, j;
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
@ -99,9 +102,10 @@ __global__ void TensorScatterMaxKernel(T *input, S *indices, T *update, T *outpu
}
template <typename T, typename S>
__global__ void TensorScatterAddKernel(T *input, S *indices, T *update, T *output, const size_t block_size,
const size_t input_size, const size_t output_size, const size_t indices_dim_0,
const size_t indices_dim_1, S *indices_stride, S *work_shape) {
__global__ void TensorScatterAddKernel(const T *input, const S *indices, const T *update, T *output,
const size_t block_size, const size_t input_size, const size_t output_size,
const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride,
S *work_shape) {
int i, j;
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
@ -126,9 +130,10 @@ __global__ void TensorScatterAddKernel(T *input, S *indices, T *update, T *outpu
}
template <typename T, typename S>
__global__ void TensorScatterSubKernel(T *input, S *indices, T *update, T *output, const size_t block_size,
const size_t input_size, const size_t output_size, const size_t indices_dim_0,
const size_t indices_dim_1, S *indices_stride, S *work_shape) {
__global__ void TensorScatterSubKernel(const T *input, const S *indices, const T *update, T *output,
const size_t block_size, const size_t input_size, const size_t output_size,
const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride,
S *work_shape) {
int i, j;
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
@ -153,9 +158,10 @@ __global__ void TensorScatterSubKernel(T *input, S *indices, T *update, T *outpu
}
template <typename T, typename S>
__global__ void TensorScatterMulKernel(T *input, S *indices, T *update, T *output, const size_t block_size,
const size_t input_size, const size_t output_size, const size_t indices_dim_0,
const size_t indices_dim_1, S *indices_stride, S *work_shape) {
__global__ void TensorScatterMulKernel(const T *input, const S *indices, const T *update, T *output,
const size_t block_size, const size_t input_size, const size_t output_size,
const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride,
S *work_shape) {
int i, j;
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
@ -180,9 +186,10 @@ __global__ void TensorScatterMulKernel(T *input, S *indices, T *update, T *outpu
}
template <typename T, typename S>
__global__ void TensorScatterDivKernel(T *input, S *indices, T *update, T *output, const size_t block_size,
const size_t input_size, const size_t output_size, const size_t indices_dim_0,
const size_t indices_dim_1, S *indices_stride, S *work_shape) {
__global__ void TensorScatterDivKernel(const T *input, const S *indices, const T *update, T *output,
const size_t block_size, const size_t input_size, const size_t output_size,
const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride,
S *work_shape) {
int i, j;
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
@ -207,10 +214,11 @@ __global__ void TensorScatterDivKernel(T *input, S *indices, T *update, T *outpu
}
template <typename T, typename S>
void TensorScatterArithmetic(const enum TensorScatterArithmeticFunctionType &func_type, T *input, S *indices, T *update,
T *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, S *work_shape,
uint32_t device_id, cudaStream_t stream) {
void TensorScatterArithmetic(const enum TensorScatterArithmeticFunctionType &func_type, const T *input,
const S *indices, const T *update, T *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, S *indices_stride, S *work_shape, uint32_t device_id,
cudaStream_t stream) {
switch (func_type) {
case TENSOR_SCATTER_FUNC_UPDATE:
return TensorScatterUpdateKernel<<<CUDA_BLOCKS(device_id, output_size), CUDA_THREADS(device_id), 0, stream>>>(
@ -246,7 +254,7 @@ void TensorScatterArithmetic(const enum TensorScatterArithmeticFunctionType &fun
}
template <typename T, typename S>
void CallTensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size,
void CallTensorScatterUpdate(const T *input, const S *indices, const T *update, T *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, S *indices_stride, S *work_shape, uint32_t device_id,
cudaStream_t stream) {
@ -256,156 +264,165 @@ void CallTensorScatterUpdate(T *input, S *indices, T *update, T *output, const s
}
template CUDA_LIB_EXPORT void TensorScatterArithmetic<half, int>(
const enum TensorScatterArithmeticFunctionType &func_type, half *input, int *indices, half *update, half *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream);
const enum TensorScatterArithmeticFunctionType &func_type, const half *input, const int *indices, const half *update,
half *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id,
cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<float, int>(
const enum TensorScatterArithmeticFunctionType &func_type, float *input, int *indices, float *update, float *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream);
const enum TensorScatterArithmeticFunctionType &func_type, const float *input, const int *indices,
const float *update, float *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id,
cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<double, int>(
const enum TensorScatterArithmeticFunctionType &func_type, double *input, int *indices, double *update,
double *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const double *input, const int *indices,
const double *update, double *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id,
cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<char, int>(
const enum TensorScatterArithmeticFunctionType &func_type, char *input, int *indices, char *update, char *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<unsigned char, int>(
const enum TensorScatterArithmeticFunctionType &func_type, unsigned char *input, int *indices, unsigned char *update,
unsigned char *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const char *input, const int *indices, const char *update,
char *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id,
cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<unsigned char, int>(
const enum TensorScatterArithmeticFunctionType &func_type, const unsigned char *input, const int *indices,
const unsigned char *update, unsigned char *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride,
int *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<int16_t, int>(
const enum TensorScatterArithmeticFunctionType &func_type, int16_t *input, int *indices, int16_t *update,
int16_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const int16_t *input, const int *indices,
const int16_t *update, int16_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id,
cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<uint16_t, int>(
const enum TensorScatterArithmeticFunctionType &func_type, uint16_t *input, int *indices, uint16_t *update,
uint16_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id,
cudaStream_t stream);
const enum TensorScatterArithmeticFunctionType &func_type, const uint16_t *input, const int *indices,
const uint16_t *update, uint16_t *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride,
int *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<int, int>(
const enum TensorScatterArithmeticFunctionType &func_type, int *input, int *indices, int *update, int *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<uint32_t, int>(
const enum TensorScatterArithmeticFunctionType &func_type, uint32_t *input, int *indices, uint32_t *update,
uint32_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const int *input, const int *indices, const int *update,
int *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id,
cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<uint32_t, int>(
const enum TensorScatterArithmeticFunctionType &func_type, const uint32_t *input, const int *indices,
const uint32_t *update, uint32_t *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride,
int *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<int64_t, int>(
const enum TensorScatterArithmeticFunctionType &func_type, int64_t *input, int *indices, int64_t *update,
int64_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const int64_t *input, const int *indices,
const int64_t *update, int64_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id,
cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<uint64_t, int>(
const enum TensorScatterArithmeticFunctionType &func_type, uint64_t *input, int *indices, uint64_t *update,
uint64_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const uint64_t *input, const int *indices,
const uint64_t *update, uint64_t *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride,
int *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<bool, int>(
const enum TensorScatterArithmeticFunctionType &func_type, const bool *input, const int *indices, const bool *update,
bool *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id,
cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<bool, int>(
const enum TensorScatterArithmeticFunctionType &func_type, bool *input, int *indices, bool *update, bool *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<half, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, half *input, int64_t *indices, half *update, half *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
const enum TensorScatterArithmeticFunctionType &func_type, const half *input, const int64_t *indices,
const half *update, half *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<float, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, float *input, int64_t *indices, float *update,
float *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const float *input, const int64_t *indices,
const float *update, float *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<double, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, double *input, int64_t *indices, double *update,
double *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const double *input, const int64_t *indices,
const double *update, double *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<char, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, char *input, int64_t *indices, char *update, char *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
const enum TensorScatterArithmeticFunctionType &func_type, const char *input, const int64_t *indices,
const char *update, char *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<unsigned char, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, unsigned char *input, int64_t *indices,
unsigned char *update, unsigned char *output, const size_t &block_size, const size_t &input_size,
const enum TensorScatterArithmeticFunctionType &func_type, const unsigned char *input, const int64_t *indices,
const unsigned char *update, unsigned char *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride,
int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<int16_t, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, int16_t *input, int64_t *indices, int16_t *update,
int16_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const int16_t *input, const int64_t *indices,
const int16_t *update, int16_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<uint16_t, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, uint16_t *input, int64_t *indices, uint16_t *update,
uint16_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
uint32_t device_id, cudaStream_t stream);
const enum TensorScatterArithmeticFunctionType &func_type, const uint16_t *input, const int64_t *indices,
const uint16_t *update, uint16_t *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride,
int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<int, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, int *input, int64_t *indices, int *update, int *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<uint32_t, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, uint32_t *input, int64_t *indices, uint32_t *update,
uint32_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const int *input, const int64_t *indices,
const int *update, int *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<uint32_t, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, const uint32_t *input, const int64_t *indices,
const uint32_t *update, uint32_t *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride,
int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<int64_t, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, int64_t *input, int64_t *indices, int64_t *update,
int64_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const int64_t *input, const int64_t *indices,
const int64_t *update, int64_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<uint64_t, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, uint64_t *input, int64_t *indices, uint64_t *update,
uint64_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const enum TensorScatterArithmeticFunctionType &func_type, const uint64_t *input, const int64_t *indices,
const uint64_t *update, uint64_t *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride,
int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<bool, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, const bool *input, const int64_t *indices,
const bool *update, bool *output, const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void TensorScatterArithmetic<bool, int64_t>(
const enum TensorScatterArithmeticFunctionType &func_type, bool *input, int64_t *indices, bool *update, bool *output,
template CUDA_LIB_EXPORT void CallTensorScatterUpdate<Complex<float>, int64_t>(
const Complex<float> *input, const int64_t *indices, const Complex<float> *update, Complex<float> *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void CallTensorScatterUpdate<Complex<float>, int64_t>(
Complex<float> *input, int64_t *indices, Complex<float> *update, Complex<float> *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void CallTensorScatterUpdate<Complex<float>, int>(
Complex<float> *input, int *indices, Complex<float> *update, Complex<float> *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream);
const Complex<float> *input, const int *indices, const Complex<float> *update, Complex<float> *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void CallTensorScatterUpdate<Complex<double>, int64_t>(
Complex<double> *input, int64_t *indices, Complex<double> *update, Complex<double> *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
const Complex<double> *input, const int64_t *indices, const Complex<double> *update, Complex<double> *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void CallTensorScatterUpdate<Complex<double>, int>(
Complex<double> *input, int *indices, Complex<double> *update, Complex<double> *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream);
const Complex<double> *input, const int *indices, const Complex<double> *update, Complex<double> *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream);

View File

@ -30,14 +30,15 @@ enum TensorScatterArithmeticFunctionType {
};
template <typename T, typename S>
CUDA_LIB_EXPORT void TensorScatterArithmetic(const enum TensorScatterArithmeticFunctionType &func_type, T *input,
S *indices, T *update, T *output, const size_t &block_size,
CUDA_LIB_EXPORT void TensorScatterArithmetic(const enum TensorScatterArithmeticFunctionType &func_type, const T *input,
const S *indices, const T *update, T *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1,
S *indices_stride, S *work_shape, uint32_t device_id, cudaStream_t stream);
template <typename T, typename S>
CUDA_LIB_EXPORT void CallTensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1,
S *indices_stride, S *work_shape, uint32_t device_id, cudaStream_t stream);
CUDA_LIB_EXPORT void CallTensorScatterUpdate(const T *input, const S *indices, const T *update, T *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, S *indices_stride, S *work_shape,
uint32_t device_id, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TENSOR_SCATTER_ARITHMETIC_CUH_

View File

@ -8,6 +8,7 @@ file(GLOB_RECURSE CUDA_KERNEL_SRC
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumsum_impl.cu
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/batchtospace_impl.cu
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/spacetobatch_impl.cu
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/depthtospace_impl.cu
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/select_impl.cu
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxpool_with_argmax_impl.cu
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/roi_align_impl.cu
@ -15,6 +16,7 @@ file(GLOB_RECURSE CUDA_KERNEL_SRC
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/boundingbox_decode_impl.cu
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/where_impl.cu
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/one_hot_impl.cu
${CCSRC_DIR}/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cu
)
set_source_files_properties(${CUDA_KERNEL_SRC} PROPERTIES CUDA_SOURCE_PROPERTY_FORMAT OBJ)

View File

@ -0,0 +1,146 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.h"
#include <memory>
#include "nnacl/pack.h"
namespace mindspore::lite {
int Conv2dBackpropInputTensorRT::IsSupport(const BaseOperatorPtr &base_operator,
const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) {
if (in_tensors.size() != INPUT_SIZE3) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
int Conv2dBackpropInputTensorRT::AddInnerOp(TensorRTContext *ctx) {
if (ctx == nullptr || ctx->network() == nullptr) {
MS_LOG(ERROR) << "context or network is invalid";
return RET_ERROR;
}
auto deconv_op = AsOps<ops::Conv2DBackpropInputFusion>();
if (deconv_op == nullptr) {
MS_LOG(ERROR) << "op action convert failed";
return RET_ERROR;
}
nvinfer1::ITensor *deconv_input = input(ctx, 0).trt_tensor_;
// transpose weight
const auto &weight_tensor = in_tensors_[1];
nvinfer1::Weights kernelWeights = lite::ConvertWeight(weight_tensor);
// deconv basic params
int nbOutputMaps = weight_tensor.Shape()[1];
if (nbOutputMaps <= 0) {
MS_LOG(ERROR) << "out_channel is invalid";
return RET_ERROR;
}
auto kernel_size = deconv_op->get_kernel_size();
if (kernel_size.empty()) {
MS_LOG(ERROR) << "kernel_size is null";
return RET_ERROR;
}
nvinfer1::Dims kernelSize = lite::ConvertCudaDims(std::vector<int64_t>(kernel_size.begin(), kernel_size.end()));
if (kernelSize.nbDims == -1) {
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
return RET_ERROR;
}
// bias
nvinfer1::Weights biasWeights{};
biasWeights.type = ConvertDataType(weight_tensor.DataType());
biasWeights.count = 0;
biasWeights.values = nullptr;
nvinfer1::IDeconvolutionLayer *deconv_layer =
ctx->network()->addDeconvolutionNd(*deconv_input, nbOutputMaps, kernelSize, kernelWeights, biasWeights);
if (deconv_layer == nullptr) {
MS_LOG(ERROR) << "DeconvolutionLayer failed";
return RET_ERROR;
}
deconv_layer->setName((op_name_ + "_deconv").c_str());
this->layer_ = deconv_layer;
// set extra params
SetAttributes(deconv_op, deconv_layer);
nvinfer1::ITensor *out_tensor = deconv_layer->getOutput(0);
ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name());
return RET_OK;
}
void Conv2dBackpropInputTensorRT::SetAttributes(const std::shared_ptr<ops::Conv2DBackpropInputFusion> &ms_op,
nvinfer1::IDeconvolutionLayer *decon_layer) {
// kernel_size
auto kernel_size = ms_op->get_kernel_size();
if (!kernel_size.empty()) {
auto kernel_size_val = std::vector<int64_t>(kernel_size.begin(), kernel_size.end());
nvinfer1::Dims kernel_size_dims = lite::ConvertCudaDims(kernel_size_val);
if (kernel_size_dims.nbDims == -1) {
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
return;
}
decon_layer->setKernelSizeNd(kernel_size_dims);
}
// nbOutputMaps
int nbOutputMaps = in_tensors_[1].Shape()[1];
decon_layer->setNbOutputMaps(nbOutputMaps);
// stride
auto stride = ms_op->get_stride();
if (!stride.empty()) {
auto stride_val = std::vector<int64_t>(stride.begin() + INPUT_SIZE2, stride.end());
nvinfer1::Dims stride_dims = lite::ConvertCudaDims(stride_val);
if (stride_dims.nbDims == -1) {
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
return;
}
decon_layer->setStrideNd(stride_dims);
}
// nbGroups
int32_t nbGroups = static_cast<int32_t>(ms_op->get_group());
decon_layer->setNbGroups(nbGroups);
// padding
PadMode pad_mode = ms_op->get_pad_mode();
if (pad_mode == PadMode::SAME) {
decon_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
} else {
auto padding = ms_op->get_pad_list();
auto padding_val = std::vector<int64_t>(padding.begin(), padding.end());
nvinfer1::Dims dims_pre{};
dims_pre.nbDims = DIMENSION_2D;
dims_pre.d[0] = padding_val[0]; // up
dims_pre.d[1] = padding_val[INPUT_SIZE2]; // left
decon_layer->setPrePadding(dims_pre);
nvinfer1::Dims dims_post{};
dims_post.nbDims = DIMENSION_2D;
dims_post.d[0] = padding_val[1];
dims_post.d[1] = padding_val[INPUT_SIZE3];
decon_layer->setPostPadding(dims_post);
}
}
REGISTER_TENSORRT_CREATOR(ops::kNameConv2DBackpropInputFusion, Conv2dBackpropInputTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV2DBACKPROPINPUT_TENSORRT_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV2DBACKPROPINPUT_TENSORRT_H_
#include <string>
#include <vector>
#include <memory>
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
#include "ops/fusion/conv2d_backprop_input_fusion.h"
namespace mindspore::lite {
class Conv2dBackpropInputTensorRT : public TensorRTOp {
public:
Conv2dBackpropInputTensorRT(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors, const std::string &name)
: TensorRTOp(base_operator, in_tensors, out_tensors, name) {}
~Conv2dBackpropInputTensorRT() = default;
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) override;
private:
void SetAttributes(const std::shared_ptr<ops::Conv2DBackpropInputFusion> &conv_op,
nvinfer1::IDeconvolutionLayer *decon_layer);
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV2DBACKPROPINPUT_TENSORRT_H_

View File

@ -0,0 +1,130 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h"
#include <cuda_runtime.h>
#include <numeric>
#include <memory>
#include <vector>
#include <functional>
#include <unordered_map>
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
#include "NvInferRuntimeCommon.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/depthtospace_impl.cuh"
#include "ops/depth_to_space.h"
namespace mindspore::lite {
int DepthToSpaceTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) {
if (in_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() < 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
int DepthToSpaceTensorRT::AddInnerOp(TensorRTContext *ctx) {
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
auto op = AsOps<ops::DepthToSpace>();
int block_size = op->get_block_size();
auto plugin = std::make_shared<DepthToSpacePlugin>(input_tensor->getName(), block_size, device_id_);
if (plugin == nullptr) {
MS_LOG(ERROR) << "add depthtospace plugin failed for" << op_name_;
return RET_ERROR;
}
nvinfer1::ITensor *inputTensors[] = {input_tensor};
nvinfer1::IPluginV2Layer *layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin);
if (layer == nullptr) {
MS_LOG(ERROR) << "add depthtospace op failed for TensorRT.";
return RET_ERROR;
}
layer->setName(op_name_.c_str());
nvinfer1::ITensor *out_tensor = layer->getOutput(0);
ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name());
this->layer_ = layer;
return RET_OK;
}
REGISTER_TENSORRT_PLUGIN(DepthToSpacePluginCreater);
template class TensorRTPluginCreater<DepthToSpacePlugin>;
template <class T>
nvinfer1::PluginFieldCollection TensorRTPluginCreater<T>::field_collection_{};
template <class T>
std::vector<nvinfer1::PluginField> TensorRTPluginCreater<T>::fields_;
int DepthToSpacePlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) noexcept {
return RunCudaDepthToSpace(inputDesc, inputs, outputs, stream);
}
int DepthToSpacePlugin::RunCudaDepthToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs,
void *const *outputs, cudaStream_t stream) {
nvinfer1::Dims input_dims = inputDesc[0].dims;
int in = input_dims.d[0];
int ic = input_dims.d[1];
int ih = input_dims.d[2];
int iw = input_dims.d[3];
int on = in;
int oc = ic / block_size_ / block_size_;
int oh = ih * block_size_;
int ow = iw * block_size_;
int size = on * oc * oh * ow;
CalDepthToSpace<float>(size, static_cast<const float *>(inputs[0]), in, ic, ih, iw, on, oc, oh, ow, block_size_,
static_cast<float *>(outputs[0]), device_id_, stream);
return RET_OK;
}
nvinfer1::IPluginV2DynamicExt *DepthToSpacePlugin::clone() const noexcept {
auto *plugin = new (std::nothrow) DepthToSpacePlugin(*this);
if (plugin == nullptr) {
MS_LOG(ERROR) << "new plugin failed!";
return nullptr;
}
plugin->setPluginNamespace(name_space_.c_str());
return plugin;
}
size_t DepthToSpacePlugin::getSerializationSize() const noexcept { return sizeof(int); }
nvinfer1::DimsExprs DepthToSpacePlugin::getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs,
int nbInputDims,
nvinfer1::IExprBuilder &exprBuilder) noexcept {
nvinfer1::DimsExprs dims;
dims.nbDims = inputs[0].nbDims;
dims.d[0] = inputs[0].d[0];
dims.d[1] = inputs[0].d[1];
auto block_size_sqrt = exprBuilder.constant(block_size_ * block_size_);
dims.d[1] = exprBuilder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, *inputs[0].d[1], *block_size_sqrt);
auto block_size = exprBuilder.constant(block_size_);
dims.d[INPUT_SIZE2] =
exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[INPUT_SIZE2], *block_size);
dims.d[INPUT_SIZE3] =
exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[INPUT_SIZE3], *block_size);
return dims;
}
void DepthToSpacePlugin::serialize(void *buffer) const noexcept { SerializeValue(&buffer, &block_size_, sizeof(int)); }
REGISTER_TENSORRT_CREATOR(ops::kNameDepthToSpace, DepthToSpaceTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,79 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_
#include <string>
#include <vector>
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h"
namespace mindspore::lite {
class DepthToSpaceTensorRT : public TensorRTOp {
public:
DepthToSpaceTensorRT(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors, std::string name)
: TensorRTOp(base_operator, in_tensors, out_tensors, name) {}
~DepthToSpaceTensorRT() override = default;
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) override;
};
constexpr auto DEPTHTOSPACETENSORRT_PLUGIN_NAME{"DepthToSpacePlugin"};
class DepthToSpacePlugin : public TensorRTPlugin {
public:
DepthToSpacePlugin(const std::string name, int block_size, uint32_t device_id)
: TensorRTPlugin(name, std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME), device_id), block_size_(block_size) {}
DepthToSpacePlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
: TensorRTPlugin(std::string(name), std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) {
const nvinfer1::PluginField *fields = fc->fields;
block_size_ = static_cast<const int *>(fields[0].data)[0];
}
DepthToSpacePlugin(const char *name, const void *serialData, size_t serialLength)
: TensorRTPlugin(std::string(name), std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) {
DeserializeValue(&serialData, &serialLength, &block_size_, sizeof(int));
}
DepthToSpacePlugin() = delete;
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void *buffer) const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
private:
int RunCudaDepthToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs,
cudaStream_t stream);
int block_size_;
const std::string layer_name_;
std::string name_space_;
};
class DepthToSpacePluginCreater : public TensorRTPluginCreater<DepthToSpacePlugin> {
public:
DepthToSpacePluginCreater() : TensorRTPluginCreater(std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) {}
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_

View File

@ -49,23 +49,6 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
scatter_input.format_ = Format::NCHW;
ctx->RegisterTensor(scatter_input, in_tensors_[0].Name());
}
if (type_ == ops::kNameTensorScatterAdd) {
nvinfer1::ITensor *value_tensor = ctx->ConvertTo1DTensor(0.f);
if (in_tensors_[0].DataType() == DataType::kNumberTypeInt32) {
value_tensor = ctx->ConvertTo1DTensor(0);
}
auto unsqueeze_layer = ctx->network()->addShuffle(*value_tensor);
CHECK_NULL_RETURN(unsqueeze_layer);
auto shape = ctx->network()->addShape(*input(ctx, 0).trt_tensor_)->getOutput(0);
int rank = shape->getDimensions().d[0];
nvinfer1::Dims unsqueeze{rank};
std::fill(unsqueeze.d, unsqueeze.d + rank, 1);
unsqueeze_layer->setReshapeDimensions(unsqueeze);
unsqueeze_layer->setZeroIsPlaceholder(false);
value_tensor = unsqueeze_layer->getOutput(0);
CHECK_NULL_RETURN(value_tensor);
scatter_input.trt_tensor_ = Broadcast(ctx, value_tensor, shape);
}
ITensorHelper indices_helper;
int ret = PreprocessInputs2SameDim(ctx, input(ctx, 1), &indices_helper);
if (ret != RET_OK || indices_helper.trt_tensor_ == nullptr) {
@ -87,11 +70,6 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
}
nvinfer1::ITensor *out_tensor = scatter_layer->getOutput(0);
if (type_ == ops::kNameTensorScatterAdd) {
out_tensor = ctx->network()
->addElementWise(*out_tensor, *input(ctx, 0).trt_tensor_, nvinfer1::ElementWiseOperation::kSUM)
->getOutput(0);
}
ctx->RegisterTensor(ITensorHelper{out_tensor, scatter_input.format_, scatter_input.same_format_},
out_tensors_[0].Name());
this->layer_ = scatter_layer;
@ -103,5 +81,4 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
}
REGISTER_TENSORRT_CREATOR(ops::kNameScatterNdUpdate, ScatterNdTensorRT)
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterUpdate, ScatterNdTensorRT)
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterAdd, ScatterNdTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,135 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.h"
#include <numeric>
#include <memory>
#include <functional>
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
#include "ops/tensor_scatter_add.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cuh"
namespace mindspore::lite {
int TensorScatterAddTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) {
if (in_tensors.size() != INPUT_SIZE3) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_;
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size() << " : " << op_name_;
return RET_ERROR;
}
return RET_OK;
}
int TensorScatterAddTensorRT::AddInnerOp(TensorRTContext *ctx) {
if (in_tensors_[0].IsConst()) {
ITensorHelper scatter_input;
scatter_input.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[0], op_name_);
scatter_input.format_ = Format::NCHW;
ctx->RegisterTensor(scatter_input, in_tensors_[0].Name());
}
nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input(ctx, 1).trt_tensor_,
input(ctx, INPUT_SIZE2).trt_tensor_};
auto plugin = std::make_shared<TensorScatterAddPlugin>(input(ctx, 0).trt_tensor_->getName(), device_id_);
nvinfer1::IPluginV2Layer *scatter_layer = ctx->network()->addPluginV2(inputTensors, 3, *plugin);
if (scatter_layer == nullptr) {
MS_LOG(ERROR) << "addScatter failed for TensorRT.";
return RET_ERROR;
}
nvinfer1::ITensor *out_tensor = scatter_layer->getOutput(0);
ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name());
this->layer_ = scatter_layer;
return RET_OK;
}
REGISTER_TENSORRT_PLUGIN(TensorScatterAddPluginCreater);
template class TensorRTPluginCreater<TensorScatterAddPlugin>;
template <class T>
nvinfer1::PluginFieldCollection TensorRTPluginCreater<T>::field_collection_{};
template <class T>
std::vector<nvinfer1::PluginField> TensorRTPluginCreater<T>::fields_;
int TensorScatterAddPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) noexcept {
return RunCudaTensorScatterAdd(inputDesc, inputs, outputs, stream);
}
int TensorScatterAddPlugin::RunCudaTensorScatterAdd(const nvinfer1::PluginTensorDesc *inputDesc,
const void *const *inputs, void *const *outputs,
cudaStream_t stream) {
nvinfer1::Dims input_dims = inputDesc[0].dims;
size_t input_num = std::accumulate(input_dims.d, input_dims.d + input_dims.nbDims, 1, std::multiplies<int>());
nvinfer1::Dims update_dims = inputDesc[INPUT_SIZE2].dims;
size_t update_num = std::accumulate(update_dims.d, update_dims.d + update_dims.nbDims, 1, std::multiplies<int>());
size_t indice_dim_0 = inputDesc[1].dims.d[0];
size_t indice_dim_1 = inputDesc[1].dims.d[1];
int block_size = 1;
for (int i = indice_dim_1; i != input_dims.nbDims; ++i) {
block_size *= input_dims.d[i];
}
std::vector<int> indice_stride(indice_dim_1, 0);
indice_stride[indice_stride.size() - 1] = block_size;
for (int i = indice_dim_1 - 1; i > 0; --i) {
indice_stride[i - 1] = indice_stride[i] * input_dims.d[i];
}
int *indice_stride_dptr{nullptr};
cudaMalloc(&indice_stride_dptr, indice_stride.size() * sizeof(int));
cudaMemcpy(indice_stride_dptr, indice_stride.data(), indice_stride.size() * sizeof(int), cudaMemcpyHostToDevice);
int *input_shape_dptr{nullptr};
cudaMalloc(&input_shape_dptr, input_dims.nbDims * sizeof(int));
cudaMemcpy(input_shape_dptr, input_dims.d, input_dims.nbDims * sizeof(int), cudaMemcpyHostToDevice);
cudaMemcpy(outputs[0], inputs[0], input_num * sizeof(float), cudaMemcpyDeviceToDevice);
TensorScatterArithmetic(TensorScatterArithmeticFunctionType::TENSOR_SCATTER_FUNC_ADD,
static_cast<const float *>(inputs[0]), static_cast<const int *>(inputs[1]),
static_cast<const float *>(inputs[INPUT_SIZE2]), static_cast<float *>(outputs[0]), block_size,
update_num, input_num, indice_dim_0, indice_dim_1, indice_stride_dptr, input_shape_dptr,
device_id_, stream);
cudaFree(indice_stride_dptr);
cudaFree(input_shape_dptr);
return RET_OK;
}
nvinfer1::IPluginV2DynamicExt *TensorScatterAddPlugin::clone() const noexcept {
auto *plugin = new TensorScatterAddPlugin(*this);
plugin->setPluginNamespace(name_space_.c_str());
return plugin;
}
bool TensorScatterAddPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc,
int nbInputs, int nbOutputs) noexcept {
if (tensorsDesc[pos].format != nvinfer1::TensorFormat::kLINEAR) {
return false;
}
return true;
}
size_t TensorScatterAddPlugin::getSerializationSize() const noexcept { return 0; }
void TensorScatterAddPlugin::serialize(void *buffer) const noexcept {}
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterAdd, TensorScatterAddTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,74 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_TENSOR_SCATTER_ADD_TENSORRT_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_TENSOR_SCATTER_ADD_TENSORRT_H_
#include <string>
#include <vector>
#include <algorithm>
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h"
namespace mindspore::lite {
class TensorScatterAddTensorRT : public TensorRTOp {
public:
TensorScatterAddTensorRT(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors, std::string name)
: TensorRTOp(base_operator, in_tensors, out_tensors, name) {}
~TensorScatterAddTensorRT() override = default;
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) override;
};
constexpr auto TENSORSCATTERADD_PLUGIN_NAME{"TensorScatterAddPlugin"};
class TensorScatterAddPlugin : public TensorRTPlugin {
public:
TensorScatterAddPlugin(const std::string &name, int device_id)
: TensorRTPlugin(name, std::string(TENSORSCATTERADD_PLUGIN_NAME), device_id) {}
TensorScatterAddPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
: TensorRTPlugin(std::string(name), std::string(TENSORSCATTERADD_PLUGIN_NAME)) {}
TensorScatterAddPlugin(const char *name, const void *serialData, size_t serialLength)
: TensorRTPlugin(std::string(name), std::string(TENSORSCATTERADD_PLUGIN_NAME)) {}
TensorScatterAddPlugin() = delete;
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void *buffer) const noexcept override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept override;
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
noexcept override {
return inputTypes[0];
}
private:
int RunCudaTensorScatterAdd(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs,
void *const *outputs, cudaStream_t stream);
};
class TensorScatterAddPluginCreater : public TensorRTPluginCreater<TensorScatterAddPlugin> {
public:
TensorScatterAddPluginCreater() : TensorRTPluginCreater(std::string(TENSORSCATTERADD_PLUGIN_NAME)) {}
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_TENSOR_SCATTER_ADD_TENSORRT_H_