mirror of https://github.com/tracel-ai/burn.git
Feat/cube/fma (#1947)
This commit is contained in:
parent
cb6b5e7183
commit
82a883a57d
|
@ -52,10 +52,10 @@ fn bench<B: Backend>(
|
|||
token: Option<&str>,
|
||||
) {
|
||||
const D: usize = 3;
|
||||
let batch_size = 4048;
|
||||
let m = 320;
|
||||
let k = 4;
|
||||
let n = 324;
|
||||
let batch_size = 32;
|
||||
let m = 256;
|
||||
let k = 1024;
|
||||
let n = 256;
|
||||
let shape_lhs = [batch_size, m, k].into();
|
||||
let shape_rhs = [batch_size, k, n].into();
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
use crate::{
|
||||
ir::{FmaOperator, Operation, Operator},
|
||||
prelude::{CubeContext, CubePrimitive, ExpandElement},
|
||||
unexpanded,
|
||||
};
|
||||
|
||||
/// Fused multiply-add `A*B+C`.
|
||||
#[allow(unused_variables)]
|
||||
pub fn fma<C: CubePrimitive>(a: C, b: C, c: C) -> C {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expand method of [fma].
|
||||
#[allow(unused_variables)]
|
||||
pub fn fma_expand<C: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
a: ExpandElement,
|
||||
b: ExpandElement,
|
||||
c: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(a.item());
|
||||
|
||||
let out = *output;
|
||||
let a = *a;
|
||||
let b = *b;
|
||||
let c = *c;
|
||||
|
||||
context.register(Operation::Operator(Operator::Fma(FmaOperator {
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
out,
|
||||
})));
|
||||
|
||||
output
|
||||
}
|
|
@ -2,10 +2,12 @@ mod assignation;
|
|||
mod base;
|
||||
mod binary;
|
||||
mod cmp;
|
||||
mod fma;
|
||||
mod unary;
|
||||
|
||||
pub use assignation::*;
|
||||
pub use base::*;
|
||||
pub use binary::*;
|
||||
pub use cmp::*;
|
||||
pub use fma::*;
|
||||
pub use unary::*;
|
||||
|
|
|
@ -26,6 +26,7 @@ pub enum Operation {
|
|||
#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags
|
||||
pub enum Operator {
|
||||
Add(BinaryOperator),
|
||||
Fma(FmaOperator),
|
||||
Sub(BinaryOperator),
|
||||
Mul(BinaryOperator),
|
||||
Div(BinaryOperator),
|
||||
|
@ -133,6 +134,15 @@ pub struct ReadGlobalWithLayoutOperator {
|
|||
pub tensor_layout_pos: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct FmaOperator {
|
||||
pub a: Variable,
|
||||
pub b: Variable,
|
||||
pub c: Variable,
|
||||
pub out: Variable,
|
||||
}
|
||||
|
||||
impl From<Operator> for Operation {
|
||||
fn from(val: Operator) -> Self {
|
||||
Operation::Operator(val)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::{
|
||||
BinaryOperator, ClampOperator, InitOperator, Item, Operation, Operator, Subcube, UnaryOperator,
|
||||
Variable,
|
||||
BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator, Subcube,
|
||||
UnaryOperator, Variable,
|
||||
};
|
||||
|
||||
pub type Vectorization = u8;
|
||||
|
@ -33,6 +33,7 @@ impl Operator {
|
|||
Operator::Max(op) => Operator::Max(op.vectorize(vectorization)),
|
||||
Operator::Min(op) => Operator::Min(op.vectorize(vectorization)),
|
||||
Operator::Add(op) => Operator::Add(op.vectorize(vectorization)),
|
||||
Operator::Fma(op) => Operator::Fma(op.vectorize(vectorization)),
|
||||
Operator::Index(op) => Operator::Index(op.vectorize(vectorization)),
|
||||
Operator::UncheckedIndex(op) => Operator::UncheckedIndex(op.vectorize(vectorization)),
|
||||
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
|
||||
|
@ -140,6 +141,17 @@ impl ClampOperator {
|
|||
}
|
||||
}
|
||||
|
||||
impl FmaOperator {
|
||||
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
Self {
|
||||
a: self.a.vectorize(vectorization),
|
||||
b: self.b.vectorize(vectorization),
|
||||
c: self.c.vectorize(vectorization),
|
||||
out: self.out.vectorize(vectorization),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Variable {
|
||||
pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Self {
|
||||
match self {
|
||||
|
|
|
@ -346,6 +346,12 @@ impl CudaCompiler {
|
|||
gpu::Operator::Floor(op) => Instruction::Floor(self.compile_unary(op)),
|
||||
gpu::Operator::Ceil(op) => Instruction::Ceil(self.compile_unary(op)),
|
||||
gpu::Operator::Remainder(_op) => todo!(),
|
||||
gpu::Operator::Fma(op) => Instruction::Fma {
|
||||
a: self.compile_variable(op.a),
|
||||
b: self.compile_variable(op.b),
|
||||
c: self.compile_variable(op.c),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,12 @@ pub enum Instruction {
|
|||
},
|
||||
Modulo(BinaryInstruction),
|
||||
Add(BinaryInstruction),
|
||||
Fma {
|
||||
a: Variable,
|
||||
b: Variable,
|
||||
c: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Div(BinaryInstruction),
|
||||
Mul(BinaryInstruction),
|
||||
Sub(BinaryInstruction),
|
||||
|
@ -246,7 +252,38 @@ for (uint {i} = {start}; {i} < {end}; {i}++) {{
|
|||
))
|
||||
}
|
||||
Instruction::Wrap(it) => f.write_fmt(format_args!("{it}")),
|
||||
Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
|
||||
Instruction::Wmma(it) => f.write_fmt(format_args!("{it}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Fma;
|
||||
|
||||
impl Fma {
|
||||
fn format(
|
||||
f: &mut core::fmt::Formatter<'_>,
|
||||
a: &Variable,
|
||||
b: &Variable,
|
||||
c: &Variable,
|
||||
out: &Variable,
|
||||
) -> core::fmt::Result {
|
||||
let num = match out.item() {
|
||||
super::Item::Vec4(_) => 4,
|
||||
super::Item::Vec3(_) => 3,
|
||||
super::Item::Vec2(_) => 2,
|
||||
super::Item::Scalar(_) => 1,
|
||||
};
|
||||
|
||||
for i in 0..num {
|
||||
let ai = a.index(i);
|
||||
let bi = b.index(i);
|
||||
let ci = c.index(i);
|
||||
let outi = out.index(i);
|
||||
|
||||
f.write_fmt(format_args!("{outi} = fma({ai}, {bi}, {ci});\n"))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -195,6 +195,12 @@ impl TraceBuilder {
|
|||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operator::Fma(op) => {
|
||||
mark(&op.a, &mut local_tensor_ids_input);
|
||||
mark(&op.b, &mut local_tensor_ids_input);
|
||||
mark(&op.c, &mut local_tensor_ids_input);
|
||||
mark(&op.out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Max(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
|
|
|
@ -516,6 +516,12 @@ impl WgslCompiler {
|
|||
rhs: self.compile_variable(op.rhs),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
cube::Operator::Fma(op) => wgsl::Instruction::Fma {
|
||||
a: self.compile_variable(op.a),
|
||||
b: self.compile_variable(op.b),
|
||||
c: self.compile_variable(op.c),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
cube::Operator::Index(op) => wgsl::Instruction::Index {
|
||||
lhs: self.compile_variable(op.lhs),
|
||||
rhs: self.compile_variable(op.rhs),
|
||||
|
|
|
@ -26,6 +26,12 @@ pub enum Instruction {
|
|||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Fma {
|
||||
a: Variable,
|
||||
b: Variable,
|
||||
c: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
If {
|
||||
cond: Variable,
|
||||
instructions: Vec<Instruction>,
|
||||
|
@ -239,6 +245,9 @@ impl Display for Instruction {
|
|||
Instruction::Add { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n"))
|
||||
}
|
||||
Instruction::Fma { a, b, c, out } => {
|
||||
f.write_fmt(format_args!("{out} = fma({a}, {b}, {c});\n"))
|
||||
}
|
||||
Instruction::Min { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = min({lhs}, {rhs});\n"))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue