Fix/wgpu/tanh (#1090)

This commit is contained in:
Louis Fortier-Dubois 2023-12-21 14:12:49 -05:00 committed by GitHub
parent d82e6b157b
commit b070706310
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 1 deletions

View File

@ -6,6 +6,8 @@ use std::fmt::Display;
pub enum Function {
Powf(Elem),
Erf(Elem),
#[cfg(target_os = "macos")]
SafeTanh(Elem),
}
impl Display for Function {
@ -13,6 +15,8 @@ impl Display for Function {
match self {
Function::Powf(elem) => format_powf(f, elem),
Function::Erf(elem) => format_erf(f, elem),
#[cfg(target_os = "macos")]
Function::SafeTanh(elem) => format_safe_tanh(f, elem),
}
}
}
@ -69,3 +73,19 @@ fn erf(x: {elem}) -> {elem} {{
"
))
}
#[cfg(target_os = "macos")]
fn format_safe_tanh(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
f.write_fmt(format_args!(
"
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
fn safe_tanh(x: {elem}) -> {elem} {{
if x > 43.0 {{
return 1.0;
}} else {{
return tanh(x);
}}
}}
"
))
}

View File

@ -162,6 +162,10 @@ impl ElemWiseKernelCodegen<BodyPhase> {
Operator::Erf { input: _, out: _ } => {
register_function(Function::Erf(Elem::F32));
}
#[cfg(target_os = "macos")]
Operator::Tanh { input: _, out: _ } => {
register_function(Function::SafeTanh(Elem::F32))
}
_ => {}
}
self.operations.push(ops.clone());

View File

@ -164,7 +164,12 @@ impl Display for Operator {
Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")),
Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")),
Operator::Tanh { input, out } => {
f.write_fmt(format_args!("let {out} = tanh({input});"))
#[cfg(target_os = "macos")]
let result = f.write_fmt(format_args!("let {out} = safe_tanh({input});"));
#[cfg(not(target_os = "macos"))]
let result = f.write_fmt(format_args!("let {out} = tanh({input});"));
result
}
Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")),
Operator::Recip { input, out } => {