mirror of https://github.com/tracel-ai/burn.git
Merge branch 'main' into fix/devauto_tests
This commit is contained in:
commit
bb18452136
|
@ -6,6 +6,8 @@ use std::fmt::Display;
|
|||
pub enum Function {
|
||||
Powf(Elem),
|
||||
Erf(Elem),
|
||||
#[cfg(target_os = "macos")]
|
||||
SafeTanh(Elem),
|
||||
}
|
||||
|
||||
impl Display for Function {
|
||||
|
@ -13,6 +15,8 @@ impl Display for Function {
|
|||
match self {
|
||||
Function::Powf(elem) => format_powf(f, elem),
|
||||
Function::Erf(elem) => format_erf(f, elem),
|
||||
#[cfg(target_os = "macos")]
|
||||
Function::SafeTanh(elem) => format_safe_tanh(f, elem),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -69,3 +73,19 @@ fn erf(x: {elem}) -> {elem} {{
|
|||
"
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
fn format_safe_tanh(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
|
||||
fn safe_tanh(x: {elem}) -> {elem} {{
|
||||
if x > 43.0 {{
|
||||
return 1.0;
|
||||
}} else {{
|
||||
return tanh(x);
|
||||
}}
|
||||
}}
|
||||
"
|
||||
))
|
||||
}
|
||||
|
|
|
@ -162,6 +162,10 @@ impl ElemWiseKernelCodegen<BodyPhase> {
|
|||
Operator::Erf { input: _, out: _ } => {
|
||||
register_function(Function::Erf(Elem::F32));
|
||||
}
|
||||
#[cfg(target_os = "macos")]
|
||||
Operator::Tanh { input: _, out: _ } => {
|
||||
register_function(Function::SafeTanh(Elem::F32))
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
self.operations.push(ops.clone());
|
||||
|
|
|
@ -164,7 +164,12 @@ impl Display for Operator {
|
|||
Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")),
|
||||
Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")),
|
||||
Operator::Tanh { input, out } => {
|
||||
f.write_fmt(format_args!("let {out} = tanh({input});"))
|
||||
#[cfg(target_os = "macos")]
|
||||
let result = f.write_fmt(format_args!("let {out} = safe_tanh({input});"));
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let result = f.write_fmt(format_args!("let {out} = tanh({input});"));
|
||||
|
||||
result
|
||||
}
|
||||
Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")),
|
||||
Operator::Recip { input, out } => {
|
||||
|
|
Loading…
Reference in New Issue