Feat/cube/fma (#1947)

This commit is contained in:
Nathaniel Simard 2024-07-02 08:32:39 -04:00 committed by GitHub
parent cb6b5e7183
commit 82a883a57d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 130 additions and 6 deletions

View File

@ -52,10 +52,10 @@ fn bench<B: Backend>(
token: Option<&str>,
) {
const D: usize = 3;
let batch_size = 4048;
let m = 320;
let k = 4;
let n = 324;
let batch_size = 32;
let m = 256;
let k = 1024;
let n = 256;
let shape_lhs = [batch_size, m, k].into();
let shape_rhs = [batch_size, k, n].into();

View File

@ -0,0 +1,36 @@
use crate::{
ir::{FmaOperator, Operation, Operator},
prelude::{CubeContext, CubePrimitive, ExpandElement},
unexpanded,
};
/// Fused multiply-add `A*B+C`.
#[allow(unused_variables)]
pub fn fma<C: CubePrimitive>(a: C, b: C, c: C) -> C {
unexpanded!()
}
/// Expand method of [fma].
#[allow(unused_variables)]
pub fn fma_expand<C: CubePrimitive>(
context: &mut CubeContext,
a: ExpandElement,
b: ExpandElement,
c: ExpandElement,
) -> ExpandElement {
let output = context.create_local(a.item());
let out = *output;
let a = *a;
let b = *b;
let c = *c;
context.register(Operation::Operator(Operator::Fma(FmaOperator {
a,
b,
c,
out,
})));
output
}

View File

@ -2,10 +2,12 @@ mod assignation;
mod base;
mod binary;
mod cmp;
mod fma;
mod unary;
pub use assignation::*;
pub use base::*;
pub use binary::*;
pub use cmp::*;
pub use fma::*;
pub use unary::*;

View File

@ -26,6 +26,7 @@ pub enum Operation {
#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags
pub enum Operator {
Add(BinaryOperator),
Fma(FmaOperator),
Sub(BinaryOperator),
Mul(BinaryOperator),
Div(BinaryOperator),
@ -133,6 +134,15 @@ pub struct ReadGlobalWithLayoutOperator {
pub tensor_layout_pos: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct FmaOperator {
pub a: Variable,
pub b: Variable,
pub c: Variable,
pub out: Variable,
}
impl From<Operator> for Operation {
fn from(val: Operator) -> Self {
Operation::Operator(val)

View File

@ -1,6 +1,6 @@
use super::{
BinaryOperator, ClampOperator, InitOperator, Item, Operation, Operator, Subcube, UnaryOperator,
Variable,
BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator, Subcube,
UnaryOperator, Variable,
};
pub type Vectorization = u8;
@ -33,6 +33,7 @@ impl Operator {
Operator::Max(op) => Operator::Max(op.vectorize(vectorization)),
Operator::Min(op) => Operator::Min(op.vectorize(vectorization)),
Operator::Add(op) => Operator::Add(op.vectorize(vectorization)),
Operator::Fma(op) => Operator::Fma(op.vectorize(vectorization)),
Operator::Index(op) => Operator::Index(op.vectorize(vectorization)),
Operator::UncheckedIndex(op) => Operator::UncheckedIndex(op.vectorize(vectorization)),
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
@ -140,6 +141,17 @@ impl ClampOperator {
}
}
impl FmaOperator {
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
Self {
a: self.a.vectorize(vectorization),
b: self.b.vectorize(vectorization),
c: self.c.vectorize(vectorization),
out: self.out.vectorize(vectorization),
}
}
}
impl Variable {
pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Self {
match self {

View File

@ -346,6 +346,12 @@ impl CudaCompiler {
gpu::Operator::Floor(op) => Instruction::Floor(self.compile_unary(op)),
gpu::Operator::Ceil(op) => Instruction::Ceil(self.compile_unary(op)),
gpu::Operator::Remainder(_op) => todo!(),
gpu::Operator::Fma(op) => Instruction::Fma {
a: self.compile_variable(op.a),
b: self.compile_variable(op.b),
c: self.compile_variable(op.c),
out: self.compile_variable(op.out),
},
}
}

View File

@ -27,6 +27,12 @@ pub enum Instruction {
},
Modulo(BinaryInstruction),
Add(BinaryInstruction),
Fma {
a: Variable,
b: Variable,
c: Variable,
out: Variable,
},
Div(BinaryInstruction),
Mul(BinaryInstruction),
Sub(BinaryInstruction),
@ -246,7 +252,38 @@ for (uint {i} = {start}; {i} < {end}; {i}++) {{
))
}
Instruction::Wrap(it) => f.write_fmt(format_args!("{it}")),
Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
Instruction::Wmma(it) => f.write_fmt(format_args!("{it}")),
}
}
}
struct Fma;
impl Fma {
fn format(
f: &mut core::fmt::Formatter<'_>,
a: &Variable,
b: &Variable,
c: &Variable,
out: &Variable,
) -> core::fmt::Result {
let num = match out.item() {
super::Item::Vec4(_) => 4,
super::Item::Vec3(_) => 3,
super::Item::Vec2(_) => 2,
super::Item::Scalar(_) => 1,
};
for i in 0..num {
let ai = a.index(i);
let bi = b.index(i);
let ci = c.index(i);
let outi = out.index(i);
f.write_fmt(format_args!("{outi} = fma({ai}, {bi}, {ci});\n"))?;
}
Ok(())
}
}

View File

@ -195,6 +195,12 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operator::Fma(op) => {
mark(&op.a, &mut local_tensor_ids_input);
mark(&op.b, &mut local_tensor_ids_input);
mark(&op.c, &mut local_tensor_ids_input);
mark(&op.out, &mut local_tensor_ids_output);
}
Operator::Max(op) => mark_binary(
op,
&mut local_tensor_ids_input,

View File

@ -516,6 +516,12 @@ impl WgslCompiler {
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
cube::Operator::Fma(op) => wgsl::Instruction::Fma {
a: self.compile_variable(op.a),
b: self.compile_variable(op.b),
c: self.compile_variable(op.c),
out: self.compile_variable(op.out),
},
cube::Operator::Index(op) => wgsl::Instruction::Index {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),

View File

@ -26,6 +26,12 @@ pub enum Instruction {
rhs: Variable,
out: Variable,
},
Fma {
a: Variable,
b: Variable,
c: Variable,
out: Variable,
},
If {
cond: Variable,
instructions: Vec<Instruction>,
@ -239,6 +245,9 @@ impl Display for Instruction {
Instruction::Add { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n"))
}
Instruction::Fma { a, b, c, out } => {
f.write_fmt(format_args!("{out} = fma({a}, {b}, {c});\n"))
}
Instruction::Min { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = min({lhs}, {rhs});\n"))
}