Revert "Perf: cube reuse shape and strides (#1939)" (#1967)

This reverts commit ad81a997af.
This commit is contained in:
Nathaniel Simard 2024-07-04 16:16:17 -04:00 committed by GitHub
parent 679cfd6dfb
commit f709858a8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 109 deletions

View File

@ -1,30 +1,6 @@
use burn_cube::ir as cube;
use std::fmt::Display;
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
pub struct ConstantShape {
pub position: usize,
pub dim: usize,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
pub struct ConstantStride {
pub position: usize,
pub dim: usize,
}
impl Display for ConstantStride {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("stride_{}_{}", self.position, self.dim))
}
}
impl Display for ConstantShape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("shape_{}_{}", self.position, self.dim))
}
}
#[derive(Debug, Clone)]
pub enum Variable {
SubgroupSize,
@ -65,8 +41,6 @@ pub enum Variable {
NumWorkgroupsX,
NumWorkgroupsY,
NumWorkgroupsZ,
ConstantShape(ConstantShape),
ConstantStride(ConstantStride),
}
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
@ -132,8 +106,6 @@ impl Variable {
Variable::WorkgroupSize => true,
Variable::NumWorkgroups => true,
Variable::SubgroupSize => true,
Variable::ConstantShape(_) => true,
Variable::ConstantStride(_) => true,
}
}
pub fn index(&self, index: usize) -> IndexedVariable {
@ -183,8 +155,6 @@ impl Variable {
Self::NumWorkgroupsY => Item::Scalar(Elem::U32),
Self::NumWorkgroupsZ => Item::Scalar(Elem::U32),
Self::SubgroupSize => Item::Scalar(Elem::U32),
Self::ConstantShape(_) => Item::Scalar(Elem::U32),
Self::ConstantStride(_) => Item::Scalar(Elem::U32),
}
}
pub fn elem(&self) -> Elem {
@ -292,8 +262,6 @@ impl Display for Variable {
Variable::WorkgroupSize => f.write_str("workgroup_size_no_axis"),
Variable::NumWorkgroups => f.write_str("num_workgroups_no_axis"),
Variable::SubgroupSize => f.write_str("subgroup_size"),
Variable::ConstantShape(val) => f.write_fmt(format_args!("{val}")),
Variable::ConstantStride(val) => f.write_fmt(format_args!("{val}")),
}
}
}

View File

@ -1,6 +1,4 @@
use hashbrown::HashSet;
use super::{ConstantShape, ConstantStride, Instruction};
use super::Instruction;
use std::fmt::Display;
/// A body is composed of a list of [instructions](Instruction).
@ -14,8 +12,6 @@ pub struct Body {
pub id: bool,
pub stride: bool,
pub shape: bool,
pub constant_shapes: HashSet<ConstantShape>,
pub constant_strides: HashSet<ConstantStride>,
}
impl Display for Body {
@ -33,24 +29,6 @@ impl Display for Body {
f.write_str("let rank_2: u32 = rank * 2u;\n")?;
}
for shape in self.constant_shapes.iter() {
let declaration = Instruction::Shape {
dim: super::Variable::ConstantScalar(shape.dim as f64, super::Elem::U32),
position: shape.position,
out: super::Variable::ConstantShape(*shape),
};
f.write_fmt(format_args!("let {declaration};\n"))?;
}
for stride in self.constant_strides.iter() {
let declaration = Instruction::Stride {
dim: super::Variable::ConstantScalar(stride.dim as f64, super::Elem::U32),
position: stride.position,
out: super::Variable::ConstantStride(*stride),
};
f.write_fmt(format_args!("let {declaration};\n"))?;
}
for ops in self.instructions.iter() {
f.write_fmt(format_args!("{ops}"))?;
}

View File

@ -1,8 +1,7 @@
use super::{shader::ComputeShader, Item, SharedMemory};
use super::{ConstantShape, ConstantStride, LocalArray, Subgroup};
use super::{LocalArray, Subgroup};
use crate::compiler::wgsl;
use burn_cube::ir as cube;
use hashbrown::HashSet;
/// Wgsl Compiler.
#[derive(Clone, Default)]
@ -15,16 +14,14 @@ pub struct WgslCompiler {
workgroup_id: bool,
rank: bool,
id: bool,
stride: bool,
shape: bool,
num_workgroups: bool,
workgroup_id_no_axis: bool,
workgroup_size_no_axis: bool,
num_workgroup_no_axis: bool,
shared_memories: Vec<SharedMemory>,
local_arrays: Vec<LocalArray>,
shape: bool,
stride: bool,
constant_shapes: HashSet<ConstantShape>,
constant_strides: HashSet<ConstantStride>,
}
impl core::fmt::Debug for WgslCompiler {
@ -57,21 +54,12 @@ impl WgslCompiler {
let instructions = self.compile_scope(&mut value.body);
let extensions = register_extensions(&instructions);
let mut constant_shapes = HashSet::new();
let mut constant_strides = HashSet::new();
core::mem::swap(&mut self.constant_shapes, &mut constant_shapes);
core::mem::swap(&mut self.constant_strides, &mut constant_strides);
let body = wgsl::Body {
instructions,
rank: true,
id: self.id,
stride: self.stride,
shape: self.shape,
constant_shapes,
constant_strides,
};
wgsl::ComputeShader {
@ -438,58 +426,28 @@ impl WgslCompiler {
match metadata {
cube::Metadata::Stride { dim, var, out } => {
self.stride = true;
let position = match var {
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
};
let dim = self.compile_variable(dim);
let out = self.compile_variable(out);
match dim {
wgsl::Variable::ConstantScalar(val, _) => {
let var = ConstantStride {
position,
dim: val as usize,
};
self.constant_strides.insert(var);
wgsl::Instruction::Assign {
input: wgsl::Variable::ConstantStride(var),
out,
}
}
_ => wgsl::Instruction::Stride { dim, position, out },
wgsl::Instruction::Stride {
dim: self.compile_variable(dim),
position,
out: self.compile_variable(out),
}
}
cube::Metadata::Shape { dim, var, out } => {
self.shape = true;
let position = match var {
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
_ => panic!("Only Input and Output have a shape, got: {:?}", var),
_ => panic!("Only Input and Output have a shape, got {:?}", var),
};
let dim = self.compile_variable(dim);
let out = self.compile_variable(out);
match dim {
wgsl::Variable::ConstantScalar(val, _) => {
let var = ConstantShape {
position,
dim: val as usize,
};
self.constant_shapes.insert(var);
wgsl::Instruction::Assign {
input: wgsl::Variable::ConstantShape(var),
out,
}
}
_ => wgsl::Instruction::Shape { dim, position, out },
wgsl::Instruction::Shape {
dim: self.compile_variable(dim),
position,
out: self.compile_variable(out),
}
}
cube::Metadata::ArrayLength { var, out } => wgsl::Instruction::ArrayLength {