fix gpu random op performance

This commit is contained in:
baochong 2022-11-30 11:18:04 +08:00 committed by baochong
parent a1231c9f05
commit a0c3c53d18
1 changed files with 72 additions and 38 deletions

View File

@ -16,11 +16,20 @@
#include "random_op_impl.cuh"
#include "include/cuda_fp16.h"
__global__ void SetupKernel(int seed, curandState *globalState) {
auto id = blockIdx.x * blockDim.x + threadIdx.x;
curand_init(seed, id, 0, &globalState[id]);
}
template <typename T>
__global__ void NormalKernel(int seed, curandState *globalState, T *output, size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
curand_init(seed, i, 0, &globalState[i]);
output[i] = (T)curand_normal(&globalState[i]);
__global__ void NormalKernel(curandState *globalState, T *output, size_t count) {
auto id = blockIdx.x * blockDim.x + threadIdx.x;
auto localState = globalState[id];
while (id < count) {
globalState[id] = localState;
output[id] = (T)curand_normal(&localState);
id += blockDim.x * gridDim.x;
}
return;
}
@ -28,62 +37,83 @@ __global__ void NormalKernel(int seed, curandState *globalState, T *output, size
__device__ bool dev_error_res = false;
template <typename T>
__global__ void UniformIntKernel(int seed, curandState *globalState, T *input1, size_t input_size_1,
T *input2, size_t input_size_2, T *output, size_t count) {
__global__ void UniformIntKernel(curandState *globalState, T *input1, size_t input_size_1, T *input2,
size_t input_size_2, T *output, size_t count) {
if (!(input1[0] < input2[0])) {
dev_error_res = false;
return;
}
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
curand_init(seed, i, 0, &globalState[i]);
output[i] = (T)(curand_uniform(&globalState[i]) * (input2[0] - input1[0])) + input1[0];
auto id = blockIdx.x * blockDim.x + threadIdx.x;
auto localState = globalState[id];
while (id < count) {
globalState[id] = localState;
output[id] = (T)(curand_uniform(&localState) * (input2[0] - input1[0])) + input1[0];
id += blockDim.x * gridDim.x;
}
dev_error_res = true;
return;
}
template <typename T>
__global__ void UniformRealKernel(int seed, curandState *globalState, T *output, size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
curand_init(seed, i, 0, &globalState[i]);
output[i] = (T)curand_uniform(&globalState[i]);
__global__ void UniformRealKernel(curandState *globalState, T *output, size_t count) {
auto id = blockIdx.x * blockDim.x + threadIdx.x;
auto localState = globalState[id];
while (id < count) {
globalState[id] = localState;
output[id] = (T)curand_uniform(&localState);
id += blockDim.x * gridDim.x;
}
return;
}
template<typename S>
__global__ void TruncatedNormalKernel(int seed, curandState *globalState, S *output, size_t count) {
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
__global__ void TruncatedNormalKernel(curandState *globalState, S *output, size_t count) {
auto id = blockIdx.x * blockDim.x + threadIdx.x;
auto localState = globalState[id];
while (id < count) {
S random_data;
curand_init(seed, i, 0, &globalState[i]);
random_data = (S)curand_normal(&globalState[i]);
globalState[id] = localState;
random_data = (S)curand_normal(&localState);
auto curState = localState;
do {
random_data = (S)curand_normal(&globalState[i]);
random_data = (S)curand_normal(&curState);
}while(random_data < -(S)0.2 || random_data > (S)0.2);
output[i] = random_data;
output[id] = random_data;
id += blockDim.x * gridDim.x;
}
return;
}
template <typename R, typename T>
__global__ void RandomPoissonKernel(int seed, curandState *globalState, R *rate, int rate_size, T *output,
size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
curand_init(seed, i, 0, &globalState[i]);
auto j = i % rate_size;
output[i] = (T)curand_poisson(&globalState[i], rate[j]);
__global__ void RandomPoissonKernel(curandState *globalState, R *rate, int rate_size, T *output, size_t count) {
auto id = blockIdx.x * blockDim.x + threadIdx.x;
auto localState = globalState[id];
while (id < count) {
auto j = id % rate_size;
globalState[id] = localState;
output[id] = (T)curand_poisson(&localState, rate[j]);
id += blockDim.x * gridDim.x;
}
return;
}
template <typename T>
__global__ void StandardLaplaceKernel(int seed, curandState *globalState, T *output, size_t count, T min_num) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
curand_init(seed, i, 0, &globalState[i]);
T temp = (T)(curand_uniform(&globalState[i]) * 2 - 1);
__global__ void StandardLaplaceKernel(curandState *globalState, T *output, size_t count, T min_num) {
auto id = blockIdx.x * blockDim.x + threadIdx.x;
auto localState = globalState[id];
while (id < count) {
globalState[id] = localState;
T temp = (T)(curand_uniform(&localState) * 2 - 1);
T temp2 = temp < 0 ? temp + min_num : temp - min_num;
T sign = std::copysignf(1.0, temp2);
output[i] = -sign * std::log(1.0 - std::abs(temp2));
output[id] = -sign * std::log(1.0 - std::abs(temp2));
id += blockDim.x * gridDim.x;
}
return;
}
@ -99,7 +129,8 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si
} else {
RNG_seed = static_cast<int>(rd());
}
NormalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState, output, count);
SetupKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState);
NormalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(globalState, output, count);
return;
}
@ -116,8 +147,9 @@ bool UniformInt(int seed, int seed2, curandState *globalState, T *input1, size_t
RNG_seed = static_cast<int>(rd());
}
bool host_error_res = false;
SetupKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState);
UniformIntKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>
(RNG_seed, globalState, input1, input_size_1, input2, input_size_2, output, count);
(globalState, input1, input_size_1, input2, input_size_2, output, count);
cudaDeviceSynchronize();
cudaMemcpyFromSymbol(&host_error_res, dev_error_res, sizeof(bool));
return host_error_res;
@ -134,7 +166,8 @@ void UniformReal(int seed, int seed2, curandState *globalState, T *output, size_
} else {
RNG_seed = static_cast<int>(rd());
}
UniformRealKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState, output, count);
SetupKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState);
UniformRealKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(globalState, output, count);
return;
}
@ -149,7 +182,8 @@ void TruncatedNormal(int seed, int seed2, curandState *globalState, S *output, s
} else {
RNG_seed = static_cast<int>(rd());
}
TruncatedNormalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState, output, count);
SetupKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState);
TruncatedNormalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(globalState, output, count);
return;
}
@ -165,8 +199,8 @@ void RandomPoisson(int seed, int seed2, curandState *globalState, R *rate, int64
} else {
RNG_seed = static_cast<int>(rd());
}
RandomPoissonKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState, rate, rate_size,
output, count);
SetupKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState);
RandomPoissonKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(globalState, rate, rate_size, output, count);
return;
}
@ -182,8 +216,8 @@ void StandardLaplace(int seed, int seed2, curandState *globalState, T *output, s
RNG_seed = static_cast<int>(rd());
}
T min_num = std::nextafter(0, 1);
StandardLaplaceKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState, output, count,
min_num);
SetupKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, globalState);
StandardLaplaceKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(globalState, output, count, min_num);
return;
}