!33490 [MS][ops]fix fastgelu cuda10 bug
Merge pull request !33490 from KXiong/master
This commit is contained in:
commit
76db1413e6
|
@ -33,8 +33,8 @@ template <>
|
|||
__global__ void FastGeluKernel(size_t size, half *input_addr, half *output_addr) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
half x = input_addr[pos];
|
||||
half up = hexp(half(0.851) * (x - __habs(x)));
|
||||
half down = half(1) + hexp(half(-1.702) * __habs(x));
|
||||
half up = hexp(half(0.851) * (x - half(std::abs(__half2float(x)))));
|
||||
half down = half(1) + hexp(half(-1.702) * half(std::abs(__half2float(x))));
|
||||
output_addr[pos] = x / down * up;
|
||||
}
|
||||
}
|
||||
|
@ -43,8 +43,15 @@ template <>
|
|||
__global__ void FastGeluKernel(size_t size, half2 *input_addr, half2 *output_addr) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
half2 x = input_addr[pos];
|
||||
half2 up = h2exp(half2(0.851, 0.851) * (x - __habs2(x)));
|
||||
half2 down = half2(1, 1) + h2exp(half2(-1.702, -1.702) * __habs2(x));
|
||||
|
||||
float2 float2_x = __half22float2(x);
|
||||
float2 abs_x_res;
|
||||
abs_x_res.x = std::abs(float2_x.x);
|
||||
abs_x_res.y = std::abs(float2_x.y);
|
||||
half2 half2_x_abs = __float22half2_rn(abs_x_res);
|
||||
|
||||
half2 up = h2exp(half2(0.851, 0.851) * (x - half2_x_abs));
|
||||
half2 down = half2(1, 1) + h2exp(half2(-1.702, -1.702) * half2_x_abs);
|
||||
output_addr[pos] = x / down * up;
|
||||
}
|
||||
}
|
||||
|
@ -86,7 +93,14 @@ __global__ void FastGeluGradKernel(size_t size, half2 *dy_addr, half2 *x_addr, h
|
|||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
half2 x = x_addr[pos];
|
||||
half2 exp_res = h2exp(half2(-1.702, -1.702) * x);
|
||||
half2 div_up = exp_res + half2(1.702, 1.702) * x * exp_res + h2exp(half2(1.702, 1.702) * (x - __habs2(x)));
|
||||
|
||||
float2 float2_x = __half22float2(x);
|
||||
float2 abs_x_res;
|
||||
abs_x_res.x = std::abs(float2_x.x);
|
||||
abs_x_res.y = std::abs(float2_x.y);
|
||||
half2 half2_x_abs = __float22half2_rn(abs_x_res);
|
||||
|
||||
half2 div_up = exp_res + half2(1.702, 1.702) * x * exp_res + h2exp(half2(1.702, 1.702) * (x - half2_x_abs));
|
||||
half2 div_down = (exp_res + half2(1, 1)) * (exp_res + half2(1, 1));
|
||||
half2 y_res = div_up / div_down;
|
||||
dx_addr[pos] = dy_addr[pos] * y_res;
|
||||
|
@ -98,7 +112,7 @@ __global__ void FastGeluGradKernel(size_t size, half *dy_addr, half *x_addr, hal
|
|||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
half x = x_addr[pos];
|
||||
half exp_res = hexp(half(-1.702) * x);
|
||||
half div_up = exp_res + half(1.702) * x * exp_res + hexp(half(1.702) * (x - __habs(x)));
|
||||
half div_up = exp_res + half(1.702) * x * exp_res + hexp(half(1.702) * (x - half(std::abs(__half2float(x)))));
|
||||
half div_down = (exp_res + half(1)) * (exp_res + half(1));
|
||||
half y_res = div_up / div_down;
|
||||
dx_addr[pos] = dy_addr[pos] * y_res;
|
||||
|
|
Loading…
Reference in New Issue