mirror of https://github.com/tracel-ai/burn.git
So wip
This commit is contained in:
parent
d1742a1be6
commit
5879c31a96
|
@ -69,7 +69,7 @@ pub enum ReadingStrategy {
|
|||
}
|
||||
|
||||
impl Scope {
|
||||
fn unroll_lazy(&mut self) {
|
||||
pub fn unroll_lazy(&mut self) {
|
||||
let map = self.map.map.clone();
|
||||
let map = map.read().unwrap();
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
use super::{ElementWise, ElementWiseState};
|
||||
use super::{ElementWise, ElementWiseState, GraphOptimization};
|
||||
use crate::{
|
||||
element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement,
|
||||
IntElement, JitBackend, Runtime,
|
||||
element::JitElement,
|
||||
fusion::{ElementWiseBuilder, GraphBuilder},
|
||||
kernel,
|
||||
tensor::JitTensor,
|
||||
FloatElement, IntElement, JitBackend, Runtime,
|
||||
};
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
|
||||
|
@ -16,6 +19,7 @@ use serde::{Deserialize, Serialize};
|
|||
pub enum JitOptimization<R: Runtime> {
|
||||
/// Element wise optimization.
|
||||
ElementWise(ElementWise<R>),
|
||||
Graph(GraphOptimization<R>),
|
||||
}
|
||||
|
||||
/// Fusion optimization state type for JIT.
|
||||
|
@ -34,18 +38,21 @@ where
|
|||
fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle<R>>) {
|
||||
match self {
|
||||
Self::ElementWise(op) => op.execute(context),
|
||||
Self::Graph(op) => op.execute(context),
|
||||
}
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
match self {
|
||||
Self::ElementWise(op) => op.len(),
|
||||
Self::Graph(op) => op.len(),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_state(&self) -> JitOptimizationState {
|
||||
match self {
|
||||
Self::ElementWise(value) => JitOptimizationState::ElementWise(value.to_state()),
|
||||
Self::Graph(value) => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,7 +118,7 @@ impl<R: Runtime> FusionRuntime for FusionJitRuntime<R> {
|
|||
fn optimizations(
|
||||
device: R::Device,
|
||||
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
|
||||
vec![Box::new(ElementWiseBuilder::<R>::new(device))]
|
||||
vec![Box::new(GraphBuilder::<R>::new(device))]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
fusion::{tracing::TraceBuilder, JitOptimization},
|
||||
gpu::{gpu, LazyProcedure, Scope, Variable, WorkgroupSize},
|
||||
Runtime,
|
||||
};
|
||||
use burn_common::id::IdGenerator;
|
||||
use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus};
|
||||
use burn_tensor::repr::{BinaryOperationDescription, OperationDescription, TensorId};
|
||||
|
||||
use super::{GraphKernelFactory, GraphOptimization};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LaunchSettings {
|
||||
workgroup_size: WorkgroupSizeSettings,
|
||||
|
@ -61,7 +66,7 @@ pub trait Node: Send + Sync {
|
|||
fn any_input_visited(&self) -> bool;
|
||||
fn any_output_visited(&self) -> bool;
|
||||
fn launch_settings(&self) -> LaunchSettings;
|
||||
fn register(&self, builder: &mut TraceBuilder);
|
||||
fn trace(&self, builder: &mut TraceBuilder);
|
||||
}
|
||||
|
||||
pub enum MergeSubGraphResult {
|
||||
|
@ -72,6 +77,7 @@ pub enum MergeSubGraphResult {
|
|||
pub trait SubGraph: Send + Sync {
|
||||
fn register(self: Box<Self>, node: NodeBoxed) -> MergingResult;
|
||||
fn merge(self: Box<Self>, other: SubGraphBoxed) -> MergeSubGraphResult;
|
||||
fn trace(&self, builder: &mut TraceBuilder);
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -122,6 +128,11 @@ impl SubGraph for TwoNodesSubGraph {
|
|||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn trace(&self, builder: &mut TraceBuilder) {
|
||||
self.graph_1.trace(builder);
|
||||
self.graph_2.trace(builder);
|
||||
}
|
||||
}
|
||||
|
||||
impl SubGraph for SingleNodeSubGraph {
|
||||
|
@ -160,6 +171,10 @@ impl SubGraph for SingleNodeSubGraph {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn trace(&self, builder: &mut TraceBuilder) {
|
||||
self.node.trace(builder);
|
||||
}
|
||||
}
|
||||
|
||||
pub enum MergingResult {
|
||||
|
@ -238,7 +253,8 @@ impl Node for FloatAddOp {
|
|||
self.settings.clone()
|
||||
}
|
||||
|
||||
fn register(&self, builder: &mut TraceBuilder) {
|
||||
fn trace(&self, builder: &mut TraceBuilder) {
|
||||
println!("FloatAddOp registering");
|
||||
let lhs = builder.input(&self.desc.lhs, Variable::Id);
|
||||
let rhs = builder.input(&self.desc.rhs, Variable::Id);
|
||||
let out = builder.output(&self.desc.out, Variable::Id);
|
||||
|
@ -253,6 +269,10 @@ impl Node for FloatAddOp {
|
|||
fn expand(&self, scope: &mut Scope, position: Option<Variable>) -> Variable {
|
||||
let position = position.unwrap_or(Variable::Id);
|
||||
|
||||
println!("FloatAddOp expand lazy");
|
||||
// let lhs = self.lhs;
|
||||
// let rhs = self.rhs;
|
||||
// let out = self.out;
|
||||
let lhs_input = self.lhs;
|
||||
let rhs_input = self.rhs;
|
||||
|
||||
|
@ -260,10 +280,14 @@ impl Node for FloatAddOp {
|
|||
let rhs = scope.create_local(self.rhs.item());
|
||||
let out = self.out;
|
||||
|
||||
println!("INPUT {lhs_input:?}");
|
||||
// Is local but should not.
|
||||
gpu!(scope, lhs = lhs_input[position]);
|
||||
gpu!(scope, rhs = rhs_input[position]);
|
||||
gpu!(scope, out = lhs + rhs);
|
||||
|
||||
gpu!(scope, out = lhs + rhs);
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
@ -277,16 +301,26 @@ pub type NodeBoxed = Box<dyn Node>;
|
|||
|
||||
pub struct GraphBuilder<R: Runtime> {
|
||||
graphs: Vec<SubGraphBoxed>,
|
||||
launch_settings: LaunchSettings,
|
||||
launch_settings: Option<LaunchSettings>,
|
||||
status: OptimizationStatus,
|
||||
device: R::Device,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl<R: Runtime> GraphBuilder<R> {
|
||||
pub fn new(device: R::Device) -> Self {
|
||||
Self {
|
||||
graphs: Vec::new(),
|
||||
launch_settings: None,
|
||||
status: OptimizationStatus::Open,
|
||||
device,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
fn add_node(&mut self, node: Box<dyn Node>) {
|
||||
if self.graphs.is_empty() {
|
||||
let settings = node.launch_settings();
|
||||
println!("First node");
|
||||
self.graphs = vec![Box::new(SingleNodeSubGraph::new(node, settings))];
|
||||
return;
|
||||
}
|
||||
|
@ -332,13 +366,25 @@ impl<R: Runtime> GraphBuilder<R> {
|
|||
if let Some(node) = node_current.take() {
|
||||
let node_settings = node.launch_settings();
|
||||
|
||||
if self.launch_settings.is_compatible_with(&node_settings) {
|
||||
if let Some(launch_settings) = &self.launch_settings {
|
||||
if launch_settings.is_compatible_with(&node_settings) {
|
||||
let launch_settings =
|
||||
LaunchSettings::most_restrictive(launch_settings.clone(), node_settings);
|
||||
|
||||
self.graphs.push(Box::new(SingleNodeSubGraph::new(
|
||||
node,
|
||||
launch_settings.clone(),
|
||||
)));
|
||||
|
||||
self.launch_settings = Some(launch_settings);
|
||||
self.size += 1;
|
||||
}
|
||||
} else {
|
||||
self.graphs.push(Box::new(SingleNodeSubGraph::new(
|
||||
node,
|
||||
self.launch_settings.clone(),
|
||||
node_settings.clone(),
|
||||
)));
|
||||
self.launch_settings =
|
||||
LaunchSettings::most_restrictive(self.launch_settings.clone(), node_settings);
|
||||
self.launch_settings = Some(node_settings);
|
||||
self.size += 1;
|
||||
}
|
||||
} else {
|
||||
|
@ -391,11 +437,30 @@ impl<R: Runtime> OptimizationBuilder<JitOptimization<R>> for GraphBuilder<R> {
|
|||
}
|
||||
};
|
||||
|
||||
println!("Registering op {operation:?}");
|
||||
self.add_node(node);
|
||||
}
|
||||
|
||||
fn build(&self) -> JitOptimization<R> {
|
||||
todo!()
|
||||
let mut builder = TraceBuilder::new();
|
||||
|
||||
for graph in self.graphs.iter() {
|
||||
println!("Trace graph");
|
||||
graph.trace(&mut builder);
|
||||
}
|
||||
println!("Builder");
|
||||
|
||||
let trace = builder.build();
|
||||
let info = Arc::new(trace.compiling());
|
||||
|
||||
let grid = match self.launch_settings.as_ref().unwrap().workgroup_size {
|
||||
WorkgroupSizeSettings::Any => WorkgroupSize::default(),
|
||||
WorkgroupSizeSettings::Fixed(wk) => wk,
|
||||
};
|
||||
let factory = GraphKernelFactory::new(IdGenerator::generate(), info.clone(), grid);
|
||||
let optim = GraphOptimization::new(trace, self.size, self.device.clone(), factory);
|
||||
|
||||
JitOptimization::Graph(optim)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
|
@ -409,7 +474,7 @@ impl<R: Runtime> OptimizationBuilder<JitOptimization<R>> for GraphBuilder<R> {
|
|||
|
||||
fn properties(&self) -> OptimizationProperties {
|
||||
OptimizationProperties {
|
||||
ready: true,
|
||||
ready: self.size > 0,
|
||||
score: self.size as u64,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,105 @@
|
|||
use crate::{fusion::tracing::Trace, Runtime};
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_fusion::stream::Context;
|
||||
|
||||
use crate::{
|
||||
codegen::{calculate_num_elems_dyn_rank, CompilationInfo, CompilationSettings},
|
||||
fusion::{
|
||||
kernel::{FusionKernel, FusionKernelFactory, OutputRuntimeInfo},
|
||||
tracing::Trace,
|
||||
JitFusionHandle,
|
||||
},
|
||||
gpu::WorkgroupSize,
|
||||
kernel::elemwise_workgroup,
|
||||
Runtime,
|
||||
};
|
||||
|
||||
#[derive(new)]
|
||||
pub struct SubGraphOptimization<R: Runtime> {
|
||||
pub struct GraphOptimization<R: Runtime> {
|
||||
pub(super) trace: Trace,
|
||||
pub(super) num_operations: usize,
|
||||
pub(super) device: R::Device,
|
||||
factory: GraphKernelFactory<R>,
|
||||
}
|
||||
|
||||
impl<R: Runtime> GraphOptimization<R> {
|
||||
pub(crate) fn execute(&mut self, context: &mut Context<'_, JitFusionHandle<R>>) {
|
||||
let client = R::client(&self.device);
|
||||
|
||||
self.run_kernel(context, client, 0)
|
||||
}
|
||||
|
||||
fn run_kernel(
|
||||
&mut self,
|
||||
context: &mut Context<'_, JitFusionHandle<R>>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
fastest_set_index: usize,
|
||||
) {
|
||||
let info = self.trace.running();
|
||||
let kernel_set = &self.factory;
|
||||
|
||||
let kernel = FusionKernel::create(
|
||||
kernel_set,
|
||||
&info,
|
||||
context,
|
||||
self.device.clone(),
|
||||
client,
|
||||
true,
|
||||
);
|
||||
|
||||
kernel.execute();
|
||||
}
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.num_operations
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> FusionKernelFactory<R> for GraphKernelFactory<R> {
|
||||
fn create(
|
||||
&self,
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
inputs: &[&burn_tensor::repr::TensorDescription],
|
||||
outputs: &[&burn_tensor::repr::TensorDescription],
|
||||
stateful: bool, // Should be set to false when running autotune.
|
||||
) -> crate::fusion::kernel::FusionKernel<R> {
|
||||
let workgroup_size_x = self.grid.x;
|
||||
let workgroup_size_y = self.grid.y;
|
||||
let workgroup_size = workgroup_size_x as usize;
|
||||
for h in handles_inputs {
|
||||
println!("h {:?}", h);
|
||||
}
|
||||
|
||||
for o in self.info.scope.operations.iter() {
|
||||
println!("O {o:?}");
|
||||
}
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let factor = 1;
|
||||
|
||||
let reference_tensor = outputs[0];
|
||||
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
|
||||
let workgroup = elemwise_workgroup(num_elems / factor, workgroup_size);
|
||||
let output_infos = outputs.iter().enumerate().map(|(pos, tensor)| {
|
||||
let size = calculate_num_elems_dyn_rank(&tensor.shape)
|
||||
* self.info.outputs[pos].elem_size::<R>();
|
||||
OutputRuntimeInfo::Array { size }
|
||||
});
|
||||
|
||||
FusionKernel::new(
|
||||
self.id.clone(),
|
||||
self.info.clone(),
|
||||
settings,
|
||||
output_infos.collect(),
|
||||
workgroup,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct GraphKernelFactory<R: Runtime> {
|
||||
id: String,
|
||||
info: Arc<CompilationInfo>,
|
||||
grid: WorkgroupSize,
|
||||
_runtime: PhantomData<R>,
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
#[derive(Default, Clone, Serialize, Deserialize)]
|
||||
#[derive(Default, Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct Scalars {
|
||||
pub(crate) num_float: usize,
|
||||
pub(crate) num_int: usize,
|
||||
|
|
|
@ -49,6 +49,7 @@ impl TraceBuilder {
|
|||
|
||||
/// Create a variable from an input [tensor description](TensorDescription).
|
||||
pub fn input(&mut self, tensor: &TensorDescription, position: Variable) -> gpu::Variable {
|
||||
println!("Input {tensor:?}");
|
||||
let already_exists = self.tensors.contains_key(&tensor.id);
|
||||
let elem = tensor.dtype.into();
|
||||
|
||||
|
@ -86,6 +87,7 @@ impl TraceBuilder {
|
|||
|
||||
/// Create a variable from an output [tensor description](TensorDescription).
|
||||
pub fn output(&mut self, tensor: &TensorDescription, position: Variable) -> gpu::Variable {
|
||||
println!("output {tensor:?}");
|
||||
let elem = tensor.dtype.into();
|
||||
// Update the tensor description to the new version.
|
||||
self.tensors
|
||||
|
@ -137,7 +139,9 @@ impl TraceBuilder {
|
|||
}
|
||||
|
||||
/// Build the [trace](Trace).
|
||||
pub fn build(self) -> Trace {
|
||||
pub fn build(mut self) -> Trace {
|
||||
self.scope.unroll_lazy();
|
||||
|
||||
let inputs = self.input_descriptions();
|
||||
let outputs = self.output_descriptions();
|
||||
let locals = outputs
|
||||
|
|
|
@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
|||
/// captured [tensor operations](burn_tensor::repr::OperationDescription).
|
||||
///
|
||||
/// A trace should be built using a [builder](super::TraceBuilder).
|
||||
#[derive(new, Clone, Serialize, Deserialize)]
|
||||
#[derive(new, Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct Trace {
|
||||
inputs: Vec<(TensorDescription, gpu::Elem, gpu::Variable)>,
|
||||
output_writes: Vec<(TensorDescription, gpu::Elem, gpu::Variable)>,
|
||||
|
|
|
@ -90,6 +90,7 @@ where
|
|||
}
|
||||
|
||||
let source = kernel.compile().source;
|
||||
println!("Source {source}");
|
||||
let pipeline = self.compile_source(&source);
|
||||
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
2024-05-16 13:52:10 - Neovide panicked with the message 'Could not parse event from neovim: invalid event format for event 'hl_attr_define' - [Integer(PosInt(2335)), Map([(String(Utf8String { s: Ok("bold") }), Boolean(true)), (String(Utf8String { s: Ok("italic") }), Boolean(true)), (String(Utf8String { s: Ok("underline") }), Boolean(true)), (String(Utf8String { s: Ok("foreground") }), Integer(PosInt(9601908))), (String(Utf8String { s: Ok("special") }), Integer(PosInt(10935295)))]), Map([(String(Utf8String { s: Ok("bold") }), Boolean(true)), (String(Utf8String { s: Ok("italic") }), Boolean(true)), (String(Utf8String { s: Ok("underline") }), Boolean(true)), (String(Utf8String { s: Ok("foreground") }), Integer(PosInt(245)))]), Array([Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("@variable") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(56)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("GruvboxGreenBold") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(696)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("GruvboxGreenBold") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(696)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Ok("Comment") })), (String(Utf8String { s: Ok("id") }), Integer(PosInt(738)))]), Map([(String(Utf8String { s: Ok("kind") }), String(Utf8String { s: Ok("syntax") })), (String(Utf8String { s: Ok("hi_name") }), String(Utf8String { s: Err(([68, 105, 97, 103, 110, 111, 115, 116, 105, 99, 85, 110, 100, 101, 114, 108, 105, 110, 148, 1, 205, 45, 134], Utf8Error { valid_up_to: 18, error_len: Some(1) })) })), (Nil, Integer(PosInt(5)))])])] - invalid event format Invalid highlight info format'. (File: src/error_handling.rs; Line: 20, Column: 5)
|
||||
0: <unknown>
|
||||
1: <unknown>
|
||||
2: <unknown>
|
||||
3: <unknown>
|
||||
4: <unknown>
|
||||
5: <unknown>
|
||||
6: <unknown>
|
||||
7: <unknown>
|
||||
8: <unknown>
|
||||
9: <unknown>
|
||||
10: <unknown>
|
||||
11: <unknown>
|
||||
12: <unknown>
|
||||
13: <unknown>
|
||||
14: <unknown>
|
||||
15: <unknown>
|
||||
16: <unknown>
|
||||
17: <unknown>
|
||||
18: <unknown>
|
||||
|
Loading…
Reference in New Issue