This commit is contained in:
nathaniel 2024-05-16 12:59:15 -04:00
parent d1742a1be6
commit 5879c31a96
9 changed files with 214 additions and 19 deletions

View File

@ -69,7 +69,7 @@ pub enum ReadingStrategy {
}
impl Scope {
fn unroll_lazy(&mut self) {
pub fn unroll_lazy(&mut self) {
let map = self.map.map.clone();
let map = map.read().unwrap();

View File

@ -1,7 +1,10 @@
use super::{ElementWise, ElementWiseState};
use super::{ElementWise, ElementWiseState, GraphOptimization};
use crate::{
element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement,
IntElement, JitBackend, Runtime,
element::JitElement,
fusion::{ElementWiseBuilder, GraphBuilder},
kernel,
tensor::JitTensor,
FloatElement, IntElement, JitBackend, Runtime,
};
use burn_compute::client::ComputeClient;
use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
@ -16,6 +19,7 @@ use serde::{Deserialize, Serialize};
pub enum JitOptimization<R: Runtime> {
/// Element wise optimization.
ElementWise(ElementWise<R>),
Graph(GraphOptimization<R>),
}
/// Fusion optimization state type for JIT.
@ -34,18 +38,21 @@ where
fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle<R>>) {
match self {
Self::ElementWise(op) => op.execute(context),
Self::Graph(op) => op.execute(context),
}
}
fn len(&self) -> usize {
match self {
Self::ElementWise(op) => op.len(),
Self::Graph(op) => op.len(),
}
}
fn to_state(&self) -> JitOptimizationState {
match self {
Self::ElementWise(value) => JitOptimizationState::ElementWise(value.to_state()),
Self::Graph(value) => todo!(),
}
}
@ -111,7 +118,7 @@ impl<R: Runtime> FusionRuntime for FusionJitRuntime<R> {
fn optimizations(
device: R::Device,
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
vec![Box::new(ElementWiseBuilder::<R>::new(device))]
vec![Box::new(GraphBuilder::<R>::new(device))]
}
}

View File

@ -1,11 +1,16 @@
use std::sync::Arc;
use crate::{
fusion::{tracing::TraceBuilder, JitOptimization},
gpu::{gpu, LazyProcedure, Scope, Variable, WorkgroupSize},
Runtime,
};
use burn_common::id::IdGenerator;
use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus};
use burn_tensor::repr::{BinaryOperationDescription, OperationDescription, TensorId};
use super::{GraphKernelFactory, GraphOptimization};
#[derive(Clone)]
pub struct LaunchSettings {
workgroup_size: WorkgroupSizeSettings,
@ -61,7 +66,7 @@ pub trait Node: Send + Sync {
fn any_input_visited(&self) -> bool;
fn any_output_visited(&self) -> bool;
fn launch_settings(&self) -> LaunchSettings;
fn register(&self, builder: &mut TraceBuilder);
fn trace(&self, builder: &mut TraceBuilder);
}
pub enum MergeSubGraphResult {
@ -72,6 +77,7 @@ pub enum MergeSubGraphResult {
pub trait SubGraph: Send + Sync {
fn register(self: Box<Self>, node: NodeBoxed) -> MergingResult;
fn merge(self: Box<Self>, other: SubGraphBoxed) -> MergeSubGraphResult;
fn trace(&self, builder: &mut TraceBuilder);
}
#[derive(new)]
@ -122,6 +128,11 @@ impl SubGraph for TwoNodesSubGraph {
),
}
}
fn trace(&self, builder: &mut TraceBuilder) {
self.graph_1.trace(builder);
self.graph_2.trace(builder);
}
}
impl SubGraph for SingleNodeSubGraph {
@ -160,6 +171,10 @@ impl SubGraph for SingleNodeSubGraph {
}
}
}
fn trace(&self, builder: &mut TraceBuilder) {
self.node.trace(builder);
}
}
pub enum MergingResult {
@ -238,7 +253,8 @@ impl Node for FloatAddOp {
self.settings.clone()
}
fn register(&self, builder: &mut TraceBuilder) {
fn trace(&self, builder: &mut TraceBuilder) {
println!("FloatAddOp registering");
let lhs = builder.input(&self.desc.lhs, Variable::Id);
let rhs = builder.input(&self.desc.rhs, Variable::Id);
let out = builder.output(&self.desc.out, Variable::Id);
@ -253,6 +269,10 @@ impl Node for FloatAddOp {
fn expand(&self, scope: &mut Scope, position: Option<Variable>) -> Variable {
let position = position.unwrap_or(Variable::Id);
println!("FloatAddOp expand lazy");
// let lhs = self.lhs;
// let rhs = self.rhs;
// let out = self.out;
let lhs_input = self.lhs;
let rhs_input = self.rhs;
@ -260,10 +280,14 @@ impl Node for FloatAddOp {
let rhs = scope.create_local(self.rhs.item());
let out = self.out;
println!("INPUT {lhs_input:?}");
// Is local but should not.
gpu!(scope, lhs = lhs_input[position]);
gpu!(scope, rhs = rhs_input[position]);
gpu!(scope, out = lhs + rhs);
gpu!(scope, out = lhs + rhs);
out
}
}
@ -277,16 +301,26 @@ pub type NodeBoxed = Box<dyn Node>;
pub struct GraphBuilder<R: Runtime> {
graphs: Vec<SubGraphBoxed>,
launch_settings: LaunchSettings,
launch_settings: Option<LaunchSettings>,
status: OptimizationStatus,
device: R::Device,
size: usize,
}
impl<R: Runtime> GraphBuilder<R> {
pub fn new(device: R::Device) -> Self {
Self {
graphs: Vec::new(),
launch_settings: None,
status: OptimizationStatus::Open,
device,
size: 0,
}
}
fn add_node(&mut self, node: Box<dyn Node>) {
if self.graphs.is_empty() {
let settings = node.launch_settings();
println!("First node");
self.graphs = vec![Box::new(SingleNodeSubGraph::new(node, settings))];
return;
}
@ -332,13 +366,25 @@ impl<R: Runtime> GraphBuilder<R> {
if let Some(node) = node_current.take() {
let node_settings = node.launch_settings();
if self.launch_settings.is_compatible_with(&node_settings) {
if let Some(launch_settings) = &self.launch_settings {
if launch_settings.is_compatible_with(&node_settings) {
let launch_settings =
LaunchSettings::most_restrictive(launch_settings.clone(), node_settings);
self.graphs.push(Box::new(SingleNodeSubGraph::new(
node,
self.launch_settings.clone(),
launch_settings.clone(),
)));
self.launch_settings =
LaunchSettings::most_restrictive(self.launch_settings.clone(), node_settings);
self.launch_settings = Some(launch_settings);
self.size += 1;
}
} else {
self.graphs.push(Box::new(SingleNodeSubGraph::new(
node,
node_settings.clone(),
)));
self.launch_settings = Some(node_settings);
self.size += 1;
}
} else {
@ -391,11 +437,30 @@ impl<R: Runtime> OptimizationBuilder<JitOptimization<R>> for GraphBuilder<R> {
}
};
println!("Registering op {operation:?}");
self.add_node(node);
}
fn build(&self) -> JitOptimization<R> {
todo!()
let mut builder = TraceBuilder::new();
for graph in self.graphs.iter() {
println!("Trace graph");
graph.trace(&mut builder);
}
println!("Builder");
let trace = builder.build();
let info = Arc::new(trace.compiling());
let grid = match self.launch_settings.as_ref().unwrap().workgroup_size {
WorkgroupSizeSettings::Any => WorkgroupSize::default(),
WorkgroupSizeSettings::Fixed(wk) => wk,
};
let factory = GraphKernelFactory::new(IdGenerator::generate(), info.clone(), grid);
let optim = GraphOptimization::new(trace, self.size, self.device.clone(), factory);
JitOptimization::Graph(optim)
}
fn reset(&mut self) {
@ -409,7 +474,7 @@ impl<R: Runtime> OptimizationBuilder<JitOptimization<R>> for GraphBuilder<R> {
fn properties(&self) -> OptimizationProperties {
OptimizationProperties {
ready: true,
ready: self.size > 0,
score: self.size as u64,
}
}

View File

@ -1,8 +1,105 @@
use crate::{fusion::tracing::Trace, Runtime};
use std::{marker::PhantomData, sync::Arc};
use burn_compute::client::ComputeClient;
use burn_fusion::stream::Context;
use crate::{
codegen::{calculate_num_elems_dyn_rank, CompilationInfo, CompilationSettings},
fusion::{
kernel::{FusionKernel, FusionKernelFactory, OutputRuntimeInfo},
tracing::Trace,
JitFusionHandle,
},
gpu::WorkgroupSize,
kernel::elemwise_workgroup,
Runtime,
};
#[derive(new)]
pub struct SubGraphOptimization<R: Runtime> {
pub struct GraphOptimization<R: Runtime> {
pub(super) trace: Trace,
pub(super) num_operations: usize,
pub(super) device: R::Device,
factory: GraphKernelFactory<R>,
}
impl<R: Runtime> GraphOptimization<R> {
pub(crate) fn execute(&mut self, context: &mut Context<'_, JitFusionHandle<R>>) {
let client = R::client(&self.device);
self.run_kernel(context, client, 0)
}
fn run_kernel(
&mut self,
context: &mut Context<'_, JitFusionHandle<R>>,
client: ComputeClient<R::Server, R::Channel>,
fastest_set_index: usize,
) {
let info = self.trace.running();
let kernel_set = &self.factory;
let kernel = FusionKernel::create(
kernel_set,
&info,
context,
self.device.clone(),
client,
true,
);
kernel.execute();
}
pub(crate) fn len(&self) -> usize {
self.num_operations
}
}
impl<R: Runtime> FusionKernelFactory<R> for GraphKernelFactory<R> {
fn create(
&self,
handles_inputs: &[JitFusionHandle<R>],
inputs: &[&burn_tensor::repr::TensorDescription],
outputs: &[&burn_tensor::repr::TensorDescription],
stateful: bool, // Should be set to false when running autotune.
) -> crate::fusion::kernel::FusionKernel<R> {
let workgroup_size_x = self.grid.x;
let workgroup_size_y = self.grid.y;
let workgroup_size = workgroup_size_x as usize;
for h in handles_inputs {
println!("h {:?}", h);
}
for o in self.info.scope.operations.iter() {
println!("O {o:?}");
}
let settings = CompilationSettings::default();
let factor = 1;
let reference_tensor = outputs[0];
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
let workgroup = elemwise_workgroup(num_elems / factor, workgroup_size);
let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| {
let size = calculate_num_elems_dyn_rank(&tensor.shape)
* self.info.outputs[pos].elem_size::<R>();
OutputRuntimeInfo::Array { size }
});
FusionKernel::new(
self.id.clone(),
self.info.clone(),
settings,
output_infos.collect(),
workgroup,
)
}
}
#[derive(new)]
pub struct GraphKernelFactory<R: Runtime> {
id: String,
info: Arc<CompilationInfo>,
grid: WorkgroupSize,
_runtime: PhantomData<R>,
}

View File

@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize};
#[derive(Default, Clone, Serialize, Deserialize)]
#[derive(Default, Clone, Serialize, Deserialize, Debug)]
pub struct Scalars {
pub(crate) num_float: usize,
pub(crate) num_int: usize,

View File

@ -49,6 +49,7 @@ impl TraceBuilder {
/// Create a variable from an input [tensor description](TensorDescription).
pub fn input(&mut self, tensor: &TensorDescription, position: Variable) -> gpu::Variable {
println!("Input {tensor:?}");
let already_exists = self.tensors.contains_key(&tensor.id);
let elem = tensor.dtype.into();
@ -86,6 +87,7 @@ impl TraceBuilder {
/// Create a variable from an output [tensor description](TensorDescription).
pub fn output(&mut self, tensor: &TensorDescription, position: Variable) -> gpu::Variable {
println!("output {tensor:?}");
let elem = tensor.dtype.into();
// Update the tensor description to the new version.
self.tensors
@ -137,7 +139,9 @@ impl TraceBuilder {
}
/// Build the [trace](Trace).
pub fn build(self) -> Trace {
pub fn build(mut self) -> Trace {
self.scope.unroll_lazy();
let inputs = self.input_descriptions();
let outputs = self.output_descriptions();
let locals = outputs

View File

@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
/// captured [tensor operations](burn_tensor::repr::OperationDescription).
///
/// A trace should be built using a [builder](super::TraceBuilder).
#[derive(new, Clone, Serialize, Deserialize)]
#[derive(new, Clone, Serialize, Deserialize, Debug)]
pub struct Trace {
inputs: Vec<(TensorDescription, gpu::Elem, gpu::Variable)>,
output_writes: Vec<(TensorDescription, gpu::Elem, gpu::Variable)>,

View File

@ -90,6 +90,7 @@ where
}
let source = kernel.compile().source;
println!("Source {source}");
let pipeline = self.compile_source(&source);
self.pipelines.insert(kernel_id.clone(), pipeline.clone());

21
neovide_backtraces.log Normal file
View File

@ -0,0 +1,21 @@
2024-05-16 13:52:10 - Neovide panicked with the message 'Could not parse event from neovim: invalid event format for event 'hl_attr_define' - [Integer(PosInt(2335)), Map([(String(Utf8String { s: Ok("bold") }), Boolean(true)), (String(Utf8String { s: Ok("italic") }), Boolean(true)), (String(Utf8String { s: Ok("underline") }), Boolean(true)), (String(Utf8String { s: Ok("foreground") }), Integer(PosInt(9601908))), (String(Utf8String { s: Ok("special") }), Integer(PosInt(10935295)))]), Map([(String(Utf8String { s: Ok("bold") }), Boolean(true)), (String(Utf8String { s: Ok("italic") }), Boolean(true)), (String(Utf8String { s: Ok("underline") }), Boolean(true)), (String(Utf8String { s: Ok("foreground") }), Integer(PosInt(245)))]), Array([Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("@variable") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(56)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("GruvboxGreenBold") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(696)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("GruvboxGreenBold") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(696)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("Comment") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(738)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Err(([68, 105, 97, 103, 110, 111, 115, 116, 105, 99, 85, 110, 100, 101, 114, 108, 105, 110, 148, 1, 205, 45, 134], Utf8Error { valid_up_to: 18, error_len: Some(1) })) })), (Nil, Integer(PosInt(5)))])])] - invalid event format Invalid highlight info format'. (File: src/error_handling.rs; Line: 20, Column: 5)
0: <unknown>
1: <unknown>
2: <unknown>
3: <unknown>
4: <unknown>
5: <unknown>
6: <unknown>
7: <unknown>
8: <unknown>
9: <unknown>
10: <unknown>
11: <unknown>
12: <unknown>
13: <unknown>
14: <unknown>
15: <unknown>
16: <unknown>
17: <unknown>
18: <unknown>