it compiles

This commit is contained in:
nathaniel 2024-09-17 14:09:37 -04:00
parent 049ceff0e2
commit 931a72e4b7
12 changed files with 1270 additions and 100 deletions

View File

@ -18,3 +18,6 @@ pub mod network;
/// Parallel utilities.
pub mod parallel;
/// Streaming utilities.
pub mod stream;

View File

@ -0,0 +1,64 @@
/// The stream id.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub struct StreamId {
#[cfg(feature = "std")]
value: std::thread::ThreadId,
#[cfg(not(feature = "std"))]
value: (),
}
impl StreamId {
/// Get the current stream id.
pub fn current() -> Self {
Self {
#[cfg(feature = "std")]
value: Self::id(),
#[cfg(not(feature = "std"))]
value: (),
}
}
#[cfg(feature = "std")]
fn id() -> std::thread::ThreadId {
std::thread_local! {
static ID: std::cell::OnceCell::<std::thread::ThreadId> = const { std::cell::OnceCell::new() };
};
// Getting the current thread is expensive, so we cache the value into a thread local
// variable, which is very fast.
ID.with(|cell| *cell.get_or_init(|| std::thread::current().id()))
}
}
impl core::fmt::Display for StreamId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("StreamID({:?})", self.value))
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
#[test]
fn stream_id_from_different_threads() {
let current = StreamId::current();
let thread1 = std::thread::spawn(|| (StreamId::current(), StreamId::current()));
let thread2 = std::thread::spawn(StreamId::current);
let (stream_1, stream_11) = thread1.join().unwrap();
let stream_2 = thread2.join().unwrap();
assert_ne!(current, stream_1, "Should be different from thread 1");
assert_ne!(current, stream_2, "Should be different from thread 2");
assert_ne!(
stream_1, stream_2,
"Should be different from different threads"
);
assert_eq!(
stream_1, stream_11,
"Should be the same, since same thread."
);
}
}

View File

