!33490 [MS][ops]fix fastgelu cuda10 bug

Merge pull request !33490 from KXiong/master
This commit is contained in:
i-robot 2022-04-25 07:55:30 +00:00 committed by Gitee
commit 76db1413e6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 20 additions and 6 deletions

View File

@ -33,8 +33,8 @@ template <>
__global__ void FastGeluKernel(size_t size, half *input_addr, half *output_addr) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
half x = input_addr[pos];
half up = hexp(half(0.851) * (x - __habs(x)));
half down = half(1) + hexp(half(-1.702) * __habs(x));
half up = hexp(half(0.851) * (x - half(std::abs(__half2float(x)))));
half down = half(1) + hexp(half(-1.702) * half(std::abs(__half2float(x))));
output_addr[pos] = x / down * up;
}
}
@ -43,8 +43,15 @@ template <>
__global__ void FastGeluKernel(size_t size, half2 *input_addr, half2 *output_addr) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
half2 x = input_addr[pos];
half2 up = h2exp(half2(0.851, 0.851) * (x - __habs2(x)));
half2 down = half2(1, 1) + h2exp(half2(-1.702, -1.702) * __habs2(x));
float2 float2_x = __half22float2(x);
float2 abs_x_res;
abs_x_res.x = std::abs(float2_x.x);
abs_x_res.y = std::abs(float2_x.y);
half2 half2_x_abs = __float22half2_rn(abs_x_res);
half2 up = h2exp(half2(0.851, 0.851) * (x - half2_x_abs));
half2 down = half2(1, 1) + h2exp(half2(-1.702, -1.702) * half2_x_abs);
output_addr[pos] = x / down * up;
}
}
@ -86,7 +93,14 @@ __global__ void FastGeluGradKernel(size_t size, half2 *dy_addr, half2 *x_addr, h
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
half2 x = x_addr[pos];
half2 exp_res = h2exp(half2(-1.702, -1.702) * x);
half2 div_up = exp_res + half2(1.702, 1.702) * x * exp_res + h2exp(half2(1.702, 1.702) * (x - __habs2(x)));
float2 float2_x = __half22float2(x);
float2 abs_x_res;
abs_x_res.x = std::abs(float2_x.x);
abs_x_res.y = std::abs(float2_x.y);
half2 half2_x_abs = __float22half2_rn(abs_x_res);
half2 div_up = exp_res + half2(1.702, 1.702) * x * exp_res + h2exp(half2(1.702, 1.702) * (x - half2_x_abs));
half2 div_down = (exp_res + half2(1, 1)) * (exp_res + half2(1, 1));
half2 y_res = div_up / div_down;
dx_addr[pos] = dy_addr[pos] * y_res;
@ -98,7 +112,7 @@ __global__ void FastGeluGradKernel(size_t size, half *dy_addr, half *x_addr, hal
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
half x = x_addr[pos];
half exp_res = hexp(half(-1.702) * x);
half div_up = exp_res + half(1.702) * x * exp_res + hexp(half(1.702) * (x - __habs(x)));
half div_up = exp_res + half(1.702) * x * exp_res + hexp(half(1.702) * (x - half(std::abs(__half2float(x)))));
half div_down = (exp_res + half(1)) * (exp_res + half(1));
half y_res = div_up / div_down;
dx_addr[pos] = dy_addr[pos] * y_res;