Fix for metal tanh. (#2475)
This commit is contained in:
parent
b60faebea4
commit
c09afc211c
|
@ -56,7 +56,7 @@ template <typename T> METAL_FUNC T gelu(T x) {
|
|||
T x_cube = x_sq * x;
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(precise::tanh(beta)));
|
||||
}
|
||||
template <typename T> METAL_FUNC T relu(T in){
|
||||
if (in < 0) {
|
||||
|
@ -154,7 +154,6 @@ UNARY_OP(floor)
|
|||
UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY_OP(recip)
|
||||
UNARY_OP(relu)
|
||||
UNARY_OP(sign)
|
||||
|
@ -164,6 +163,11 @@ UNARY(id, half, copy_f16, copy_f16_strided)
|
|||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
// tanh may create NaN on large values, e.g. 45 rather than outputing 1.
|
||||
// This has been an issue for the encodec example.
|
||||
UNARY(precise::tanh, float, tanh_f32, tanh_f32_strided);
|
||||
UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
COPY2D(copy2d_i64, int64_t)
|
||||
|
@ -185,7 +189,6 @@ BFLOAT_UNARY_OP(floor)
|
|||
BFLOAT_UNARY_OP(round)
|
||||
BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
BFLOAT_UNARY_OP(sign)
|
||||
|
@ -193,5 +196,7 @@ BFLOAT_UNARY_OP(sigmoid)
|
|||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
|
||||
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
|
||||
|
||||
COPY2D(copy2d_bf16, bfloat)
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue