mirror of https://github.com/tracel-ai/burn.git
This reverts commit ad81a997af
.
This commit is contained in:
parent
679cfd6dfb
commit
f709858a8b
|
@ -1,30 +1,6 @@
|
|||
use burn_cube::ir as cube;
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
|
||||
pub struct ConstantShape {
|
||||
pub position: usize,
|
||||
pub dim: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
|
||||
pub struct ConstantStride {
|
||||
pub position: usize,
|
||||
pub dim: usize,
|
||||
}
|
||||
|
||||
impl Display for ConstantStride {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!("stride_{}_{}", self.position, self.dim))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ConstantShape {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!("shape_{}_{}", self.position, self.dim))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Variable {
|
||||
SubgroupSize,
|
||||
|
@ -65,8 +41,6 @@ pub enum Variable {
|
|||
NumWorkgroupsX,
|
||||
NumWorkgroupsY,
|
||||
NumWorkgroupsZ,
|
||||
ConstantShape(ConstantShape),
|
||||
ConstantStride(ConstantStride),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
|
@ -132,8 +106,6 @@ impl Variable {
|
|||
Variable::WorkgroupSize => true,
|
||||
Variable::NumWorkgroups => true,
|
||||
Variable::SubgroupSize => true,
|
||||
Variable::ConstantShape(_) => true,
|
||||
Variable::ConstantStride(_) => true,
|
||||
}
|
||||
}
|
||||
pub fn index(&self, index: usize) -> IndexedVariable {
|
||||
|
@ -183,8 +155,6 @@ impl Variable {
|
|||
Self::NumWorkgroupsY => Item::Scalar(Elem::U32),
|
||||
Self::NumWorkgroupsZ => Item::Scalar(Elem::U32),
|
||||
Self::SubgroupSize => Item::Scalar(Elem::U32),
|
||||
Self::ConstantShape(_) => Item::Scalar(Elem::U32),
|
||||
Self::ConstantStride(_) => Item::Scalar(Elem::U32),
|
||||
}
|
||||
}
|
||||
pub fn elem(&self) -> Elem {
|
||||
|
@ -292,8 +262,6 @@ impl Display for Variable {
|
|||
Variable::WorkgroupSize => f.write_str("workgroup_size_no_axis"),
|
||||
Variable::NumWorkgroups => f.write_str("num_workgroups_no_axis"),
|
||||
Variable::SubgroupSize => f.write_str("subgroup_size"),
|
||||
Variable::ConstantShape(val) => f.write_fmt(format_args!("{val}")),
|
||||
Variable::ConstantStride(val) => f.write_fmt(format_args!("{val}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
use hashbrown::HashSet;
|
||||
|
||||
use super::{ConstantShape, ConstantStride, Instruction};
|
||||
use super::Instruction;
|
||||
use std::fmt::Display;
|
||||
|
||||
/// A body is composed of a list of [instructions](Instruction).
|
||||
|
@ -14,8 +12,6 @@ pub struct Body {
|
|||
pub id: bool,
|
||||
pub stride: bool,
|
||||
pub shape: bool,
|
||||
pub constant_shapes: HashSet<ConstantShape>,
|
||||
pub constant_strides: HashSet<ConstantStride>,
|
||||
}
|
||||
|
||||
impl Display for Body {
|
||||
|
@ -33,24 +29,6 @@ impl Display for Body {
|
|||
f.write_str("let rank_2: u32 = rank * 2u;\n")?;
|
||||
}
|
||||
|
||||
for shape in self.constant_shapes.iter() {
|
||||
let declaration = Instruction::Shape {
|
||||
dim: super::Variable::ConstantScalar(shape.dim as f64, super::Elem::U32),
|
||||
position: shape.position,
|
||||
out: super::Variable::ConstantShape(*shape),
|
||||
};
|
||||
f.write_fmt(format_args!("let {declaration};\n"))?;
|
||||
}
|
||||
|
||||
for stride in self.constant_strides.iter() {
|
||||
let declaration = Instruction::Stride {
|
||||
dim: super::Variable::ConstantScalar(stride.dim as f64, super::Elem::U32),
|
||||
position: stride.position,
|
||||
out: super::Variable::ConstantStride(*stride),
|
||||
};
|
||||
f.write_fmt(format_args!("let {declaration};\n"))?;
|
||||
}
|
||||
|
||||
for ops in self.instructions.iter() {
|
||||
f.write_fmt(format_args!("{ops}"))?;
|
||||
}
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use super::{shader::ComputeShader, Item, SharedMemory};
|
||||
use super::{ConstantShape, ConstantStride, LocalArray, Subgroup};
|
||||
use super::{LocalArray, Subgroup};
|
||||
use crate::compiler::wgsl;
|
||||
use burn_cube::ir as cube;
|
||||
use hashbrown::HashSet;
|
||||
|
||||
/// Wgsl Compiler.
|
||||
#[derive(Clone, Default)]
|
||||
|
@ -15,16 +14,14 @@ pub struct WgslCompiler {
|
|||
workgroup_id: bool,
|
||||
rank: bool,
|
||||
id: bool,
|
||||
stride: bool,
|
||||
shape: bool,
|
||||
num_workgroups: bool,
|
||||
workgroup_id_no_axis: bool,
|
||||
workgroup_size_no_axis: bool,
|
||||
num_workgroup_no_axis: bool,
|
||||
shared_memories: Vec<SharedMemory>,
|
||||
local_arrays: Vec<LocalArray>,
|
||||
shape: bool,
|
||||
stride: bool,
|
||||
constant_shapes: HashSet<ConstantShape>,
|
||||
constant_strides: HashSet<ConstantStride>,
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for WgslCompiler {
|
||||
|
@ -57,21 +54,12 @@ impl WgslCompiler {
|
|||
|
||||
let instructions = self.compile_scope(&mut value.body);
|
||||
let extensions = register_extensions(&instructions);
|
||||
|
||||
let mut constant_shapes = HashSet::new();
|
||||
let mut constant_strides = HashSet::new();
|
||||
|
||||
core::mem::swap(&mut self.constant_shapes, &mut constant_shapes);
|
||||
core::mem::swap(&mut self.constant_strides, &mut constant_strides);
|
||||
|
||||
let body = wgsl::Body {
|
||||
instructions,
|
||||
rank: true,
|
||||
id: self.id,
|
||||
stride: self.stride,
|
||||
shape: self.shape,
|
||||
constant_shapes,
|
||||
constant_strides,
|
||||
};
|
||||
|
||||
wgsl::ComputeShader {
|
||||
|
@ -438,58 +426,28 @@ impl WgslCompiler {
|
|||
match metadata {
|
||||
cube::Metadata::Stride { dim, var, out } => {
|
||||
self.stride = true;
|
||||
|
||||
let position = match var {
|
||||
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
|
||||
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
||||
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
|
||||
};
|
||||
|
||||
let dim = self.compile_variable(dim);
|
||||
let out = self.compile_variable(out);
|
||||
|
||||
match dim {
|
||||
wgsl::Variable::ConstantScalar(val, _) => {
|
||||
let var = ConstantStride {
|
||||
position,
|
||||
dim: val as usize,
|
||||
};
|
||||
self.constant_strides.insert(var);
|
||||
|
||||
wgsl::Instruction::Assign {
|
||||
input: wgsl::Variable::ConstantStride(var),
|
||||
out,
|
||||
}
|
||||
}
|
||||
_ => wgsl::Instruction::Stride { dim, position, out },
|
||||
wgsl::Instruction::Stride {
|
||||
dim: self.compile_variable(dim),
|
||||
position,
|
||||
out: self.compile_variable(out),
|
||||
}
|
||||
}
|
||||
cube::Metadata::Shape { dim, var, out } => {
|
||||
self.shape = true;
|
||||
|
||||
let position = match var {
|
||||
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
|
||||
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
||||
_ => panic!("Only Input and Output have a shape, got: {:?}", var),
|
||||
_ => panic!("Only Input and Output have a shape, got {:?}", var),
|
||||
};
|
||||
|
||||
let dim = self.compile_variable(dim);
|
||||
let out = self.compile_variable(out);
|
||||
|
||||
match dim {
|
||||
wgsl::Variable::ConstantScalar(val, _) => {
|
||||
let var = ConstantShape {
|
||||
position,
|
||||
dim: val as usize,
|
||||
};
|
||||
self.constant_shapes.insert(var);
|
||||
|
||||
wgsl::Instruction::Assign {
|
||||
input: wgsl::Variable::ConstantShape(var),
|
||||
out,
|
||||
}
|
||||
}
|
||||
_ => wgsl::Instruction::Shape { dim, position, out },
|
||||
wgsl::Instruction::Shape {
|
||||
dim: self.compile_variable(dim),
|
||||
position,
|
||||
out: self.compile_variable(out),
|
||||
}
|
||||
}
|
||||
cube::Metadata::ArrayLength { var, out } => wgsl::Instruction::ArrayLength {
|
||||
|
|
Loading…
Reference in New Issue