@ -23,44 +23,6 @@ impl<R: FusionRuntime> Default for OperationQueue<R> {
}
}
/// The stream id.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub struct StreamId {
#[cfg(feature = "std")]
value: std::thread::ThreadId,
#[cfg(not(feature = "std"))]
value: (),
}
impl StreamId {
/// Get the current stream id.
pub fn current() -> Self {
Self {
#[cfg(feature = "std")]
value: Self::id(),
#[cfg(not(feature = "std"))]
value: (),
}
}
#[cfg(feature = "std")]
fn id() -> std::thread::ThreadId {
std::thread_local! {
static ID: std::cell::OnceCell::<std::thread::ThreadId> = const { std::cell::OnceCell::new() };
};
// Getting the current thread is expensive, so we cache the value into a thread local
// variable, which is very fast.
ID.with(|cell| *cell.get_or_init(|| std::thread::current().id()))
}
}
impl core::fmt::Display for StreamId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("StreamID({:?})", self.value))
}
}
impl<R: FusionRuntime> OperationQueue<R> {
/// Create a new empty queue.
pub fn new() -> Self {
@ -94,30 +56,3 @@ impl<R: FusionRuntime> OperationQueue<R> {
self.len() == 0
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
#[test]
fn stream_id_from_different_threads() {
let current = StreamId::current();
let thread1 = std::thread::spawn(|| (StreamId::current(), StreamId::current()));
let thread2 = std::thread::spawn(StreamId::current);
let (stream_1, stream_11) = thread1.join().unwrap();
let stream_2 = thread2.join().unwrap();
assert_ne!(current, stream_1, "Should be different from thread 1");
assert_ne!(current, stream_2, "Should be different from thread 2");
assert_ne!(
stream_1, stream_2,
"Should be different from different threads"
);
assert_eq!(
stream_1, stream_11,
"Should be the same, since same thread."
);
}
}

View File

@ -8,3 +8,5 @@ mod multi;
pub use base::*;
pub use context::*;
pub use multi::*;
pub use burn_common::stream::StreamId;

View File

@ -1,28 +1,28 @@
use core::marker::PhantomData;
use core::{future::Future, marker::PhantomData};
use burn_common::stream::StreamId;
use crate::{
backend::{Backend, BackendBridge, DeviceOps},
quantization::QTensorPrimitive,
repr::{OperationDescription, TensorDescription},
TensorData,
};
use super::ServerTensor;
pub struct Server<B: ServerBackend> {
r: PhantomData<B>,
}
pub trait ServerBackend: Send + Sync + 'static {
pub trait ServerBackend: Send + Sync + 'static + Sized {
type Runtime: ServerRuntime;
type Bridge: BackendBridge<Server<Self>> + 'static;
}
pub trait ServerRuntime {
type Client: ServerClient;
type Device: DeviceOps;
type Bridge;
}
pub struct ServerTensor<R: ServerRuntime> {
desc: TensorDescription,
client: R::Client,
}
impl<B: ServerBackend> core::fmt::Debug for Server<B> {
@ -31,21 +31,6 @@ impl<B: ServerBackend> core::fmt::Debug for Server<B> {
}
}
impl<R: ServerRuntime> core::fmt::Debug for ServerTensor<R> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("tensor"))
}
}
impl<R: ServerRuntime> Clone for ServerTensor<R> {
fn clone(&self) -> Self {
Self {
desc: self.desc.clone(),
client: self.client.clone(),
}
}
}
impl<B: ServerBackend> Clone for Server<B> {
fn clone(&self) -> Self {
Self { r: PhantomData }
@ -69,16 +54,20 @@ impl<R: ServerRuntime> QTensorPrimitive for ServerTensor<R> {
}
pub trait ServerClient: Clone + Send + Sync {
/// Execute an operation.
fn execute(&self, op: OperationDescription);
/// Read the values contained by a tensor.
fn read_tensor(
&self,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send;
}
impl<B: ServerBackend> Backend for Server<B>
where
<B::Runtime as ServerRuntime>::Bridge: BackendBridge<Self> + 'static,
{
impl<B: ServerBackend> Backend for Server<B> {
type Device = <B::Runtime as ServerRuntime>::Device;
type FullPrecisionBridge = <B::Runtime as ServerRuntime>::Bridge;
type FullPrecisionBridge = B::Bridge;
type FloatTensorPrimitive<const D: usize> = ServerTensor<B::Runtime>;

View File

@ -1,4 +1,6 @@
mod backend;
mod ops;
mod tensor;
pub use backend::*;
pub use tensor::*;

View File

@ -18,7 +18,7 @@ impl<B: ServerBackend> BoolTensorOps<Self> for Server<B> {
fn bool_into_data<const D: usize>(
tensor: crate::ops::BoolTensor<Self, D>,
) -> impl core::future::Future<Output = crate::TensorData> + Send {
todo!()
async { tensor.into_data().await }
}
fn bool_from_data<const D: usize>(

View File

@ -1,4 +1,455 @@
use crate::ops::{BoolTensor, FloatElem, FloatTensor, IntTensor};
use crate::server::Server;
use crate::{ops::FloatTensorOps, server::ServerBackend};
use crate::{Device, Distribution, Shape, TensorData};
use std::ops::Range;
impl<B: ServerBackend> FloatTensorOps<Self> for Server<B> {}
impl<B: ServerBackend> FloatTensorOps<Self> for Server<B> {
fn float_from_data<const D: usize>(
data: TensorData,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
todo!();
}
fn float_random<const D: usize>(
shape: Shape<D>,
distribution: Distribution,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
todo!()
}
fn float_ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
todo!()
}
fn float_full<const D: usize>(
shape: Shape<D>,
fill_value: FloatElem<Self>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_shape<const D: usize>(tensor: &FloatTensor<Self, D>) -> Shape<D> {
todo!()
}
async fn float_into_data<const D: usize>(tensor: FloatTensor<Self, D>) -> TensorData {
tensor.into_data().await
}
fn float_device<const D: usize>(tensor: &FloatTensor<Self, D>) -> Device<Self> {
todo!()
}
fn float_to_device<const D: usize>(
tensor: FloatTensor<Self, D>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_into_int<const D: usize>(tensor: FloatTensor<Self, D>) -> IntTensor<Self, D> {
todo!()
}
fn float_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
todo!()
}
fn float_add<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_add_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_clamp<const D: usize>(
tensor: FloatTensor<Self, D>,
min: FloatElem<Self>,
max: FloatElem<Self>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_sub<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_sub_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_mul<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_mul_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_div<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_div_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_remainder_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_matmul<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_swap_dims<const D: usize>(
tensor: FloatTensor<Self, D>,
dim1: usize,
dim2: usize,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_reshape<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> FloatTensor<Self, D2> {
todo!()
}
fn float_gather<const D: usize>(
dim: usize,
tensor: FloatTensor<Self, D>,
indices: IntTensor<Self, D>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_scatter<const D: usize>(
dim: usize,
tensor: FloatTensor<Self, D>,
indices: IntTensor<Self, D>,
value: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_select<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_select_assign<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
value: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_slice<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
ranges: [Range<usize>; D2],
) -> FloatTensor<Self, D1> {
todo!()
}
fn float_slice_assign<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
ranges: [Range<usize>; D2],
value: FloatTensor<Self, D1>,
) -> FloatTensor<Self, D1> {
todo!()
}
fn float_mask_where<const D: usize>(
tensor: FloatTensor<Self, D>,
mask: BoolTensor<Self, D>,
value: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_mask_fill<const D: usize>(
tensor: FloatTensor<Self, D>,
mask: BoolTensor<Self, D>,
value: FloatElem<Self>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
todo!()
}
fn float_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
todo!()
}
fn float_greater<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
todo!()
}
fn float_greater_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
todo!()
}
fn float_greater_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
todo!()
}
fn float_greater_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
todo!()
}
fn float_lower<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
todo!()
}
fn float_lower_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
todo!()
}
fn float_lower_equal<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> BoolTensor<Self, D> {
todo!()
}
fn float_lower_equal_elem<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> BoolTensor<Self, D> {
todo!()
}
fn float_sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
todo!()
}
fn float_sum_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_mean<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
todo!()
}
fn float_mean_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn float_log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn float_log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn float_powf_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: f32,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn float_abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn float_cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn float_sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn float_tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn float_recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn float_erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn float_cat<const D: usize>(
tensors: Vec<FloatTensor<Self, D>>,
dim: usize,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_argmax<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> IntTensor<Self, D> {
todo!()
}
fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
times: usize,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_argmin<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> IntTensor<Self, D> {
todo!()
}
fn float_max<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
todo!()
}
fn float_max_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_max_dim_with_indices<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> (FloatTensor<Self, D>, IntTensor<Self, D>) {
todo!()
}
fn float_min<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
todo!()
}
fn float_min_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_min_dim_with_indices<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
) -> (FloatTensor<Self, D>, IntTensor<Self, D>) {
todo!()
}
fn float_powf<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
todo!()
}
fn float_permute<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: [usize; D],
) -> FloatTensor<Self, D> {
todo!()
}
fn float_expand<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
) -> FloatTensor<Self, D2> {
todo!()
}
fn float_flip<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: &[usize],
) -> FloatTensor<Self, D> {
todo!()
}
}

View File

@ -1,4 +1,375 @@
use core::ops::Range;
use crate::ops::{BoolTensor, FloatTensor, IntElem, IntTensor};
use crate::server::Server;
use crate::{ops::IntTensorOps, server::ServerBackend};
use crate::{Device, Distribution, Shape, TensorData};
impl<B: ServerBackend> IntTensorOps<Self> for Server<B> {}
impl<B: ServerBackend> IntTensorOps<Self> for Server<B> {
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
todo!();
}
fn int_shape<const D: usize>(tensor: &IntTensor<Self, D>) -> Shape<D> {
todo!();
}
async fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> TensorData {
tensor.into_data().await
}
fn int_from_data<const D: usize>(
data: TensorData,
device: &Device<Self>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_device<const D: usize>(tensor: &IntTensor<Self, D>) -> Device<Self> {
todo!()
}
fn int_to_device<const D: usize>(
tensor: IntTensor<Self, D>,
device: &Device<Self>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_reshape<const D1: usize, const D2: usize>(
tensor: IntTensor<Self, D1>,
shape: Shape<D2>,
) -> IntTensor<Self, D2> {
todo!()
}
fn int_slice<const D1: usize, const D2: usize>(
tensor: IntTensor<Self, D1>,
ranges: [Range<usize>; D2],
) -> IntTensor<Self, D1> {
todo!()
}
fn int_slice_assign<const D1: usize, const D2: usize>(
tensor: IntTensor<Self, D1>,
ranges: [Range<usize>; D2],
value: IntTensor<Self, D1>,
) -> IntTensor<Self, D1> {
todo!()
}
fn int_mask_where<const D: usize>(
tensor: IntTensor<Self, D>,
mask: BoolTensor<Self, D>,
value: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_mask_fill<const D: usize>(
tensor: IntTensor<Self, D>,
mask: BoolTensor<Self, D>,
value: IntElem<Self>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_gather<const D: usize>(
dim: usize,
tensor: IntTensor<Self, D>,
indices: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_scatter<const D: usize>(
dim: usize,
tensor: IntTensor<Self, D>,
indices: IntTensor<Self, D>,
value: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_select<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_select_assign<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
indices: IntTensor<Self, 1>,
value: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_cat<const D: usize>(tensors: Vec<IntTensor<Self, D>>, dim: usize) -> IntTensor<Self, D> {
todo!()
}
fn int_equal<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> BoolTensor<Self, D> {
todo!()
}
fn int_equal_elem<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> BoolTensor<Self, D> {
todo!()
}
fn int_greater<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> BoolTensor<Self, D> {
todo!()
}
fn int_greater_elem<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> BoolTensor<Self, D> {
todo!()
}
fn int_greater_equal<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> BoolTensor<Self, D> {
todo!()
}
fn int_greater_equal_elem<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> BoolTensor<Self, D> {
todo!()
}
fn int_lower<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> BoolTensor<Self, D> {
todo!()
}
fn int_lower_elem<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> BoolTensor<Self, D> {
todo!()
}
fn int_lower_equal<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> BoolTensor<Self, D> {
todo!()
}
fn int_lower_equal_elem<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> BoolTensor<Self, D> {
todo!()
}
fn int_add<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_add_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_sub<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_sub_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_mul<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_mul_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_div<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_div_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_remainder_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
todo!()
}
fn int_ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
todo!()
}
fn int_sum<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!()
}
fn int_sum_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
todo!()
}
fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!()
}
fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
todo!()
}
fn int_mean<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!()
}
fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
todo!()
}
fn int_argmax<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
todo!()
}
fn int_argmin<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
todo!()
}
fn int_clamp<const D: usize>(
tensor: IntTensor<Self, D>,
min: IntElem<Self>,
max: IntElem<Self>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
todo!()
}
fn int_into_float<const D: usize>(tensor: IntTensor<Self, D>) -> FloatTensor<Self, D> {
todo!()
}
fn int_swap_dims<const D: usize>(
tensor: IntTensor<Self, D>,
dim1: usize,
dim2: usize,
) -> IntTensor<Self, D> {
todo!()
}
fn int_max<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!()
}
fn int_max_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
todo!()
}
fn int_max_dim_with_indices<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
) -> (IntTensor<Self, D>, IntTensor<Self, D>) {
todo!()
}
fn int_min<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!()
}
fn int_min_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
todo!()
}
fn int_min_dim_with_indices<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
) -> (IntTensor<Self, D>, IntTensor<Self, D>) {
todo!()
}
fn int_random<const D: usize>(
shape: Shape<D>,
distribution: Distribution,
device: &Device<Self>,
) -> IntTensor<Self, D> {
todo!()
}
fn int_permute<const D: usize>(
tensor: IntTensor<Self, D>,
axes: [usize; D],
) -> IntTensor<Self, D> {
todo!()
}
fn int_expand<const D1: usize, const D2: usize>(
tensor: IntTensor<Self, D1>,
shape: Shape<D2>,
) -> IntTensor<Self, D2> {
todo!()
}
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
todo!()
}
fn int_repeat_dim<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
times: usize,
) -> IntTensor<Self, D> {
todo!()
}
}

