diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index 3d5032b56..878415081 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -11,6 +11,9 @@ extern crate derive_new; extern crate alloc; +/// Server module. +pub mod server; + mod tensor; /// Burn Tensor representaton diff --git a/crates/burn-tensor/src/server/backend.rs b/crates/burn-tensor/src/server/backend.rs new file mode 100644 index 000000000..a848ceb60 --- /dev/null +++ b/crates/burn-tensor/src/server/backend.rs @@ -0,0 +1,102 @@ +use core::marker::PhantomData; + +use crate::{ + backend::{Backend, BackendBridge, DeviceOps}, + quantization::QTensorPrimitive, + repr::{OperationDescription, TensorDescription}, +}; + +pub struct Server { + r: PhantomData, +} + +pub trait ServerBackend: Send + Sync + 'static { + type Runtime: ServerRuntime; +} + +pub trait ServerRuntime { + type Client: ServerClient; + type Device: DeviceOps; + type Bridge; +} + +pub struct ServerTensor { + desc: TensorDescription, + client: R::Client, +} + +impl core::fmt::Debug for Server { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("server")) + } +} + +impl core::fmt::Debug for ServerTensor { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("tensor")) + } +} + +impl Clone for ServerTensor { + fn clone(&self) -> Self { + Self { + desc: self.desc.clone(), + client: self.client.clone(), + } + } +} + +impl Clone for Server { + fn clone(&self) -> Self { + Self { r: PhantomData } + } +} + +impl Default for Server { + fn default() -> Self { + Self { r: PhantomData } + } +} + +impl QTensorPrimitive for ServerTensor { + fn scheme(&self) -> &crate::quantization::QuantizationScheme { + todo!() + } + + fn strategy(&self) -> crate::quantization::QuantizationStrategy { + todo!() + } +} + +pub trait ServerClient: Clone + Send + Sync { + fn execute(&self, op: OperationDescription); +} + +impl Backend for Server +where + ::Bridge: BackendBridge + 'static, +{ + type Device = ::Device; + + type FullPrecisionBridge = ::Bridge; + + type FloatTensorPrimitive = ServerTensor; + + type FloatElem = f32; + + type IntTensorPrimitive = ServerTensor; + + type IntElem = i32; + + type BoolTensorPrimitive = ServerTensor; + + type QuantizedTensorPrimitive = ServerTensor; + + fn name() -> String { + todo!() + } + + fn seed(seed: u64) { + todo!() + } +} diff --git a/crates/burn-tensor/src/server/mod.rs b/crates/burn-tensor/src/server/mod.rs new file mode 100644 index 000000000..2a5ef9bbd --- /dev/null +++ b/crates/burn-tensor/src/server/mod.rs @@ -0,0 +1,4 @@ +mod backend; +mod ops; + +pub use backend::*; diff --git a/crates/burn-tensor/src/server/ops/mod.rs b/crates/burn-tensor/src/server/ops/mod.rs new file mode 100644 index 000000000..3fe3545ee --- /dev/null +++ b/crates/burn-tensor/src/server/ops/mod.rs @@ -0,0 +1,6 @@ +mod op_activation; +mod op_bool; +mod op_float; +mod op_int; +mod op_module; +mod op_qfloat; diff --git a/crates/burn-tensor/src/server/ops/op_activation.rs b/crates/burn-tensor/src/server/ops/op_activation.rs new file mode 100644 index 000000000..e2c648720 --- /dev/null +++ b/crates/burn-tensor/src/server/ops/op_activation.rs @@ -0,0 +1,4 @@ +use crate::server::Server; +use crate::{ops::ActivationOps, server::ServerBackend}; + +impl ActivationOps for Server {} diff --git a/crates/burn-tensor/src/server/ops/op_bool.rs b/crates/burn-tensor/src/server/ops/op_bool.rs new file mode 100644 index 000000000..b3680e6a3 --- /dev/null +++ b/crates/burn-tensor/src/server/ops/op_bool.rs @@ -0,0 +1,119 @@ +use crate::{ + ops::BoolTensorOps, + server::{Server, ServerBackend}, +}; + +impl BoolTensorOps for Server { + fn bool_empty( + shape: crate::Shape, + device: &crate::Device, + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_shape(tensor: &crate::ops::BoolTensor) -> crate::Shape { + todo!() + } + + fn bool_into_data( + tensor: crate::ops::BoolTensor, + ) -> impl core::future::Future + Send { + todo!() + } + + fn bool_from_data( + data: crate::TensorData, + device: &crate::Device, + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_into_int( + tensor: crate::ops::BoolTensor, + ) -> crate::ops::IntTensor { + todo!() + } + + fn bool_into_float( + tensor: crate::ops::BoolTensor, + ) -> crate::ops::FloatTensor { + todo!() + } + + fn bool_device( + tensor: &crate::ops::BoolTensor, + ) -> crate::Device { + todo!() + } + + fn bool_to_device( + tensor: crate::ops::BoolTensor, + device: &crate::Device, + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_reshape( + tensor: crate::ops::BoolTensor, + shape: crate::Shape, + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_slice( + tensor: crate::ops::BoolTensor, + ranges: [core::ops::Range; D2], + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_slice_assign( + tensor: crate::ops::BoolTensor, + ranges: [core::ops::Range; D2], + value: crate::ops::BoolTensor, + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_equal( + lhs: crate::ops::BoolTensor, + rhs: crate::ops::BoolTensor, + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_not( + tensor: crate::ops::BoolTensor, + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_swap_dims( + tensor: crate::ops::BoolTensor, + dim1: usize, + dim2: usize, + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_permute( + tensor: crate::ops::BoolTensor, + axes: [usize; D], + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_flip( + tensor: crate::ops::BoolTensor, + axes: &[usize], + ) -> crate::ops::BoolTensor { + todo!() + } + + fn bool_expand( + tensor: crate::ops::BoolTensor, + shape: crate::Shape, + ) -> crate::ops::BoolTensor { + todo!() + } +} diff --git a/crates/burn-tensor/src/server/ops/op_float.rs b/crates/burn-tensor/src/server/ops/op_float.rs new file mode 100644 index 000000000..79bf54eea --- /dev/null +++ b/crates/burn-tensor/src/server/ops/op_float.rs @@ -0,0 +1,4 @@ +use crate::server::Server; +use crate::{ops::FloatTensorOps, server::ServerBackend}; + +impl FloatTensorOps for Server {} diff --git a/crates/burn-tensor/src/server/ops/op_int.rs b/crates/burn-tensor/src/server/ops/op_int.rs new file mode 100644 index 000000000..5cc7e08ac --- /dev/null +++ b/crates/burn-tensor/src/server/ops/op_int.rs @@ -0,0 +1,4 @@ +use crate::server::Server; +use crate::{ops::IntTensorOps, server::ServerBackend}; + +impl IntTensorOps for Server {} diff --git a/crates/burn-tensor/src/server/ops/op_module.rs b/crates/burn-tensor/src/server/ops/op_module.rs new file mode 100644 index 000000000..96276376e --- /dev/null +++ b/crates/burn-tensor/src/server/ops/op_module.rs @@ -0,0 +1,7 @@ +use crate::server::Server; +use crate::{ + ops::{ConvOptions, ConvTransposeOptions, FloatTensor, ModuleOps}, + server::ServerBackend, +}; + +impl ModuleOps for Server {} diff --git a/crates/burn-tensor/src/server/ops/op_qfloat.rs b/crates/burn-tensor/src/server/ops/op_qfloat.rs new file mode 100644 index 000000000..9695f7ecd --- /dev/null +++ b/crates/burn-tensor/src/server/ops/op_qfloat.rs @@ -0,0 +1,10 @@ +use core::future::Future; + +use crate::{ + ops::{FloatTensor, QTensorOps, QuantizedTensor}, + quantization::{QuantizationParametersPrimitive, QuantizationScheme}, + server::{Server, ServerBackend}, + Device, TensorData, +}; + +impl QTensorOps for Server {} diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index c22914a60..6d9a2d74f 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -60,8 +60,8 @@ pub trait Backend: + ActivationOps + QTensorOps + Clone - + Sized + Default + + Sized + Send + Sync + core::fmt::Debug