diff --git a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs index 83722789f..9e766ae27 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs @@ -69,7 +69,7 @@ pub enum ReadingStrategy { } impl Scope { - fn unroll_lazy(&mut self) { + pub fn unroll_lazy(&mut self) { let map = self.map.map.clone(); let map = map.read().unwrap(); diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 919670524..7a8946e15 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,7 +1,10 @@ -use super::{ElementWise, ElementWiseState}; +use super::{ElementWise, ElementWiseState, GraphOptimization}; use crate::{ - element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement, - IntElement, JitBackend, Runtime, + element::JitElement, + fusion::{ElementWiseBuilder, GraphBuilder}, + kernel, + tensor::JitTensor, + FloatElement, IntElement, JitBackend, Runtime, }; use burn_compute::client::ComputeClient; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; @@ -16,6 +19,7 @@ use serde::{Deserialize, Serialize}; pub enum JitOptimization { /// Element wise optimization. ElementWise(ElementWise), + Graph(GraphOptimization), } /// Fusion optimization state type for JIT. @@ -34,18 +38,21 @@ where fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle>) { match self { Self::ElementWise(op) => op.execute(context), + Self::Graph(op) => op.execute(context), } } fn len(&self) -> usize { match self { Self::ElementWise(op) => op.len(), + Self::Graph(op) => op.len(), } } fn to_state(&self) -> JitOptimizationState { match self { Self::ElementWise(value) => JitOptimizationState::ElementWise(value.to_state()), + Self::Graph(value) => todo!(), } } @@ -111,7 +118,7 @@ impl FusionRuntime for FusionJitRuntime { fn optimizations( device: R::Device, ) -> Vec>> { - vec![Box::new(ElementWiseBuilder::::new(device))] + vec![Box::new(GraphBuilder::::new(device))] } } diff --git a/crates/burn-jit/src/fusion/graph/base.rs b/crates/burn-jit/src/fusion/graph/base.rs index b189161e5..c6a70a6f6 100644 --- a/crates/burn-jit/src/fusion/graph/base.rs +++ b/crates/burn-jit/src/fusion/graph/base.rs @@ -1,11 +1,16 @@ +use std::sync::Arc; + use crate::{ fusion::{tracing::TraceBuilder, JitOptimization}, gpu::{gpu, LazyProcedure, Scope, Variable, WorkgroupSize}, Runtime, }; +use burn_common::id::IdGenerator; use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus}; use burn_tensor::repr::{BinaryOperationDescription, OperationDescription, TensorId}; +use super::{GraphKernelFactory, GraphOptimization}; + #[derive(Clone)] pub struct LaunchSettings { workgroup_size: WorkgroupSizeSettings, @@ -61,7 +66,7 @@ pub trait Node: Send + Sync { fn any_input_visited(&self) -> bool; fn any_output_visited(&self) -> bool; fn launch_settings(&self) -> LaunchSettings; - fn register(&self, builder: &mut TraceBuilder); + fn trace(&self, builder: &mut TraceBuilder); } pub enum MergeSubGraphResult { @@ -72,6 +77,7 @@ pub enum MergeSubGraphResult { pub trait SubGraph: Send + Sync { fn register(self: Box, node: NodeBoxed) -> MergingResult; fn merge(self: Box, other: SubGraphBoxed) -> MergeSubGraphResult; + fn trace(&self, builder: &mut TraceBuilder); } #[derive(new)] @@ -122,6 +128,11 @@ impl SubGraph for TwoNodesSubGraph { ), } } + + fn trace(&self, builder: &mut TraceBuilder) { + self.graph_1.trace(builder); + self.graph_2.trace(builder); + } } impl SubGraph for SingleNodeSubGraph { @@ -160,6 +171,10 @@ impl SubGraph for SingleNodeSubGraph { } } } + + fn trace(&self, builder: &mut TraceBuilder) { + self.node.trace(builder); + } } pub enum MergingResult { @@ -238,7 +253,8 @@ impl Node for FloatAddOp { self.settings.clone() } - fn register(&self, builder: &mut TraceBuilder) { + fn trace(&self, builder: &mut TraceBuilder) { + println!("FloatAddOp registering"); let lhs = builder.input(&self.desc.lhs, Variable::Id); let rhs = builder.input(&self.desc.rhs, Variable::Id); let out = builder.output(&self.desc.out, Variable::Id); @@ -253,6 +269,10 @@ impl Node for FloatAddOp { fn expand(&self, scope: &mut Scope, position: Option) -> Variable { let position = position.unwrap_or(Variable::Id); + println!("FloatAddOp expand lazy"); + // let lhs = self.lhs; + // let rhs = self.rhs; + // let out = self.out; let lhs_input = self.lhs; let rhs_input = self.rhs; @@ -260,10 +280,14 @@ impl Node for FloatAddOp { let rhs = scope.create_local(self.rhs.item()); let out = self.out; + println!("INPUT {lhs_input:?}"); + // Is local but should not. gpu!(scope, lhs = lhs_input[position]); gpu!(scope, rhs = rhs_input[position]); gpu!(scope, out = lhs + rhs); + gpu!(scope, out = lhs + rhs); + out } } @@ -277,16 +301,26 @@ pub type NodeBoxed = Box; pub struct GraphBuilder { graphs: Vec, - launch_settings: LaunchSettings, + launch_settings: Option, status: OptimizationStatus, device: R::Device, size: usize, } impl GraphBuilder { + pub fn new(device: R::Device) -> Self { + Self { + graphs: Vec::new(), + launch_settings: None, + status: OptimizationStatus::Open, + device, + size: 0, + } + } fn add_node(&mut self, node: Box) { if self.graphs.is_empty() { let settings = node.launch_settings(); + println!("First node"); self.graphs = vec![Box::new(SingleNodeSubGraph::new(node, settings))]; return; } @@ -332,13 +366,25 @@ impl GraphBuilder { if let Some(node) = node_current.take() { let node_settings = node.launch_settings(); - if self.launch_settings.is_compatible_with(&node_settings) { + if let Some(launch_settings) = &self.launch_settings { + if launch_settings.is_compatible_with(&node_settings) { + let launch_settings = + LaunchSettings::most_restrictive(launch_settings.clone(), node_settings); + + self.graphs.push(Box::new(SingleNodeSubGraph::new( + node, + launch_settings.clone(), + ))); + + self.launch_settings = Some(launch_settings); + self.size += 1; + } + } else { self.graphs.push(Box::new(SingleNodeSubGraph::new( node, - self.launch_settings.clone(), + node_settings.clone(), ))); - self.launch_settings = - LaunchSettings::most_restrictive(self.launch_settings.clone(), node_settings); + self.launch_settings = Some(node_settings); self.size += 1; } } else { @@ -391,11 +437,30 @@ impl OptimizationBuilder> for GraphBuilder { } }; + println!("Registering op {operation:?}"); self.add_node(node); } fn build(&self) -> JitOptimization { - todo!() + let mut builder = TraceBuilder::new(); + + for graph in self.graphs.iter() { + println!("Trace graph"); + graph.trace(&mut builder); + } + println!("Builder"); + + let trace = builder.build(); + let info = Arc::new(trace.compiling()); + + let grid = match self.launch_settings.as_ref().unwrap().workgroup_size { + WorkgroupSizeSettings::Any => WorkgroupSize::default(), + WorkgroupSizeSettings::Fixed(wk) => wk, + }; + let factory = GraphKernelFactory::new(IdGenerator::generate(), info.clone(), grid); + let optim = GraphOptimization::new(trace, self.size, self.device.clone(), factory); + + JitOptimization::Graph(optim) } fn reset(&mut self) { @@ -409,7 +474,7 @@ impl OptimizationBuilder> for GraphBuilder { fn properties(&self) -> OptimizationProperties { OptimizationProperties { - ready: true, + ready: self.size > 0, score: self.size as u64, } } diff --git a/crates/burn-jit/src/fusion/graph/optimization.rs b/crates/burn-jit/src/fusion/graph/optimization.rs index f54fe9c68..fc43f9110 100644 --- a/crates/burn-jit/src/fusion/graph/optimization.rs +++ b/crates/burn-jit/src/fusion/graph/optimization.rs @@ -1,8 +1,105 @@ -use crate::{fusion::tracing::Trace, Runtime}; +use std::{marker::PhantomData, sync::Arc}; + +use burn_compute::client::ComputeClient; +use burn_fusion::stream::Context; + +use crate::{ + codegen::{calculate_num_elems_dyn_rank, CompilationInfo, CompilationSettings}, + fusion::{ + kernel::{FusionKernel, FusionKernelFactory, OutputRuntimeInfo}, + tracing::Trace, + JitFusionHandle, + }, + gpu::WorkgroupSize, + kernel::elemwise_workgroup, + Runtime, +}; #[derive(new)] -pub struct SubGraphOptimization { +pub struct GraphOptimization { pub(super) trace: Trace, pub(super) num_operations: usize, pub(super) device: R::Device, + factory: GraphKernelFactory, +} + +impl GraphOptimization { + pub(crate) fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { + let client = R::client(&self.device); + + self.run_kernel(context, client, 0) + } + + fn run_kernel( + &mut self, + context: &mut Context<'_, JitFusionHandle>, + client: ComputeClient, + fastest_set_index: usize, + ) { + let info = self.trace.running(); + let kernel_set = &self.factory; + + let kernel = FusionKernel::create( + kernel_set, + &info, + context, + self.device.clone(), + client, + true, + ); + + kernel.execute(); + } + pub(crate) fn len(&self) -> usize { + self.num_operations + } +} + +impl FusionKernelFactory for GraphKernelFactory { + fn create( + &self, + handles_inputs: &[JitFusionHandle], + inputs: &[&burn_tensor::repr::TensorDescription], + outputs: &[&burn_tensor::repr::TensorDescription], + stateful: bool, // Should be set to false when running autotune. + ) -> crate::fusion::kernel::FusionKernel { + let workgroup_size_x = self.grid.x; + let workgroup_size_y = self.grid.y; + let workgroup_size = workgroup_size_x as usize; + for h in handles_inputs { + println!("h {:?}", h); + } + + for o in self.info.scope.operations.iter() { + println!("O {o:?}"); + } + + let settings = CompilationSettings::default(); + let factor = 1; + + let reference_tensor = outputs[0]; + let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape); + let workgroup = elemwise_workgroup(num_elems / factor, workgroup_size); + let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| { + let size = calculate_num_elems_dyn_rank(&tensor.shape) + * self.info.outputs[pos].elem_size::(); + OutputRuntimeInfo::Array { size } + }); + + FusionKernel::new( + self.id.clone(), + self.info.clone(), + settings, + output_infos.collect(), + workgroup, + ) + } +} + +#[derive(new)] +pub struct GraphKernelFactory { + id: String, + info: Arc, + grid: WorkgroupSize, + _runtime: PhantomData, } diff --git a/crates/burn-jit/src/fusion/tracing/base.rs b/crates/burn-jit/src/fusion/tracing/base.rs index 5e2481f94..624e62982 100644 --- a/crates/burn-jit/src/fusion/tracing/base.rs +++ b/crates/burn-jit/src/fusion/tracing/base.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -#[derive(Default, Clone, Serialize, Deserialize)] +#[derive(Default, Clone, Serialize, Deserialize, Debug)] pub struct Scalars { pub(crate) num_float: usize, pub(crate) num_int: usize, diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index 9e7af250d..bad74bf3a 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -49,6 +49,7 @@ impl TraceBuilder { /// Create a variable from an input [tensor description](TensorDescription). pub fn input(&mut self, tensor: &TensorDescription, position: Variable) -> gpu::Variable { + println!("Input {tensor:?}"); let already_exists = self.tensors.contains_key(&tensor.id); let elem = tensor.dtype.into(); @@ -86,6 +87,7 @@ impl TraceBuilder { /// Create a variable from an output [tensor description](TensorDescription). pub fn output(&mut self, tensor: &TensorDescription, position: Variable) -> gpu::Variable { + println!("output {tensor:?}"); let elem = tensor.dtype.into(); // Update the tensor description to the new version. self.tensors @@ -137,7 +139,9 @@ impl TraceBuilder { } /// Build the [trace](Trace). - pub fn build(self) -> Trace { + pub fn build(mut self) -> Trace { + self.scope.unroll_lazy(); + let inputs = self.input_descriptions(); let outputs = self.output_descriptions(); let locals = outputs diff --git a/crates/burn-jit/src/fusion/tracing/trace.rs b/crates/burn-jit/src/fusion/tracing/trace.rs index 20973cb03..b05d08aef 100644 --- a/crates/burn-jit/src/fusion/tracing/trace.rs +++ b/crates/burn-jit/src/fusion/tracing/trace.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; /// captured [tensor operations](burn_tensor::repr::OperationDescription). /// /// A trace should be built using a [builder](super::TraceBuilder). -#[derive(new, Clone, Serialize, Deserialize)] +#[derive(new, Clone, Serialize, Deserialize, Debug)] pub struct Trace { inputs: Vec<(TensorDescription, gpu::Elem, gpu::Variable)>, output_writes: Vec<(TensorDescription, gpu::Elem, gpu::Variable)>, diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index 04aa39e73..48d0b715f 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -90,6 +90,7 @@ where } let source = kernel.compile().source; + println!("Source {source}"); let pipeline = self.compile_source(&source); self.pipelines.insert(kernel_id.clone(), pipeline.clone()); diff --git a/neovide_backtraces.log b/neovide_backtraces.log new file mode 100644 index 000000000..b8fc9093b --- /dev/null +++ b/neovide_backtraces.log @@ -0,0 +1,21 @@ +2024-05-16 13:52:10 - Neovide panicked with the message 'Could not parse event from neovim: invalid event format for event 'hl_attr_define' - [Integer(PosInt(2335)), Map([(String(Utf8String { s: Ok("bold") }), Boolean(true)), (String(Utf8String { s: Ok("italic") }), Boolean(true)), (String(Utf8String { s: Ok("underline") }), Boolean(true)), (String(Utf8String { s: Ok("foreground") }), Integer(PosInt(9601908))), (String(Utf8String { s: Ok("special") }), Integer(PosInt(10935295)))]), Map([(String(Utf8String { s: Ok("bold") }), Boolean(true)), (String(Utf8String { s: Ok("italic") }), Boolean(true)), (String(Utf8String { s: Ok("underline") }), Boolean(true)), (String(Utf8String { s: Ok("foreground") }), Integer(PosInt(245)))]), Array([Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("@variable") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(56)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("GruvboxGreenBold") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(696)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("GruvboxGreenBold") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(696)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("Comment") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(738)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Err(([68, 105, 97, 103, 110, 111, 115, 116, 105, 99, 85, 110, 100, 101, 114, 108, 105, 110, 148, 1, 205, 45, 134], Utf8Error { valid_up_to: 18, error_len: Some(1) })) })), (Nil, Integer(PosInt(5)))])])] - invalid event format Invalid highlight info format'. (File: src/error_handling.rs; Line: 20, Column: 5) + 0: + 1: + 2: + 3: + 4: + 5: + 6: + 7: + 8: + 9: + 10: + 11: + 12: + 13: + 14: + 15: + 16: + 17: + 18: +