View File

@ -1,7 +1,213 @@
use crate::ops::{
IntTensor, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
MaxPool2dWithIndices,
};
use crate::server::Server;
use crate::{
ops::{ConvOptions, ConvTransposeOptions, FloatTensor, ModuleOps},
server::ServerBackend,
};
impl<B: ServerBackend> ModuleOps<Self> for Server<B> {}
impl<B: ServerBackend> ModuleOps<Self> for Server<B> {
fn conv1d(
x: FloatTensor<Self, 3>,
weight: FloatTensor<Self, 3>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvOptions<1>,
) -> FloatTensor<Self, 3> {
todo!()
}
fn conv2d(
x: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvOptions<2>,
) -> FloatTensor<Self, 4> {
todo!()
}
fn conv3d(
x: FloatTensor<Self, 5>,
weight: FloatTensor<Self, 5>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvOptions<3>,
) -> FloatTensor<Self, 5> {
todo!()
}
fn conv_transpose1d(
x: FloatTensor<Self, 3>,
weight: FloatTensor<Self, 3>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<Self, 3> {
todo!()
}
fn conv_transpose2d(
x: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<Self, 4> {
todo!()
}
fn conv_transpose3d(
x: FloatTensor<Self, 5>,
weight: FloatTensor<Self, 5>,
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<Self, 5> {
todo!()
}
fn avg_pool1d(
x: FloatTensor<Self, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
) -> FloatTensor<Self, 3> {
todo!()
}
fn avg_pool2d(
x: FloatTensor<Self, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
) -> FloatTensor<Self, 4> {
todo!()
}
fn avg_pool1d_backward(
x: FloatTensor<Self, 3>,
grad: FloatTensor<Self, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
) -> FloatTensor<Self, 3> {
todo!()
}
fn avg_pool2d_backward(
x: FloatTensor<Self, 4>,
grad: FloatTensor<Self, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
) -> FloatTensor<Self, 4> {
todo!()
}
fn max_pool1d(
x: FloatTensor<Self, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> FloatTensor<Self, 3> {
todo!()
}
fn max_pool2d(
x: FloatTensor<Self, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> FloatTensor<Self, 4> {
todo!()
}
fn max_pool1d_with_indices(
x: FloatTensor<Self, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> MaxPool1dWithIndices<Self> {
todo!()
}
fn max_pool2d_with_indices(
x: FloatTensor<Self, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<Self> {
todo!()
}
fn max_pool1d_with_indices_backward(
x: FloatTensor<Self, 3>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
output_grad: FloatTensor<Self, 3>,
indices: IntTensor<Self, 3>,
) -> MaxPool1dBackward<Self> {
todo!()
}
fn max_pool2d_with_indices_backward(
x: FloatTensor<Self, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: FloatTensor<Self, 4>,
indices: IntTensor<Self, 4>,
) -> MaxPool2dBackward<Self> {
todo!()
}
fn adaptive_avg_pool1d(x: FloatTensor<Self, 3>, output_size: usize) -> FloatTensor<Self, 3> {
todo!()
}
fn adaptive_avg_pool2d(
x: FloatTensor<Self, 4>,
output_size: [usize; 2],
) -> FloatTensor<Self, 4> {
todo!()
}
fn adaptive_avg_pool1d_backward(
x: FloatTensor<Self, 3>,
grad: FloatTensor<Self, 3>,
) -> FloatTensor<Self, 3> {
todo!()
}
fn adaptive_avg_pool2d_backward(
x: FloatTensor<Self, 4>,
grad: FloatTensor<Self, 4>,
) -> FloatTensor<Self, 4> {
todo!()
}
fn interpolate(
x: FloatTensor<Self, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self, 4> {
todo!()
}
fn interpolate_backward(
x: FloatTensor<Self, 4>,
grad: FloatTensor<Self, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self, 4> {
todo!()
}
}

View File

@ -1,10 +1,118 @@
use core::future::Future;
use core::ops::Range;
use crate::{
ops::{FloatTensor, QTensorOps, QuantizedTensor},
backend::Backend,
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
server::{Server, ServerBackend},
Device, TensorData,
Device, Shape, TensorData,
};
impl<B: ServerBackend> QTensorOps<Self> for Server<B> {}
impl<B: ServerBackend> QTensorOps<Self> for Server<B> {
fn q_from_data<const D: usize>(
_data: TensorData,
_device: &Device<Self>,
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn quantize<const D: usize>(
_tensor: <Self as Backend>::FloatTensorPrimitive<D>,
_scheme: &QuantizationScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> <Self as Backend>::QuantizedTensorPrimitive<D> {
unimplemented!()
}
fn quantize_dynamic<const D: usize>(
_tensor: <Self as Backend>::FloatTensorPrimitive<D>,
_scheme: &QuantizationScheme,
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn dequantize<const D: usize>(
_tensor: <Self as Backend>::QuantizedTensorPrimitive<D>,
) -> <Self as Backend>::FloatTensorPrimitive<D> {
unimplemented!()
}
fn q_shape<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Shape<D> {
todo!()
}
fn q_device<const D: usize>(tensor: &QuantizedTensor<Self, D>) -> Device<Self> {
todo!()
}
fn q_to_device<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_device: &Device<Self>,
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn q_reshape<const D1: usize, const D2: usize>(
_tensor: QuantizedTensor<Self, D1>,
_shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
unimplemented!()
}
async fn q_into_data<const D: usize>(_tensor: QuantizedTensor<Self, D>) -> TensorData {
unimplemented!()
}
fn q_swap_dims<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_dim1: usize,
_dim2: usize,
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn q_permute<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_axes: [usize; D],
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn q_flip<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_axes: &[usize],
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn q_gather<const D: usize>(
_dim: usize,
_tensor: QuantizedTensor<Self, D>,
_indices: IntTensor<Self, D>,
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn q_select<const D: usize>(
_tensor: QuantizedTensor<Self, D>,
_dim: usize,
_indices: IntTensor<Self, 1>,
) -> QuantizedTensor<Self, D> {
unimplemented!()
}
fn q_slice<const D1: usize, const D2: usize>(
_tensor: QuantizedTensor<Self, D1>,
_ranges: [Range<usize>; D2],
) -> QuantizedTensor<Self, D1> {
unimplemented!()
}
fn q_expand<const D1: usize, const D2: usize>(
_tensor: QuantizedTensor<Self, D1>,
_shape: Shape<D2>,
) -> QuantizedTensor<Self, D2> {
unimplemented!()
}
}

View File

@ -0,0 +1,39 @@
use burn_common::stream::StreamId;
use super::{ServerClient, ServerRuntime};
use crate::{repr::TensorDescription, TensorData};
pub struct ServerTensor<R: ServerRuntime> {
pub(crate) desc: TensorDescription,
pub(crate) client: R::Client,
pub(crate) stream: StreamId,
}
impl<R: ServerRuntime> ServerTensor<R> {
pub(crate) async fn into_data(self) -> TensorData {
let id = self.stream;
let desc = self.desc;
self.client.read_tensor(desc, id).await
}
pub fn into_description(self) -> TensorDescription {
self.desc
}
}
impl<R: ServerRuntime> core::fmt::Debug for ServerTensor<R> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("tensor"))
}
}
impl<R: ServerRuntime> Clone for ServerTensor<R> {
fn clone(&self) -> Self {
Self {
desc: self.desc.clone(),
client: self.client.clone(),
stream: self.stream,
}
}
}