diff --git a/Cargo.lock b/Cargo.lock index 207afc58b..8d0fc04af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -706,6 +706,18 @@ dependencies = [ "serde", ] +[[package]] +name = "burn-router" +version = "0.15.0" +dependencies = [ + "burn-autodiff", + "burn-ndarray", + "burn-tensor", + "burn-wgpu", + "hashbrown 0.14.5", + "spin", +] + [[package]] name = "burn-tch" version = "0.15.0" @@ -732,6 +744,7 @@ dependencies = [ "half", "hashbrown 0.14.5", "num-traits", + "portable-atomic-util", "rand", "rand_distr", "serde", diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index 8bed96b27..42c0e1f32 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -4,8 +4,8 @@ use crate::{ }; use burn_tensor::{ backend::{Backend, DeviceOps}, - ops::FloatTensor, - repr::{OperationDescription, ReprBackend}, + ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, + repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle}, Device, }; use serde::{de::DeserializeOwned, Serialize}; @@ -166,3 +166,43 @@ pub trait FusionBackend: /// Pointer to the full precision fusion backend. type FullPrecisionBackend: FusionBackend; } + +// Fusion implements `ReprBackend` to enable router backend usage. +impl ReprBackend for Fusion { + type Handle = FusionTensor; + + fn float_tensor(handle: TensorHandle) -> FloatTensor { + handle.handle + } + + fn int_tensor(handle: TensorHandle) -> IntTensor { + handle.handle + } + + fn bool_tensor(handle: TensorHandle) -> BoolTensor { + handle.handle + } + + fn quantized_tensor( + _handles: QuantizedKind>, + _scheme: burn_tensor::quantization::QuantizationScheme, + ) -> QuantizedTensor { + todo!() // not as simple + } + + fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { + tensor + } + + fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { + tensor + } + + fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { + tensor + } + + fn quantized_tensor_handle(_tensor: QuantizedTensor) -> QuantizedKind { + todo!() // not as simple + } +} diff --git a/crates/burn-fusion/src/ops/binary.rs b/crates/burn-fusion/src/ops/binary.rs index 8ce7bb4ac..4f2dbb986 100644 --- a/crates/burn-fusion/src/ops/binary.rs +++ b/crates/burn-fusion/src/ops/binary.rs @@ -73,16 +73,6 @@ macro_rules! binary_int_cmp_ops { }; } -pub(crate) fn binary_ops_shape(lhs: &[usize], rhs: &[usize]) -> Vec { - let mut shape_out = Vec::with_capacity(lhs.len()); - - for (l, r) in lhs.iter().zip(rhs.iter()) { - shape_out.push(usize::max(*l, *r)); - } - - shape_out -} - #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_int_ops { diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index a863355dc..510c0254b 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -1,5 +1,5 @@ use burn_tensor::{ - ops::{FloatTensor, IntTensor}, + ops::{binary_ops_shape, FloatTensor, IntTensor}, DType, Element, TensorData, }; use std::marker::PhantomData; @@ -7,7 +7,6 @@ use std::marker::PhantomData; use crate::{ client::FusionClient, get_client, - ops::binary::binary_ops_shape, stream::{execution::Operation, StreamId}, Fusion, FusionBackend, }; diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 0351830d2..1be7cd884 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -1,14 +1,12 @@ use crate::{ binary_float_cmp_ops, binary_float_ops, client::FusionClient, - get_client, - ops::binary::binary_ops_shape, - scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, + get_client, scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, stream::{execution::Operation, StreamId}, unary_float_ops, Fusion, FusionBackend, }; use burn_tensor::{ - ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, + ops::{binary_ops_shape, BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, repr::*, DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, }; diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 48352e2db..d8facd1e3 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -1,14 +1,12 @@ use crate::{ binary_int_cmp_ops, binary_int_ops, client::FusionClient, - get_client, - ops::binary::binary_ops_shape, - scalar_int_cmp_ops, scalar_int_ops, + get_client, scalar_int_cmp_ops, scalar_int_ops, stream::{execution::Operation, StreamId}, unary_int_ops, Fusion, FusionBackend, }; use burn_tensor::{ - ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, + ops::{binary_ops_shape, BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, repr::{self, *}, DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, }; diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index a6a51ce3a..1f5f5e494 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -6,6 +6,7 @@ use burn_tensor::{ repr::{ DequantizeOperationDescription, FloatOperationDescription, HandleContainer, OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription, + QuantizedKind, }, DType, Device, Element, Shape, TensorData, }; @@ -25,19 +26,17 @@ impl QTensorOps for Fusion { let tensor = B::q_from_data(data, device); let shape = B::q_shape(&tensor); - let mut handles = B::quantized_tensor_handle(tensor); + let handles = B::quantized_tensor_handle(tensor); let qparams = match strategy { QuantizationStrategy::PerTensorAffineInt8(_) => { - let num_handles = handles.len(); - assert_eq!( - num_handles, 3, - "Expected 3 handles for quantized tensor, got {num_handles}" - ); - let offset = handles.pop().unwrap(); - let scale = handles.pop().unwrap(); + let offset = if let Some(offset) = handles.offset { + offset + } else { + panic!("Expected offset for quantized tensor."); + }; FusionQuantizationParameters { scale: client.register_tensor( - scale, + handles.scale, vec![1], StreamId::current(), B::FloatElem::dtype(), @@ -51,15 +50,13 @@ impl QTensorOps for Fusion { } } QuantizationStrategy::PerTensorSymmetricInt8(_) => { - let num_handles = handles.len(); - assert_eq!( - num_handles, 2, - "Expected 2 handles for quantized tensor, got {num_handles}" + assert!( + handles.offset.is_none(), + "Offset should not be provided for symmetric quantization." ); - let scale = handles.pop().unwrap(); FusionQuantizationParameters { scale: client.register_tensor( - scale, + handles.scale, vec![1], StreamId::current(), B::FloatElem::dtype(), @@ -69,7 +66,7 @@ impl QTensorOps for Fusion { } }; let qtensor = client.register_tensor( - handles.pop().unwrap(), + handles.tensor, shape.dims, StreamId::current(), B::QuantizedEncoding::dtype(), @@ -111,17 +108,20 @@ impl QTensorOps for Fusion { let qparams = QuantizationParametersPrimitive { scale, offset }; let output = B::quantize(tensor, &self.desc.scheme, qparams); - if let Some(offset) = &self.desc.qparams.offset { - handles.register_quantized_tensor::( - &[&self.desc.out.id, &self.desc.qparams.scale.id, &offset.id], - output, - ); + let q_ids = if let Some(offset) = &self.desc.qparams.offset { + QuantizedKind { + tensor: self.desc.out.id, + scale: self.desc.qparams.scale.id, + offset: Some(offset.id), + } } else { - handles.register_quantized_tensor::( - &[&self.desc.out.id, &self.desc.qparams.scale.id], - output, - ); - } + QuantizedKind { + tensor: self.desc.out.id, + scale: self.desc.qparams.scale.id, + offset: None, + } + }; + handles.register_quantized_tensor::(&q_ids, output); } } diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index 58cb7f960..ce44d51c4 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -3,7 +3,8 @@ use crate::{ FusionBackend, FusionRuntime, }; use burn_tensor::repr::{ - HandleContainer, OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId, + HandleContainer, OperationDescription, QuantizedKind, QuantizedTensorDescription, + TensorDescription, TensorId, }; use std::sync::Arc; @@ -183,18 +184,28 @@ where let scale_id = server_device.create_empty_handle(); let offset_id = server_device.create_empty_handle(); + let q_ids = QuantizedKind { + tensor: *tensor_id, + scale: *scale_id, + offset: Some(*offset_id), + }; server_device .handles - .register_quantized_tensor::(&[&tensor_id, &scale_id, &offset_id], tensor); + .register_quantized_tensor::(&q_ids, tensor); vec![tensor_id, scale_id, offset_id] } else { let tensor_id = server_device.create_empty_handle(); let scale_id = server_device.create_empty_handle(); + let q_ids = QuantizedKind { + tensor: *tensor_id, + scale: *scale_id, + offset: None, + }; server_device .handles - .register_quantized_tensor::(&[&tensor_id, &scale_id], tensor); + .register_quantized_tensor::(&q_ids, tensor); vec![tensor_id, scale_id] } diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 571a1254b..28293c926 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -968,6 +968,9 @@ impl RelativeOps for BaseOperationDescription { out: desc.out.to_relative(converter), }) } + BaseOperationDescription::Empty(desc) => { + BaseOperationDescription::Empty(desc.to_relative(converter)) + } } } } diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 5218feaf9..c6463baa3 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -31,7 +31,7 @@ template = [] burn-common = { path = "../burn-common", version = "0.15.0" } burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true } burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = [ - "cubecl", + "cubecl", "repr", ] } cubecl = { workspace = true, features = ["linalg"] } diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 46c520e2c..b23479075 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -7,6 +7,13 @@ use cubecl::server::ComputeServer; use rand::{rngs::StdRng, SeedableRng}; use std::{marker::PhantomData, sync::Mutex}; +#[cfg(not(feature = "fusion"))] +use burn_tensor::{ + ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, + quantization::QuantizationScheme, + repr::{HandleKind, QuantizedKind, ReprBackend, TensorHandle}, +}; + pub(crate) static SEED: Mutex> = Mutex::new(None); /// Generic tensor backend that can be compiled just-in-time to any shader runtime @@ -82,3 +89,62 @@ where type JitDevice = R::Device; type JitServer = R::Server; } + +#[cfg(not(feature = "fusion"))] +impl ReprBackend for JitBackend { + type Handle = HandleKind; + + fn float_tensor(handle: TensorHandle) -> FloatTensor { + match handle.handle { + HandleKind::Float(handle) => handle, + _ => panic!("Expected float handle, got {}", handle.handle.name()), + } + } + + fn int_tensor(handle: TensorHandle) -> IntTensor { + match handle.handle { + HandleKind::Int(handle) => handle, + _ => panic!("Expected int handle, got {}", handle.handle.name()), + } + } + + fn bool_tensor(handle: TensorHandle) -> BoolTensor { + match handle.handle { + HandleKind::Bool(handle) => handle, + _ => panic!("Expected bool handle, got {}", handle.handle.name()), + } + } + + fn quantized_tensor( + handles: QuantizedKind>, + _scheme: QuantizationScheme, + ) -> QuantizedTensor { + let handle = handles.tensor.handle; + match handle { + HandleKind::Quantized(handle) => handle, + _ => panic!("Expected quantized handle, got {}", handle.name()), + } + } + + fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { + HandleKind::Float(tensor) + } + + fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { + HandleKind::Int(tensor) + } + + fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { + HandleKind::Bool(tensor) + } + + fn quantized_tensor_handle(tensor: QuantizedTensor) -> QuantizedKind { + QuantizedKind { + tensor: HandleKind::Quantized(tensor), + // The quantized tensor primitive already encapsulates the required quantization + // parameters so we set the scale as an empty handle (unused). + scale: HandleKind::Empty, + offset: None, + } + } +} diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 2b58692f3..527c4bd98 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -7,7 +7,7 @@ use crate::{ }; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::quantization::QuantizationScheme; -use burn_tensor::repr::TensorHandle; +use burn_tensor::repr::{QuantizedKind, TensorHandle}; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; use cubecl::client::ComputeClient; @@ -79,41 +79,22 @@ impl ReprBackend for JitBackend>, + handles: QuantizedKind>, scheme: QuantizationScheme, ) -> burn_tensor::ops::QuantizedTensor { - match handles.len() { - // NOTE: the order of the handles is known [qtensor, scale, ] - 3 => { - let mut handles = handles; - let offset = handles.pop().unwrap(); - let scale = handles.pop().unwrap(); - let qtensor = handles.pop().unwrap(); - QJitTensor { - qtensor: qtensor.handle.into_tensor(qtensor.shape), - scheme, - qparams: JitQuantizationParameters { - scale: scale.handle.into_tensor(scale.shape), - offset: Some(offset.handle.into_tensor(offset.shape)), - }, - } - } - 2 => { - let mut handles = handles; - let scale = handles.pop().unwrap(); - let qtensor = handles.pop().unwrap(); - QJitTensor { - qtensor: qtensor.handle.into_tensor(qtensor.shape), - scheme, - qparams: JitQuantizationParameters { - scale: scale.handle.into_tensor(scale.shape), - offset: None, - }, - } - } - _ => { - panic!("Expected handles for the quantized tensor and its quantization parameters.") - } + let qtensor = handles.tensor.handle.into_tensor(handles.tensor.shape); + let scale = handles.scale.handle.into_tensor(handles.scale.shape); + let offset = handles.offset; + + let qparams = JitQuantizationParameters { + scale, + offset: offset.map(|h| h.handle.into_tensor(h.shape)), + }; + + QJitTensor { + qtensor, + scheme, + qparams, } } @@ -131,14 +112,14 @@ impl ReprBackend for JitBackend, - ) -> Vec { + ) -> QuantizedKind { let qtensor: JitFusionHandle = tensor.qtensor.into(); let scale: JitFusionHandle = tensor.qparams.scale.into(); - if let Some(offset) = tensor.qparams.offset { - let offset: JitFusionHandle = offset.into(); - vec![qtensor, scale, offset] - } else { - vec![qtensor, scale] + + QuantizedKind { + tensor: qtensor, + scale, + offset: tensor.qparams.offset.map(|offset| offset.into()), } } } diff --git a/crates/burn-ndarray/Cargo.toml b/crates/burn-ndarray/Cargo.toml index ef5016d8c..d41c37ed5 100644 --- a/crates/burn-ndarray/Cargo.toml +++ b/crates/burn-ndarray/Cargo.toml @@ -45,7 +45,7 @@ blas-openblas-system = [ burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", optional = true } burn-common = { path = "../burn-common", version = "0.15.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false, features = ["repr"] } atomic_float = { workspace = true } blas-src = { workspace = true, default-features = false, optional = true } # no-std compatible diff --git a/crates/burn-ndarray/src/backend.rs b/crates/burn-ndarray/src/backend.rs index 671b999ed..0ca2c89a1 100644 --- a/crates/burn-ndarray/src/backend.rs +++ b/crates/burn-ndarray/src/backend.rs @@ -1,9 +1,12 @@ -use crate::element::{FloatNdArrayElement, QuantElement}; +use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}; use crate::PrecisionBridge; use crate::{NdArrayQTensor, NdArrayTensor}; use alloc::string::String; use burn_common::stub::Mutex; use burn_tensor::backend::{Backend, DeviceId, DeviceOps}; +use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; +use burn_tensor::quantization::QuantizationScheme; +use burn_tensor::repr::{HandleKind, QuantizedKind, ReprBackend, TensorHandle}; use core::marker::PhantomData; use rand::{rngs::StdRng, SeedableRng}; @@ -35,20 +38,21 @@ impl Default for NdArrayDevice { /// This backend is compatible with CPUs and can be compiled for almost any platform, including /// `wasm`, `arm`, and `x86`. #[derive(Clone, Copy, Default, Debug)] -pub struct NdArray { +pub struct NdArray { _e: PhantomData, + _i: PhantomData, _q: PhantomData, } -impl Backend for NdArray { +impl Backend for NdArray { type Device = NdArrayDevice; type FullPrecisionBridge = PrecisionBridge; type FloatTensorPrimitive = NdArrayTensor; type FloatElem = E; - type IntTensorPrimitive = NdArrayTensor; - type IntElem = i64; + type IntTensorPrimitive = NdArrayTensor; + type IntElem = I; type BoolTensorPrimitive = NdArrayTensor; @@ -69,3 +73,63 @@ impl Backend for NdArray { *seed = Some(rng); } } + +impl ReprBackend + for NdArray +{ + type Handle = HandleKind; + + fn float_tensor(handle: TensorHandle) -> FloatTensor { + match handle.handle { + HandleKind::Float(handle) => handle, + _ => panic!("Expected float handle, got {}", handle.handle.name()), + } + } + + fn int_tensor(handle: TensorHandle) -> IntTensor { + match handle.handle { + HandleKind::Int(handle) => handle, + _ => panic!("Expected int handle, got {}", handle.handle.name()), + } + } + + fn bool_tensor(handle: TensorHandle) -> BoolTensor { + match handle.handle { + HandleKind::Bool(handle) => handle, + _ => panic!("Expected bool handle, got {}", handle.handle.name()), + } + } + + fn quantized_tensor( + handles: QuantizedKind>, + _scheme: QuantizationScheme, + ) -> QuantizedTensor { + let handle = handles.tensor.handle; + match handle { + HandleKind::Quantized(handle) => handle, + _ => panic!("Expected quantized handle, got {}", handle.name()), + } + } + + fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { + HandleKind::Float(tensor) + } + + fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { + HandleKind::Int(tensor) + } + + fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { + HandleKind::Bool(tensor) + } + + fn quantized_tensor_handle(tensor: QuantizedTensor) -> QuantizedKind { + QuantizedKind { + tensor: HandleKind::Quantized(tensor), + // The quantized tensor primitive already encapsulates the required quantization + // parameters so we set the scale as an empty handle (unused). + scale: HandleKind::Empty, + offset: None, + } + } +} diff --git a/crates/burn-ndarray/src/bridge.rs b/crates/burn-ndarray/src/bridge.rs index 15803ffc5..27f993979 100644 --- a/crates/burn-ndarray/src/bridge.rs +++ b/crates/burn-ndarray/src/bridge.rs @@ -1,4 +1,7 @@ -use crate::{element::QuantElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor}; +use crate::{ + element::{IntNdArrayElement, QuantElement}, + FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor, +}; use burn_tensor::{backend::BackendBridge, ops::FloatTensor}; use core::marker::PhantomData; @@ -8,13 +11,15 @@ pub struct PrecisionBridge { _e: PhantomData, } -impl BackendBridge> for PrecisionBridge +impl BackendBridge> + for PrecisionBridge where TElem: FloatNdArrayElement, OElem: FloatNdArrayElement, QElem: QuantElement, + IntElem: IntNdArrayElement, { - type Target = NdArray; + type Target = NdArray; fn into_target( tensor: FloatTensor>, diff --git a/crates/burn-ndarray/src/element.rs b/crates/burn-ndarray/src/element.rs index bf87d5477..21c3b5dde 100644 --- a/crates/burn-ndarray/src/element.rs +++ b/crates/burn-ndarray/src/element.rs @@ -16,6 +16,8 @@ where { } +pub trait IntNdArrayElement: NdArrayElement + core::ops::Rem + Signed {} + /// A general element for ndarray backend. pub trait NdArrayElement: Element @@ -49,6 +51,9 @@ impl QuantElement for i8 {} impl FloatNdArrayElement for f64 {} impl FloatNdArrayElement for f32 {} +impl IntNdArrayElement for i64 {} +impl IntNdArrayElement for i32 {} + macro_rules! make_elem { ( double diff --git a/crates/burn-ndarray/src/ops/activations.rs b/crates/burn-ndarray/src/ops/activations.rs index 2d8b32302..aa194e44c 100644 --- a/crates/burn-ndarray/src/ops/activations.rs +++ b/crates/burn-ndarray/src/ops/activations.rs @@ -1,11 +1,13 @@ use crate::{ - element::{FloatNdArrayElement, QuantElement}, + element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, tensor::NdArrayTensor, NdArray, }; use burn_tensor::{ops::ActivationOps, ElementConversion}; -impl ActivationOps for NdArray { +impl ActivationOps + for NdArray +{ fn relu(tensor: NdArrayTensor) -> NdArrayTensor { let zero = 0.elem(); let array = tensor diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index ab38e808d..7b94ea799 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -261,10 +261,10 @@ where } } - pub fn gather( + pub fn gather( dim: usize, mut tensor: NdArrayTensor, - mut indices: NdArrayTensor, + mut indices: NdArrayTensor, ) -> NdArrayTensor { let ndims = tensor.shape().num_dims(); if dim != ndims - 1 { @@ -284,7 +284,7 @@ where let indices = indices.slice(s!(b, ..)); for (i, index) in indices.iter().enumerate() { - output[[b, i]] = tensor[[b, *index as usize]]; + output[[b, i]] = tensor[[b, index.elem::() as usize]]; } } @@ -300,10 +300,10 @@ where output } - pub fn scatter( + pub fn scatter( dim: usize, mut tensor: NdArrayTensor, - mut indices: NdArrayTensor, + mut indices: NdArrayTensor, mut value: NdArrayTensor, ) -> NdArrayTensor { let ndims = tensor.shape().num_dims(); @@ -338,7 +338,7 @@ where let indices = indices.slice(s!(b, ..)); for (i, index) in indices.iter().enumerate() { - let index = *index as usize; + let index = index.elem::() as usize; tensor[[b, index]] += value[[b, i]]; } } @@ -403,33 +403,33 @@ where batch_size } - pub fn select( + pub fn select( tensor: NdArrayTensor, dim: usize, - indices: NdArrayTensor, + indices: NdArrayTensor, ) -> NdArrayTensor { let array = tensor.array.select( Axis(dim), &indices .array .into_iter() - .map(|i| i as usize) + .map(|i| i.elem::() as usize) .collect::>(), ); NdArrayTensor::new(array.into_shared()) } - pub fn select_assign( + pub fn select_assign( tensor: NdArrayTensor, dim: usize, - indices: NdArrayTensor, + indices: NdArrayTensor, value: NdArrayTensor, ) -> NdArrayTensor { let mut output_array = tensor.array.into_owned(); for (index_value, index) in indices.array.into_iter().enumerate() { - let mut view = output_array.index_axis_mut(Axis(dim), index as usize); + let mut view = output_array.index_axis_mut(Axis(dim), index.elem::() as usize); let value = value.array.index_axis(Axis(dim), index_value); view.zip_mut_with(&value, |a, b| *a += *b); @@ -437,11 +437,11 @@ where NdArrayTensor::new(output_array.into_shared()) } - pub fn argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + pub fn argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { arg(tensor, dim, CmpType::Max) } - pub fn argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + pub fn argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { arg(tensor, dim, CmpType::Min) } @@ -523,11 +523,11 @@ enum CmpType { Max, } -fn arg( +fn arg( tensor: NdArrayTensor, dim: usize, cmp: CmpType, -) -> NdArrayTensor { +) -> NdArrayTensor { let mut reshape = tensor.array.shape().to_vec(); reshape[dim] = 1; @@ -546,7 +546,7 @@ fn arg( } }); - idx as i64 + (idx as i64).elem() }); let output = output.to_shape(Dim(reshape.as_slice())).unwrap(); diff --git a/crates/burn-ndarray/src/ops/bool_tensor.rs b/crates/burn-ndarray/src/ops/bool_tensor.rs index e36223235..6dbd02e30 100644 --- a/crates/burn-ndarray/src/ops/bool_tensor.rs +++ b/crates/burn-ndarray/src/ops/bool_tensor.rs @@ -7,7 +7,7 @@ use core::ops::Range; use ndarray::{IntoDimension, Zip}; // Current crate -use crate::element::{FloatNdArrayElement, QuantElement}; +use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}; use crate::NdArrayDevice; use crate::{tensor::NdArrayTensor, NdArray}; @@ -16,7 +16,9 @@ use burn_tensor::{backend::Backend, Shape, TensorData}; use super::NdArrayOps; -impl BoolTensorOps for NdArray { +impl BoolTensorOps + for NdArray +{ fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor { NdArrayTensor::from_data(data) } @@ -43,11 +45,11 @@ impl BoolTensorOps for NdArray) -> NdArrayTensor { + fn bool_into_int(tensor: NdArrayTensor) -> NdArrayTensor { let shape = tensor.shape(); let values = tensor.array.into_iter().collect(); - NdArray::::int_from_data( - TensorData::new(values, shape).convert::(), + NdArray::::int_from_data( + TensorData::new(values, shape).convert::(), &NdArrayDevice::Cpu, ) } diff --git a/crates/burn-ndarray/src/ops/conv.rs b/crates/burn-ndarray/src/ops/conv.rs index 23e36238d..ff8483473 100644 --- a/crates/burn-ndarray/src/ops/conv.rs +++ b/crates/burn-ndarray/src/ops/conv.rs @@ -11,7 +11,7 @@ use ndarray::{ }; use crate::{ - element::{FloatNdArrayElement, QuantElement}, + element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, ops::padding::{apply_padding_4d, apply_padding_5d}, sharing::UnsafeSharedRef, tensor::NdArrayTensor, @@ -98,7 +98,7 @@ fn conv3d_mad_inner( } } -pub(crate) fn conv2d( +pub(crate) fn conv2d( x: NdArrayTensor, weight: NdArrayTensor, bias: Option>, @@ -126,7 +126,7 @@ pub(crate) fn conv2d( in_width, ); - let x = apply_padding_4d::(x, options.padding, 0i32.elem()).array; + let x = apply_padding_4d::(x, options.padding, 0i32.elem()).array; // Convert inputs from dynamic indexes to static to improve perf. let x = x.into_dimensionality::().unwrap(); @@ -310,7 +310,7 @@ pub(crate) fn conv_transpose2d( NdArrayTensor::new(output.into_dyn().into_shared()) } -pub(crate) fn conv3d( +pub(crate) fn conv3d( x: NdArrayTensor, weight: NdArrayTensor, bias: Option>, @@ -345,7 +345,7 @@ pub(crate) fn conv3d( in_width, ); - let x = apply_padding_5d::(x, options.padding, 0i32.elem()).array; + let x = apply_padding_5d::(x, options.padding, 0i32.elem()).array; // Convert inputs from dynamic indexes to static to improve perf. let x = x.into_dimensionality::().unwrap(); diff --git a/crates/burn-ndarray/src/ops/deform_conv.rs b/crates/burn-ndarray/src/ops/deform_conv.rs index 041a8ae22..2a5f8d2db 100644 --- a/crates/burn-ndarray/src/ops/deform_conv.rs +++ b/crates/burn-ndarray/src/ops/deform_conv.rs @@ -252,7 +252,7 @@ pub mod backward { #[cfg(target_has_atomic = "32")] use core::sync::atomic::Ordering; - use crate::NdArray; + use crate::{element::IntNdArrayElement, NdArray}; use atomic_float::AtomicF32; use burn_tensor::ops::DeformConv2dBackward; use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4}; @@ -260,7 +260,11 @@ pub mod backward { use super::*; /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. - pub(crate) fn deform_conv2d_backward( + pub(crate) fn deform_conv2d_backward< + F: FloatNdArrayElement, + I: IntNdArrayElement, + Q: QuantElement, + >( input: NdArrayTensor, offset: NdArrayTensor, weight: NdArrayTensor, @@ -268,7 +272,7 @@ pub mod backward { bias: Option>, out_grad: NdArrayTensor, args: DeformConvOptions<2>, - ) -> DeformConv2dBackward> { + ) -> DeformConv2dBackward> { let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims(); let [_, _, kernel_h, kernel_w] = weight.shape().dims(); let groups = args.weight_groups; diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index 887bda526..4242ad1ca 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -11,8 +11,8 @@ use ndarray::IntoDimension; use ndarray::Zip; // Current crate -use crate::element::ExpElement; use crate::element::FloatNdArrayElement; +use crate::element::IntNdArrayElement; use crate::element::QuantElement; use crate::{tensor::NdArrayTensor, NdArray}; use crate::{NdArrayDevice, SEED}; @@ -22,71 +22,73 @@ use burn_tensor::{backend::Backend, Shape, TensorData}; use super::{NdArrayMathOps, NdArrayOps}; -impl IntTensorOps for NdArray { - fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor { +impl IntTensorOps + for NdArray +{ + fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor { NdArrayTensor::from_data(data) } - fn int_shape(tensor: &NdArrayTensor) -> Shape { + fn int_shape(tensor: &NdArrayTensor) -> Shape { tensor.shape() } - async fn int_into_data(tensor: NdArrayTensor) -> TensorData { + async fn int_into_data(tensor: NdArrayTensor) -> TensorData { let shape = tensor.shape(); let values = tensor.array.into_iter().collect(); TensorData::new(values, shape) } - fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor { + fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor { tensor } - fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { + fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { NdArrayOps::reshape(tensor, shape) } - fn int_slice(tensor: NdArrayTensor, ranges: &[Range]) -> NdArrayTensor { + fn int_slice(tensor: NdArrayTensor, ranges: &[Range]) -> NdArrayTensor { NdArrayOps::slice(tensor, ranges) } - fn int_device(_tensor: &NdArrayTensor) -> as Backend>::Device { + fn int_device(_tensor: &NdArrayTensor) -> as Backend>::Device { NdArrayDevice::Cpu } - fn int_empty(shape: Shape, _device: & as Backend>::Device) -> NdArrayTensor { + fn int_empty(shape: Shape, _device: & as Backend>::Device) -> NdArrayTensor { let values = vec![0; shape.num_elements()]; NdArrayTensor::from_data(TensorData::new(values, shape)) } fn int_mask_where( - tensor: NdArrayTensor, + tensor: NdArrayTensor, mask: NdArrayTensor, - source: NdArrayTensor, - ) -> NdArrayTensor { + source: NdArrayTensor, + ) -> NdArrayTensor { NdArrayMathOps::mask_where(tensor, mask, source) } fn int_mask_fill( - tensor: NdArrayTensor, + tensor: NdArrayTensor, mask: NdArrayTensor, - value: i64, - ) -> NdArrayTensor { + value: I, + ) -> NdArrayTensor { NdArrayMathOps::mask_fill(tensor, mask, value) } fn int_slice_assign( - tensor: NdArrayTensor, + tensor: NdArrayTensor, ranges: &[Range], - value: NdArrayTensor, - ) -> NdArrayTensor { + value: NdArrayTensor, + ) -> NdArrayTensor { NdArrayOps::slice_assign(tensor, ranges, value) } - fn int_cat(tensors: Vec>, dim: usize) -> NdArrayTensor { + fn int_cat(tensors: Vec>, dim: usize) -> NdArrayTensor { NdArrayOps::cat(tensors, dim) } - fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { let output = Zip::from(&lhs.array) .and(&rhs.array) .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val)) @@ -94,196 +96,196 @@ impl IntTensorOps for NdArray, rhs: i64) -> NdArrayTensor { + fn int_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { let array = lhs.array.mapv(|a| a == rhs).into_shared(); NdArrayTensor { array } } - fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { let tensor = Self::int_sub(lhs, rhs); - Self::int_greater_elem(tensor, 0) + Self::int_greater_elem(tensor, 0.elem()) } - fn int_greater_elem(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + fn int_greater_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { let array = lhs.array.mapv(|a| a > rhs).into_shared(); NdArrayTensor::new(array) } - fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { let tensor = Self::int_sub(lhs, rhs); - Self::int_greater_equal_elem(tensor, 0) + Self::int_greater_equal_elem(tensor, 0.elem()) } - fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { let array = lhs.array.mapv(|a| a >= rhs).into_shared(); NdArrayTensor::new(array) } - fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { let tensor = Self::int_sub(lhs, rhs); - Self::int_lower_elem(tensor, 0) + Self::int_lower_elem(tensor, 0.elem()) } - fn int_lower_elem(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + fn int_lower_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { let array = lhs.array.mapv(|a| a < rhs).into_shared(); NdArrayTensor::new(array) } - fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { let tensor = Self::int_sub(lhs, rhs); - Self::int_lower_equal_elem(tensor, 0) + Self::int_lower_equal_elem(tensor, 0.elem()) } - fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { let array = lhs.array.mapv(|a| a <= rhs).into_shared(); NdArrayTensor::new(array) } - fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::add(lhs, rhs) } - fn int_add_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + fn int_add_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { NdArrayMathOps::add_scalar(lhs, rhs) } - fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::sub(lhs, rhs) } - fn int_sub_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + fn int_sub_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { NdArrayMathOps::sub_scalar(lhs, rhs) } - fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::mul(lhs, rhs) } - fn int_mul_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + fn int_mul_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { NdArrayMathOps::mul_scalar(lhs, rhs) } - fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::div(lhs, rhs) } - fn int_div_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + fn int_div_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { NdArrayMathOps::div_scalar(lhs, rhs) } - fn int_remainder_scalar(lhs: NdArrayTensor, rhs: i64) -> NdArrayTensor { + fn int_remainder_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { NdArrayMathOps::remainder_scalar(lhs, rhs) } - fn int_neg(tensor: NdArrayTensor) -> NdArrayTensor { - Self::int_mul_scalar(tensor, -1) + fn int_neg(tensor: NdArrayTensor) -> NdArrayTensor { + Self::int_mul_scalar(tensor, (-1).elem()) } - fn int_zeros(shape: Shape, device: & as Backend>::Device) -> NdArrayTensor { + fn int_zeros(shape: Shape, device: & as Backend>::Device) -> NdArrayTensor { Self::int_from_data(TensorData::zeros::(shape), device) } - fn int_ones(shape: Shape, device: & as Backend>::Device) -> NdArrayTensor { + fn int_ones(shape: Shape, device: & as Backend>::Device) -> NdArrayTensor { Self::int_from_data(TensorData::ones::(shape), device) } fn int_full( shape: Shape, - fill_value: i64, + fill_value: I, device: & as Backend>::Device, - ) -> NdArrayTensor { + ) -> NdArrayTensor { Self::int_from_data(TensorData::full(shape, fill_value), device) } - fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor { + fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::sum(tensor) } - fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { NdArrayMathOps::sum_dim(tensor, dim) } - fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor { + fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::prod(tensor) } - fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { NdArrayMathOps::prod_dim(tensor, dim) } - fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor { + fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::mean(tensor) } - fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { NdArrayMathOps::mean_dim(tensor, dim) } fn int_gather( dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - ) -> NdArrayTensor { + tensor: NdArrayTensor, + indices: NdArrayTensor, + ) -> NdArrayTensor { NdArrayMathOps::gather(dim, tensor, indices) } fn int_scatter( dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { + tensor: NdArrayTensor, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { NdArrayMathOps::scatter(dim, tensor, indices, value) } fn int_select( - tensor: NdArrayTensor, + tensor: NdArrayTensor, dim: usize, - indices: NdArrayTensor, - ) -> NdArrayTensor { + indices: NdArrayTensor, + ) -> NdArrayTensor { NdArrayMathOps::select(tensor, dim, indices) } fn int_select_assign( - tensor: NdArrayTensor, + tensor: NdArrayTensor, dim: usize, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { NdArrayMathOps::select_assign(tensor, dim, indices, value) } - fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { NdArrayMathOps::argmax(tensor, dim) } - fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { NdArrayMathOps::argmin(tensor, dim) } - fn int_clamp_min(tensor: NdArrayTensor, min: i64) -> NdArrayTensor { + fn int_clamp_min(tensor: NdArrayTensor, min: I) -> NdArrayTensor { NdArrayMathOps::clamp_min(tensor, min) } - fn int_clamp_max(tensor: NdArrayTensor, max: i64) -> NdArrayTensor { + fn int_clamp_max(tensor: NdArrayTensor, max: I) -> NdArrayTensor { NdArrayMathOps::clamp_max(tensor, max) } - fn int_clamp(tensor: NdArrayTensor, min: i64, max: i64) -> NdArrayTensor { + fn int_clamp(tensor: NdArrayTensor, min: I, max: I) -> NdArrayTensor { NdArrayMathOps::clamp(tensor, min, max) } - fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor { + fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor.array.mapv_into(|a| a.int_abs_elem()).into_shared(); NdArrayTensor::new(array) } - fn int_into_float(tensor: NdArrayTensor) -> as Backend>::FloatTensorPrimitive { + fn int_into_float(tensor: NdArrayTensor) -> as Backend>::FloatTensorPrimitive { let array = tensor.array.mapv(|a| a.elem()).into_shared(); NdArrayTensor { array } } - fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor { + fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor { NdArrayOps::swap_dims(tensor, dim1, dim2) } @@ -291,7 +293,7 @@ impl IntTensorOps for NdArray NdArrayTensor { + ) -> NdArrayTensor { let mut seed = SEED.lock().unwrap(); let mut rng = if let Some(rng_seeded) = seed.as_ref() { rng_seeded.clone() @@ -313,32 +315,36 @@ impl IntTensorOps for NdArray, rhs: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::elementwise_op(lhs, rhs, |a: &i64, b: &i64| a.pow(*b as u32)) + fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::().pow(b.elem::())).elem() + }) } - fn int_powf(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - NdArrayMathOps::elementwise_op(lhs, rhs, |a: &i64, b: &E| a.pow(b.elem::())) + fn int_powf(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &E| { + (a.elem::().pow(b.elem::())).elem() + }) } - fn int_powf_scalar(lhs: NdArrayTensor, rhs: f32) -> NdArrayTensor { - NdArrayMathOps::elementwise_op_scalar(lhs, |a: i64| a.pow(rhs as u32)) + fn int_powf_scalar(lhs: NdArrayTensor, rhs: f32) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| (a.elem::().pow(rhs as u32)).elem()) } - fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { + fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { let array = tensor.array.permuted_axes(axes.into_dimension()); NdArrayTensor { array } } - fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { + fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { NdArrayOps::flip(tensor, axes) } - fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor { + fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::sign_op(tensor) } - fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { + fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { NdArrayOps::expand(tensor, shape) } } diff --git a/crates/burn-ndarray/src/ops/maxpool.rs b/crates/burn-ndarray/src/ops/maxpool.rs index dcd5e006d..09db5fec2 100644 --- a/crates/burn-ndarray/src/ops/maxpool.rs +++ b/crates/burn-ndarray/src/ops/maxpool.rs @@ -1,5 +1,5 @@ use crate::{ - element::{FloatNdArrayElement, QuantElement}, + element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, ops::padding::apply_padding_4d, sharing::UnsafeSharedRef, tensor::NdArrayTensor, @@ -9,7 +9,7 @@ use burn_common::{iter_range_par, run_par}; use burn_tensor::ElementConversion; use ndarray::Array4; -pub(crate) fn max_pool2d( +pub(crate) fn max_pool2d( x: NdArrayTensor, kernel_size: [usize; 2], stride: [usize; 2], @@ -30,7 +30,7 @@ pub(crate) fn max_pool2d( / stride_width) + 1; - let x = apply_padding_4d::(x, padding, inf).array; + let x = apply_padding_4d::(x, padding, inf).array; let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); @@ -69,13 +69,17 @@ pub(crate) fn max_pool2d( NdArrayTensor::new(output.into_dyn().into_shared()) } -pub(crate) fn max_pool2d_with_indices( +pub(crate) fn max_pool2d_with_indices< + E: FloatNdArrayElement, + I: IntNdArrayElement, + Q: QuantElement, +>( x: NdArrayTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], -) -> (NdArrayTensor, NdArrayTensor) { +) -> (NdArrayTensor, NdArrayTensor) { let [kernel_height, kernel_width] = kernel_size; let [padding_height, padding_width] = padding; let [stride_height, stride_width] = stride; @@ -90,10 +94,10 @@ pub(crate) fn max_pool2d_with_indices( / stride_width) + 1; - let x = apply_padding_4d::(x, padding, inf).array; + let x = apply_padding_4d::(x, padding, inf).array; let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); - let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); + let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices); @@ -130,7 +134,7 @@ pub(crate) fn max_pool2d_with_indices( } output[[b, c, oh, ow]] = max_val; - indices[[b, c, oh, ow]] = index; + indices[[b, c, oh, ow]] = index.elem(); } } }) @@ -142,14 +146,14 @@ pub(crate) fn max_pool2d_with_indices( (output, indices) } -pub(crate) fn max_pool2d_backward( +pub(crate) fn max_pool2d_backward( x: NdArrayTensor, _kernel_size: [usize; 2], _stride: [usize; 2], _padding: [usize; 2], _dilation: [usize; 2], output_grad: NdArrayTensor, - indices: NdArrayTensor, + indices: NdArrayTensor, ) -> NdArrayTensor { let [_batch_size, _channels, height, width] = output_grad.shape().dims(); let [batch_size, channels, height_x, width_x] = x.shape().dims(); @@ -170,7 +174,7 @@ pub(crate) fn max_pool2d_backward( for h in 0..height { for w in 0..width { - let index = indices[[b, c, h, w]]; + let index = indices[[b, c, h, w]].elem::(); let grad = output_grad[[b, c, h, w]]; let index_h = index as usize / width_x; diff --git a/crates/burn-ndarray/src/ops/module.rs b/crates/burn-ndarray/src/ops/module.rs index c8d351b31..9b2fc63b5 100644 --- a/crates/burn-ndarray/src/ops/module.rs +++ b/crates/burn-ndarray/src/ops/module.rs @@ -7,17 +7,22 @@ use super::{ maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices}, }; use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray}; -use crate::{element::QuantElement, ops::interpolate::nearest_interpolate_backward}; +use crate::{ + element::{IntNdArrayElement, QuantElement}, + ops::interpolate::nearest_interpolate_backward, +}; use burn_tensor::ops::*; -impl ModuleOps for NdArray { +impl ModuleOps + for NdArray +{ fn conv2d( x: NdArrayTensor, weight: NdArrayTensor, bias: Option>, options: ConvOptions<2>, ) -> NdArrayTensor { - conv2d::(x, weight, bias, options) + conv2d::(x, weight, bias, options) } fn deform_conv2d( @@ -80,7 +85,7 @@ impl ModuleOps for NdArray padding: [usize; 2], dilation: [usize; 2], ) -> NdArrayTensor { - max_pool2d::(x, kernel_size, stride, padding, dilation) + max_pool2d::(x, kernel_size, stride, padding, dilation) } fn max_pool2d_with_indices( @@ -89,9 +94,9 @@ impl ModuleOps for NdArray stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], - ) -> MaxPool2dWithIndices> { + ) -> MaxPool2dWithIndices> { let (output, indices) = - max_pool2d_with_indices::(x, kernel_size, stride, padding, dilation); + max_pool2d_with_indices::(x, kernel_size, stride, padding, dilation); MaxPool2dWithIndices::new(output, indices) } @@ -103,8 +108,8 @@ impl ModuleOps for NdArray padding: [usize; 2], dilation: [usize; 2], output_grad: NdArrayTensor, - indices: NdArrayTensor, - ) -> MaxPool2dBackward> { + indices: NdArrayTensor, + ) -> MaxPool2dBackward> { MaxPool2dBackward::new(max_pool2d_backward( x, kernel_size, @@ -162,7 +167,7 @@ impl ModuleOps for NdArray bias: Option>, options: ConvOptions<3>, ) -> NdArrayTensor { - conv3d::(x, weight, bias, options) + conv3d::(x, weight, bias, options) } fn conv_transpose3d( diff --git a/crates/burn-ndarray/src/ops/padding.rs b/crates/burn-ndarray/src/ops/padding.rs index c7af8ad9d..c5879c045 100644 --- a/crates/burn-ndarray/src/ops/padding.rs +++ b/crates/burn-ndarray/src/ops/padding.rs @@ -1,12 +1,12 @@ use crate::{ - element::{FloatNdArrayElement, QuantElement}, + element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, tensor::NdArrayTensor, NdArray, }; use burn_tensor::ops::FloatTensorOps; use ndarray::{Array4, Array5}; -pub(crate) fn apply_padding_4d( +pub(crate) fn apply_padding_4d( x: NdArrayTensor, padding: [usize; 2], elem: E, @@ -22,7 +22,7 @@ pub(crate) fn apply_padding_4d( ); let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn()); - x_new = NdArray::::float_slice_assign( + x_new = NdArray::::float_slice_assign( x_new, &[ 0..batch_size, @@ -36,7 +36,7 @@ pub(crate) fn apply_padding_4d( x_new } -pub(crate) fn apply_padding_5d( +pub(crate) fn apply_padding_5d( x: NdArrayTensor, padding: [usize; 3], elem: E, @@ -59,7 +59,7 @@ pub(crate) fn apply_padding_5d( ); let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn()); - x_new = NdArray::::float_slice_assign( + x_new = NdArray::::float_slice_assign( x_new, &[ 0..batch_size, diff --git a/crates/burn-ndarray/src/ops/qtensor.rs b/crates/burn-ndarray/src/ops/qtensor.rs index 91bd58a3d..c597fae87 100644 --- a/crates/burn-ndarray/src/ops/qtensor.rs +++ b/crates/burn-ndarray/src/ops/qtensor.rs @@ -10,7 +10,7 @@ use burn_tensor::{ }; use crate::{ - element::{NdArrayElement, QuantElement}, + element::{IntNdArrayElement, NdArrayElement, QuantElement}, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, }; @@ -22,7 +22,9 @@ fn into_data(tensor: NdArrayTensor) -> TensorData { TensorData::new(values, shape) } -impl QTensorOps for NdArray { +impl QTensorOps + for NdArray +{ fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor { match data.dtype { DType::QFloat(strategy) => match strategy { diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index 7f7ef43fd..49d038d7b 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -5,7 +5,7 @@ use ndarray::Zip; // Current crate use super::{matmul::matmul, NdArrayMathOps, NdArrayOps}; -use crate::element::{FloatNdArrayElement, QuantElement}; +use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}; use crate::{tensor::NdArrayTensor, NdArray}; use crate::{NdArrayDevice, SEED}; @@ -20,7 +20,9 @@ use num_traits::Float; use libm::erf; -impl FloatTensorOps for NdArray { +impl FloatTensorOps + for NdArray +{ fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor { NdArrayTensor::from_data(data) } @@ -125,7 +127,7 @@ impl FloatTensorOps for NdArray, - indices: NdArrayTensor, + indices: NdArrayTensor, ) -> NdArrayTensor { NdArrayMathOps::gather(dim, tensor, indices) } @@ -133,7 +135,7 @@ impl FloatTensorOps for NdArray, - indices: NdArrayTensor, + indices: NdArrayTensor, value: NdArrayTensor, ) -> NdArrayTensor { NdArrayMathOps::scatter(dim, tensor, indices, value) @@ -142,7 +144,7 @@ impl FloatTensorOps for NdArray, dim: usize, - indices: NdArrayTensor, + indices: NdArrayTensor, ) -> NdArrayTensor { NdArrayMathOps::select(tensor, dim, indices) } @@ -150,7 +152,7 @@ impl FloatTensorOps for NdArray, dim: usize, - indices: NdArrayTensor, + indices: NdArrayTensor, value: NdArrayTensor, ) -> NdArrayTensor { NdArrayMathOps::select_assign(tensor, dim, indices, value) @@ -266,11 +268,11 @@ impl FloatTensorOps for NdArray, dim: usize) -> NdArrayTensor { + fn float_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { NdArrayMathOps::argmax(tensor, dim) } - fn float_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + fn float_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { NdArrayMathOps::argmin(tensor, dim) } @@ -374,7 +376,7 @@ impl FloatTensorOps for NdArray) -> as Backend>::IntTensorPrimitive { + fn float_into_int(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor.array.mapv(|a| a.elem()).into_shared(); NdArrayTensor { array } } diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index e59d73529..22141d14d 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -212,7 +212,7 @@ mod tests { let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); let qparams = QuantizationParametersPrimitive { scale: NdArrayTensor::from_data(TensorData::from([0.009_019_608])), - offset: Some(NdArrayTensor::from_data(TensorData::from([72]))), + offset: Some(NdArrayTensor::::from_data(TensorData::from([72]))), }; let qtensor: NdArrayQTensor = NdArray::quantize(tensor, &scheme, qparams); diff --git a/crates/burn-router/Cargo.toml b/crates/burn-router/Cargo.toml new file mode 100644 index 000000000..aa93faf1f --- /dev/null +++ b/crates/burn-router/Cargo.toml @@ -0,0 +1,39 @@ +[package] +authors = ["guillaumelagrange ", "nathanielsimard "] +categories = ["science"] +description = "Multi-backend router decorator for the Burn framework" +edition.workspace = true +keywords = ["deep-learning", "machine-learning", "data"] +license.workspace = true +name = "burn-router" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router" +documentation = "https://docs.rs/burn-router" +version.workspace = true + +[features] +default = ["std"] +std = [] +doc = ["default"] + +[dependencies] +burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false, features = ["repr"]} +hashbrown = { workspace = true } +spin = { workspace = true } + + +[dev-dependencies] +burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", default-features = false, features = [ + "export_tests", +] } +burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false, features = [ + "export_tests", +] } + +burn-ndarray = { path = "../burn-ndarray", version = "0.15.0" } +burn-wgpu = { path = "../burn-wgpu", version = "0.15.0" } + + +[package.metadata.docs.rs] +features = ["doc"] +rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/burn-router/README.md b/crates/burn-router/README.md new file mode 100644 index 000000000..be3e69a83 --- /dev/null +++ b/crates/burn-router/README.md @@ -0,0 +1,3 @@ +# Burn Router + +A multi-backend extension that forwards the tensor operations to the appropriate backend. diff --git a/crates/burn-router/src/backend.rs b/crates/burn-router/src/backend.rs new file mode 100644 index 000000000..03f555edc --- /dev/null +++ b/crates/burn-router/src/backend.rs @@ -0,0 +1,125 @@ +use alloc::{format, string::String}; +use core::marker::PhantomData; + +use burn_tensor::{ + backend::{Backend, BackendBridge}, + ops::FloatTensor, + quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy}, + repr::{BaseOperationDescription, OperationDescription, UnaryOperationDescription}, + Device, +}; + +use super::{get_client, set_seed, RouterTensor, RunnerChannel, RunnerClient}; + +/// A backend that forwards the tensor operations to the appropiate backend (given multiple backends). +pub struct BackendRouter { + r: PhantomData, +} + +impl core::fmt::Debug for BackendRouter { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("router")) + } +} + +impl Clone for BackendRouter { + fn clone(&self) -> Self { + Self { r: PhantomData } + } +} + +impl Default for BackendRouter { + fn default() -> Self { + Self { r: PhantomData } + } +} + +// TODO: quantization tensor primitive (w/ qparams) +impl QTensorPrimitive for RouterTensor { + fn scheme(&self) -> &QuantizationScheme { + todo!() + } + + fn strategy(&self) -> QuantizationStrategy { + todo!() + } +} + +impl Backend for BackendRouter { + type Device = R::Device; + + type FullPrecisionBridge = PrecisionBridge; + + type FloatTensorPrimitive = RouterTensor; + + type FloatElem = R::FloatElem; + + type IntTensorPrimitive = RouterTensor; + + type IntElem = R::IntElem; + + type BoolTensorPrimitive = RouterTensor; + + type QuantizedTensorPrimitive = RouterTensor; + + type QuantizedEncoding = u32; + + fn name() -> String { + format!("router<{}>", R::name()) + } + + fn seed(seed: u64) { + set_seed(seed) + } + + fn sync(device: &Self::Device) { + let client = get_client::(device); + client.sync(); + } +} + +/// Handle precision conversion. +#[derive(Debug)] +pub struct PrecisionBridge {} + +impl BackendBridge> for PrecisionBridge { + type Target = BackendRouter; + + fn into_target( + tensor: FloatTensor>, + _device: Option>, + ) -> FloatTensor { + let client = tensor.client.clone(); + let out = client.register_float_tensor(tensor.shape.clone(), true); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Cast(desc), + )); + + out + } + + fn from_target( + tensor: FloatTensor, + _device: Option>>, + ) -> FloatTensor> { + let client = tensor.client.clone(); + let out = client.register_float_tensor(tensor.shape.clone(), false); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Cast(desc), + )); + + out + } +} diff --git a/crates/burn-router/src/bridge/base.rs b/crates/burn-router/src/bridge/base.rs new file mode 100644 index 000000000..d5f338c76 --- /dev/null +++ b/crates/burn-router/src/bridge/base.rs @@ -0,0 +1,32 @@ +use burn_tensor::{backend::DeviceOps, Shape}; + +/// Allows tensors to be transferred between multiple backends. +pub trait MultiBackendBridge: Send + Sync + 'static { + /// The type that can be used to point to a tensor of any kind. + type TensorHandle; + /// Device type used by the backends. + type Device: DeviceOps; + + /// Change the backend of the given float tensor. + fn change_backend_float( + tensor: Self::TensorHandle, + shape: Shape, + target_device: &Self::Device, + ) -> Self::TensorHandle; + + /// Change the backend of the given int tensor. + fn change_backend_int( + tensor: Self::TensorHandle, + shape: Shape, + target_device: &Self::Device, + ) -> Self::TensorHandle; + + /// Change the backend of the given bool tensor. + fn change_backend_bool( + tensor: Self::TensorHandle, + shape: Shape, + target_device: &Self::Device, + ) -> Self::TensorHandle; + + // TODO: change_backend_quantized +} diff --git a/crates/burn-router/src/bridge/byte.rs b/crates/burn-router/src/bridge/byte.rs new file mode 100644 index 000000000..94e3432a6 --- /dev/null +++ b/crates/burn-router/src/bridge/byte.rs @@ -0,0 +1,171 @@ +use core::marker::PhantomData; + +use burn_tensor::{ + repr::{ReprBackend, TensorHandle}, + try_read_sync, Shape, +}; + +use super::base::MultiBackendBridge; +use crate::{MultiDevice2, TensorHandle2}; + +/// Simply transfers tensors between backends via the underlying [tensor data](burn_tensor::TensorData). +pub struct ByteBridge { + backends: PhantomData, +} + +impl MultiBackendBridge for ByteBridge<(B1, B2)> { + type TensorHandle = TensorHandle2; + type Device = MultiDevice2; + + fn change_backend_float( + tensor: Self::TensorHandle, + shape: Shape, + target_device: &Self::Device, + ) -> Self::TensorHandle { + let msg = "Failed to read tensor data synchronously. +This can happen on platforms that don't support blocking futures like WASM."; + match tensor { + TensorHandle2::Handle1(handle) => match target_device { + MultiDevice2::Device1(device) => { + // Same backend + let tensor = B1::float_tensor(TensorHandle { handle, shape }); + let tensor = B1::float_to_device(tensor, device); + let handle = B1::float_tensor_handle(tensor); + TensorHandle2::Handle1(handle) + } + MultiDevice2::Device2(device) => { + let tensor = B1::float_tensor(TensorHandle { handle, shape }); + let data = try_read_sync(B1::float_into_data(tensor)).expect(msg); + let tensor = B2::float_from_data(data, device); + let handle = B2::float_tensor_handle(tensor); + TensorHandle2::Handle2(handle) + } + }, + TensorHandle2::Handle2(handle) => match target_device { + MultiDevice2::Device1(device) => { + let tensor = B2::float_tensor(TensorHandle { handle, shape }); + let data = try_read_sync(B2::float_into_data(tensor)).expect(msg); + let tensor = B1::float_from_data(data, device); + let handle = B1::float_tensor_handle(tensor); + TensorHandle2::Handle1(handle) + } + MultiDevice2::Device2(device) => { + // Same backend + let tensor = B2::float_tensor(TensorHandle { handle, shape }); + let tensor = B2::float_to_device(tensor, device); + let handle = B2::float_tensor_handle(tensor); + TensorHandle2::Handle2(handle) + } + }, + } + } + + fn change_backend_int( + tensor: Self::TensorHandle, + shape: Shape, + target_device: &Self::Device, + ) -> Self::TensorHandle { + let msg = "Failed to read tensor data synchronously. +This can happen on platforms that don't support blocking futures like WASM."; + match tensor { + TensorHandle2::Handle1(handle) => match target_device { + MultiDevice2::Device1(device) => { + // Same backend + let tensor = B1::int_tensor(TensorHandle { handle, shape }); + let tensor = B1::int_to_device(tensor, device); + let handle = B1::int_tensor_handle(tensor); + TensorHandle2::Handle1(handle) + } + MultiDevice2::Device2(device) => { + let tensor = B1::int_tensor(TensorHandle { handle, shape }); + let data = try_read_sync(B1::int_into_data(tensor)).expect(msg); + let tensor = B2::int_from_data(data, device); + let handle = B2::int_tensor_handle(tensor); + TensorHandle2::Handle2(handle) + } + }, + TensorHandle2::Handle2(handle) => match target_device { + MultiDevice2::Device1(device) => { + let tensor = B2::int_tensor(TensorHandle { handle, shape }); + let data = try_read_sync(B2::int_into_data(tensor)).expect(msg); + let tensor = B1::int_from_data(data, device); + let handle = B1::int_tensor_handle(tensor); + TensorHandle2::Handle1(handle) + } + MultiDevice2::Device2(device) => { + // Same backend + let tensor = B2::int_tensor(TensorHandle { handle, shape }); + let tensor = B2::int_to_device(tensor, device); + let handle = B2::int_tensor_handle(tensor); + TensorHandle2::Handle2(handle) + } + }, + } + } + + fn change_backend_bool( + tensor: Self::TensorHandle, + shape: Shape, + target_device: &Self::Device, + ) -> Self::TensorHandle { + let msg = "Failed to read tensor data synchronously. + This can happen on platforms that don't support blocking futures like WASM."; + match tensor { + TensorHandle2::Handle1(handle) => match target_device { + MultiDevice2::Device1(device) => { + // Same backend + let tensor = B1::bool_tensor(TensorHandle { handle, shape }); + let tensor = B1::bool_to_device(tensor, device); + let handle = B1::bool_tensor_handle(tensor); + TensorHandle2::Handle1(handle) + } + MultiDevice2::Device2(device) => { + let tensor = B1::bool_tensor(TensorHandle { handle, shape }); + let data = try_read_sync(B1::bool_into_data(tensor)).expect(msg); + let tensor = B2::bool_from_data(data, device); + let handle = B2::bool_tensor_handle(tensor); + TensorHandle2::Handle2(handle) + } + }, + TensorHandle2::Handle2(handle) => match target_device { + MultiDevice2::Device1(device) => { + let tensor = B2::bool_tensor(TensorHandle { handle, shape }); + let data = try_read_sync(B2::bool_into_data(tensor)).expect(msg); + let tensor = B1::bool_from_data(data, device); + let handle = B1::bool_tensor_handle(tensor); + TensorHandle2::Handle1(handle) + } + MultiDevice2::Device2(device) => { + // Same backend + let tensor = B2::bool_tensor(TensorHandle { handle, shape }); + let tensor = B2::bool_to_device(tensor, device); + let handle = B2::bool_tensor_handle(tensor); + TensorHandle2::Handle2(handle) + } + }, + } + } +} + +#[cfg(not(target_os = "windows"))] +#[cfg(test)] +mod tests { + use burn_tensor::{backend::Backend, Tensor}; + + use super::*; + use crate::tests::{TestBackend, TestBackend1, TestBackend2}; + + #[test] + fn should_support_dual_byte_bridge() { + let device1 = MultiDevice2::Device1(::Device::default()); + let device2 = MultiDevice2::Device2(::Device::default()); + let tensor1 = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0], &device1); + let tensor2 = Tensor::::from_floats([5.0, 6.0, 7.0, 8.0], &device2); + + let tensor1_2 = tensor1.clone().to_device(&device2); + tensor1.into_data().assert_eq(&tensor1_2.into_data(), true); + + let tensor2_1 = tensor2.clone().to_device(&device1); + tensor2.into_data().assert_eq(&tensor2_1.into_data(), true); + } +} diff --git a/crates/burn-router/src/bridge/mod.rs b/crates/burn-router/src/bridge/mod.rs new file mode 100644 index 000000000..d43da0ed0 --- /dev/null +++ b/crates/burn-router/src/bridge/mod.rs @@ -0,0 +1,5 @@ +mod base; +mod byte; + +pub use base::*; +pub use byte::*; diff --git a/crates/burn-router/src/channel/base.rs b/crates/burn-router/src/channel/base.rs new file mode 100644 index 000000000..876d273f6 --- /dev/null +++ b/crates/burn-router/src/channel/base.rs @@ -0,0 +1,68 @@ +use alloc::{string::String, vec::Vec}; +use burn_tensor::{backend::DeviceOps, repr::TensorDescription, DType, Element}; + +use crate::{get_client, MultiBackendBridge, RouterTensor, RunnerClient}; + +/// Type alias for `
::TensorHandle`. +pub type TensorHandle
=
::TensorHandle; + +/// Defines the connection channel and operations for a setup with multiple backend runner clients. +pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized { + /// Device type. + type Device: DeviceOps; + /// A bridge that can transfer tensors between multiple backends. + type Bridge: MultiBackendBridge; + /// Client type. + type Client: RunnerClient; + /// Float element type. + type FloatElem: Element; + /// Int element type. + type IntElem: Element; + + /// Name of the channel. + fn name() -> String; + + /// Initialize a new client for the given device. + fn init_client(device: &Self::Device) -> Self::Client; + + /// Get the tensor handle corresponding to the [tensor description](TensorDescription). + fn get_tensor_handle( + tensor: &TensorDescription, + client: &Self::Client, + ) -> TensorHandle; + + // TODO: get quantized tensor handle from QuantizedTensorDescription + + /// Create a tensor with the given handle and shape. + fn register_tensor( + client: &Self::Client, + handle: TensorHandle, + shape: Vec, + dtype: DType, + ) -> RouterTensor; + + /// Change the tensor to a different client backend. + fn change_client_backend( + tensor: RouterTensor, + device: &Self::Device, // target device + ) -> RouterTensor { + // Get tensor handle from current client + let original_client = tensor.client.clone(); + let desc = tensor.into_description(); + let mut handle = Self::get_tensor_handle(&desc, &original_client); + + if desc.dtype.is_float() { + handle = Self::Bridge::change_backend_float(handle, desc.shape.clone().into(), device); + } else if desc.dtype.is_int() { + handle = Self::Bridge::change_backend_int(handle, desc.shape.clone().into(), device); + } else if desc.dtype.is_bool() { + handle = Self::Bridge::change_backend_bool(handle, desc.shape.clone().into(), device); + } else { + unimplemented!() + } + + // Register tensor handle on target client + let target_client = get_client::(device); + Self::register_tensor(&target_client, handle, desc.shape, desc.dtype) + } +} diff --git a/crates/burn-router/src/channel/direct.rs b/crates/burn-router/src/channel/direct.rs new file mode 100644 index 000000000..1aa7c3bfd --- /dev/null +++ b/crates/burn-router/src/channel/direct.rs @@ -0,0 +1,271 @@ +use alloc::{format, string::String, sync::Arc, vec::Vec}; +use core::marker::PhantomData; + +use burn_tensor::{ + backend::{Backend, BackendBridge, DeviceId, DeviceOps}, + repr::{OperationDescription, ReprBackend, TensorDescription, TensorId}, + DType, TensorData, +}; + +use super::{RunnerChannel, TensorHandle}; +use crate::{MultiBackendBridge, RouterTensor, Runner, RunnerClient}; + +/// A local channel with direct connection to the backend runner clients. +pub struct DirectChannel { + backends: PhantomData, + bridge: PhantomData, +} + +impl Clone for DirectChannel { + fn clone(&self) -> Self { + Self { + backends: self.backends, + bridge: self.bridge, + } + } +} + +impl RunnerChannel for DirectChannel<(B1, B2), Br> +where + B1: ReprBackend, + B2: ReprBackend, + Br: MultiBackendBridge, Device = MultiDevice2>, + // Restrict full precision backend handle to be the same + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, +{ + type Device = Br::Device; + + type Bridge = Br; + + type FloatElem = B1::FloatElem; + type IntElem = B1::IntElem; + + type Client = MultiRunnerClient2; + + fn init_client(device: &Self::Device) -> Self::Client { + match device { + MultiDevice2::Device1(device) => { + MultiRunnerClient2::RunnerClient1(Runner::new(device.clone())) + } + MultiDevice2::Device2(device) => { + MultiRunnerClient2::RunnerClient2(Runner::new(device.clone())) + } + } + } + + fn get_tensor_handle( + tensor: &TensorDescription, + client: &Self::Client, + ) -> TensorHandle { + match client { + MultiRunnerClient2::RunnerClient1(runner) => { + TensorHandle2::Handle1(runner.get_tensor_handle(tensor)) + } + MultiRunnerClient2::RunnerClient2(runner) => { + TensorHandle2::Handle2(runner.get_tensor_handle(tensor)) + } + } + } + + fn register_tensor( + client: &Self::Client, + handle: TensorHandle, + shape: Vec, + dtype: DType, + ) -> RouterTensor { + match client { + MultiRunnerClient2::RunnerClient1(runner) => match handle { + TensorHandle2::Handle1(handle) => { + runner.register_tensor(handle, shape, dtype, client.clone()) + } + TensorHandle2::Handle2(_) => { + unreachable!("Can't register tensor handle for another backend.") + } + }, + MultiRunnerClient2::RunnerClient2(runner) => match handle { + TensorHandle2::Handle1(_) => { + unreachable!("Can't register tensor handle for another backend.") + } + TensorHandle2::Handle2(handle) => { + runner.register_tensor(handle, shape, dtype, client.clone()) + } + }, + } + } + + fn name() -> String { + format!("direct<({}, {})>", B1::name(), B2::name()) + } +} + +// TODO: generate this for different number of backends (up to 4?) + +/// Handle type to interact with two backends. +pub enum TensorHandle2 { + /// Handle for the first backend. + Handle1(B1::Handle), + /// Handle for the second backend. + Handle2(B2::Handle), +} + +/// Device type to interact with two backends. +#[derive(Clone, Debug)] +pub enum MultiDevice2 { + /// Device for the first backend. + Device1(B1::Device), + /// Device for the second backend. + Device2(B2::Device), +} + +impl PartialEq for MultiDevice2 { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Device1(lhs), Self::Device1(rhs)) => lhs == rhs, + (Self::Device2(lhs), Self::Device2(rhs)) => lhs == rhs, + _ => false, + } + } +} + +impl Eq for MultiDevice2 {} + +impl Default for MultiDevice2 { + fn default() -> Self { + Self::Device1(B1::Device::default()) + } +} + +impl DeviceOps for MultiDevice2 { + fn id(&self) -> DeviceId { + match self { + MultiDevice2::Device1(device) => device.id(), + MultiDevice2::Device2(device) => device.id(), + } + } +} + +/// Local [`RunnerClient`] with two backends. +#[derive(Clone)] +pub enum MultiRunnerClient2 { + /// Client for the first backend runner. + RunnerClient1(Runner), + /// Client for the second backend runner. + RunnerClient2(Runner), +} + +impl RunnerClient for MultiRunnerClient2 +where + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, +{ + type Device = MultiDevice2; + + fn register(&self, op: OperationDescription) { + match self { + MultiRunnerClient2::RunnerClient1(runner) => runner.register(op), + MultiRunnerClient2::RunnerClient2(runner) => runner.register(op), + } + } + + async fn read_tensor(&self, tensor: TensorDescription) -> TensorData { + match self { + MultiRunnerClient2::RunnerClient1(runner) => runner.read_tensor(tensor).await, + MultiRunnerClient2::RunnerClient2(runner) => runner.read_tensor(tensor).await, + } + } + + fn register_tensor_data(&self, data: TensorData) -> RouterTensor { + match self { + MultiRunnerClient2::RunnerClient1(runner) => { + let desc = runner.register_tensor_data_desc(data); + RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) + } + MultiRunnerClient2::RunnerClient2(runner) => { + let desc = runner.register_tensor_data_desc(data); + RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) + } + } + } + + fn register_empty_tensor(&self, shape: Vec, dtype: DType) -> RouterTensor { + match self { + MultiRunnerClient2::RunnerClient1(runner) => { + let desc = runner.register_empty_tensor_desc(shape, dtype); + RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) + } + MultiRunnerClient2::RunnerClient2(runner) => { + let desc = runner.register_empty_tensor_desc(shape, dtype); + RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) + } + } + } + + fn register_float_tensor(&self, shape: Vec, full_precision: bool) -> RouterTensor { + match self { + MultiRunnerClient2::RunnerClient1(runner) => { + let desc = runner.register_float_tensor_desc(shape, full_precision); + RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) + } + MultiRunnerClient2::RunnerClient2(runner) => { + let desc = runner.register_float_tensor_desc(shape, full_precision); + RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) + } + } + } + + fn device(&self) -> Self::Device { + match self { + MultiRunnerClient2::RunnerClient1(runner) => MultiDevice2::Device1(runner.device()), + MultiRunnerClient2::RunnerClient2(runner) => MultiDevice2::Device2(runner.device()), + } + } + + fn register_orphan(&self, id: &TensorId) { + match self { + MultiRunnerClient2::RunnerClient1(runner) => runner.register_orphan(id), + MultiRunnerClient2::RunnerClient2(runner) => runner.register_orphan(id), + } + } + + fn sync(&self) { + match self { + MultiRunnerClient2::RunnerClient1(runner) => runner.sync(), + MultiRunnerClient2::RunnerClient2(runner) => runner.sync(), + } + } + + fn seed(&self, seed: u64) { + match self { + MultiRunnerClient2::RunnerClient1(runner) => runner.seed(seed), + MultiRunnerClient2::RunnerClient2(runner) => runner.seed(seed), + } + } +} + +// NOTE: conflicting implementations because B1 and B2 cannot be differentiated (could be the same type) +// impl From>> +// for RouterTensor> +// { +// fn from(value: RouterTensor>) -> Self { +// RouterTensor { +// desc: value.desc, +// client: MultiRunnerClient2::RunnerClient1(value.client), +// } +// } +// } + +// impl From>> +// for RouterTensor> +// { +// fn from(value: RouterTensor>) -> Self { +// RouterTensor { +// desc: value.desc, +// client: MultiRunnerClient2::RunnerClient2(value.client), +// } +// } +// } diff --git a/crates/burn-router/src/channel/mod.rs b/crates/burn-router/src/channel/mod.rs new file mode 100644 index 000000000..5617df40f --- /dev/null +++ b/crates/burn-router/src/channel/mod.rs @@ -0,0 +1,5 @@ +mod base; +mod direct; + +pub use base::*; +pub use direct::*; diff --git a/crates/burn-router/src/client/base.rs b/crates/burn-router/src/client/base.rs new file mode 100644 index 000000000..47781996f --- /dev/null +++ b/crates/burn-router/src/client/base.rs @@ -0,0 +1,138 @@ +use alloc::{boxed::Box, vec::Vec}; +use core::{ + future::Future, + ops::DerefMut, + sync::atomic::{AtomicBool, AtomicU64, Ordering}, +}; +use hashbrown::HashMap; + +use spin::Mutex; + +use burn_tensor::{ + backend::{DeviceId, DeviceOps}, + repr::{OperationDescription, TensorDescription, TensorId}, + DType, TensorData, +}; + +use crate::{RouterTensor, RunnerChannel}; + +/// Type alias for `::Client`. +pub type Client = ::Client; +pub(crate) static CLIENTS: RunnerClientLocator = RunnerClientLocator::new(); +static SEED_SET: AtomicBool = AtomicBool::new(false); +static SEED: AtomicU64 = AtomicU64::new(0); +type Key = (core::any::TypeId, DeviceId); + +/// Define how to interact with the runner. +pub trait RunnerClient: Clone + Send + Sync + Sized { + /// Device type. + type Device: DeviceOps; + + /// Register a new tensor operation to be executed by the (runner) server. + fn register(&self, op: OperationDescription); + /// Read the values contained by a tensor. + fn read_tensor(&self, tensor: TensorDescription) -> impl Future + Send; + /// Create a new [RouterTensor] from the tensor data. + fn register_tensor_data(&self, data: TensorData) -> RouterTensor; + /// Create a new [RouterTensor] with no resources associated. + fn register_empty_tensor(&self, shape: Vec, dtype: DType) -> RouterTensor; + /// Create a new float [RouterTensor] with no resources associated. + fn register_float_tensor(&self, shape: Vec, full_precision: bool) -> RouterTensor; + /// Get the current device used by all operations handled by this client. + fn device(&self) -> Self::Device; + /// Drop the tensor with the given [tensor id](TensorId). + fn register_orphan(&self, id: &TensorId); + /// Sync the runner, ensure that all computations are finished. + fn sync(&self); + /// Seed the runner. + fn seed(&self, seed: u64); +} + +pub(crate) struct RunnerClientLocator { + clients: Mutex>>>, +} + +pub(crate) fn get_client(device: &R::Device) -> Client { + CLIENTS.client::(device) +} + +pub(crate) fn set_seed(seed: u64) { + SEED_SET.store(true, Ordering::Relaxed); + SEED.store(seed, Ordering::Relaxed); +} + +fn get_seed() -> Option { + if SEED_SET.load(Ordering::Relaxed) { + Some(SEED.load(Ordering::Relaxed)) + } else { + None + } +} + +/// Initialize a new client for the given device. +/// +/// If a (global) seed was previously set, the client seed is set. +fn new_client(device: &R::Device) -> Client { + let client = R::init_client(device); + if let Some(seed) = get_seed() { + client.seed(seed) + } + client +} + +impl RunnerClientLocator { + /// Create a new client locator. + pub const fn new() -> Self { + Self { + clients: Mutex::new(None), + } + } + + /// Get the runner client for the given device. + /// + /// If a client isn't already initialized, it is created. + pub fn client(&self, device: &R::Device) -> Client { + let device_id = device.id(); + let client_id = (core::any::TypeId::of::(), device_id); + let mut clients = self.clients.lock(); + + if clients.is_none() { + let client = new_client::(device); + Self::register_inner::(client_id, client, &mut clients); + } + + match clients.deref_mut() { + Some(clients) => match clients.get(&client_id) { + Some(client) => { + let client: &Client = client.downcast_ref().unwrap(); + client.clone() + } + None => { + let client = new_client::(device); + let any = Box::new(client.clone()); + clients.insert(client_id, any); + client + } + }, + _ => unreachable!(), + } + } + + fn register_inner( + key: Key, + client: Client, + clients: &mut Option>>, + ) { + if clients.is_none() { + *clients = Some(HashMap::new()); + } + + if let Some(clients) = clients { + if clients.contains_key(&key) { + panic!("Client already created for device {:?}", key); + } + + clients.insert(key, Box::new(client)); + } + } +} diff --git a/crates/burn-router/src/client/mod.rs b/crates/burn-router/src/client/mod.rs new file mode 100644 index 000000000..cbcb6ac7e --- /dev/null +++ b/crates/burn-router/src/client/mod.rs @@ -0,0 +1,3 @@ +mod base; + +pub use base::*; diff --git a/crates/burn-router/src/lib.rs b/crates/burn-router/src/lib.rs new file mode 100644 index 000000000..279278685 --- /dev/null +++ b/crates/burn-router/src/lib.rs @@ -0,0 +1,50 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![warn(missing_docs)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] + +//! Burn multi-backend router. + +mod backend; +mod bridge; +mod channel; +mod client; +mod ops; +mod runner; +mod tensor; + +pub use backend::*; +pub use bridge::*; +pub use channel::*; +pub use client::*; +pub use runner::*; +pub use tensor::*; + +extern crate alloc; + +#[cfg(test)] +mod tests { + use alloc::format; + use alloc::vec; + + use crate::BackendRouter; + use crate::ByteBridge; + use crate::DirectChannel; + + type DirectByteChannel = DirectChannel>; + + pub type TestBackend1 = burn_ndarray::NdArray; + pub type TestBackend2 = burn_wgpu::Wgpu; + pub type TestBackend = BackendRouter>; + + pub type TestTensor = burn_tensor::Tensor; + pub type TestTensorInt = burn_tensor::Tensor; + pub type TestTensorBool = + burn_tensor::Tensor; + + burn_tensor::testgen_all!(); + // TODO: add support for quantization + // burn_tensor::testgen_quantization!(); + + #[cfg(feature = "std")] + burn_autodiff::testgen_all!(); +} diff --git a/crates/burn-router/src/ops/binary.rs b/crates/burn-router/src/ops/binary.rs new file mode 100644 index 000000000..534c978bd --- /dev/null +++ b/crates/burn-router/src/ops/binary.rs @@ -0,0 +1,55 @@ +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! binary_float_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_float_tensor::(&$desc.lhs); + let rhs = $handles.get_float_tensor::(&$desc.rhs); + let output = $ops(lhs, rhs); + + $handles.register_float_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! binary_float_cmp_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_float_tensor::(&$desc.lhs); + let rhs = $handles.get_float_tensor::(&$desc.rhs); + let output = $ops(lhs, rhs); + + $handles.register_bool_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! binary_int_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_int_tensor::(&$desc.lhs); + let rhs = $handles.get_int_tensor::(&$desc.rhs); + let output = $ops(lhs, rhs); + + $handles.register_int_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! binary_int_cmp_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_int_tensor::(&$desc.lhs); + let rhs = $handles.get_int_tensor::(&$desc.rhs); + let output = $ops(lhs, rhs); + + $handles.register_bool_tensor::(&$desc.out.id, output); + }}; +} diff --git a/crates/burn-router/src/ops/mod.rs b/crates/burn-router/src/ops/mod.rs new file mode 100644 index 000000000..2cfc73b85 --- /dev/null +++ b/crates/burn-router/src/ops/mod.rs @@ -0,0 +1,8 @@ +mod binary; +mod op_activation; +mod op_bool; +mod op_float; +mod op_int; +mod op_module; +mod op_qfloat; +mod unary; diff --git a/crates/burn-router/src/ops/op_activation.rs b/crates/burn-router/src/ops/op_activation.rs new file mode 100644 index 000000000..09c99aa14 --- /dev/null +++ b/crates/burn-router/src/ops/op_activation.rs @@ -0,0 +1,4 @@ +use crate::{BackendRouter, RunnerChannel}; +use burn_tensor::ops::ActivationOps; + +impl ActivationOps for BackendRouter {} diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs new file mode 100644 index 000000000..e58263130 --- /dev/null +++ b/crates/burn-router/src/ops/op_bool.rs @@ -0,0 +1,302 @@ +use alloc::vec::Vec; + +use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntElem, IntTensor}; +use burn_tensor::repr::{ + BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, + CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, + OperationDescription, PermuteOperationDescription, RepeatDimOperationDescription, + ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, +}; +use burn_tensor::{DType, Device, Element, Shape, TensorData}; + +use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; + +impl BoolTensorOps for BackendRouter { + fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + // Get the runtime client on which to register the operation for execution. + let client = get_client::(device); + let out = client.register_empty_tensor(shape.into(), DType::Bool); + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::Empty(out.to_description_out()), + )); + + out + } + + fn bool_shape(tensor: &BoolTensor) -> Shape { + Shape::from(tensor.shape.clone()) + } + + async fn bool_into_data(tensor: BoolTensor) -> TensorData { + tensor.into_data().await + } + + fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { + let client = get_client::(device); + client.register_tensor_data(data.convert::()) + } + + fn bool_into_int(tensor: BoolTensor) -> IntTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(tensor.shape.clone(), IntElem::::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Bool( + BoolOperationDescription::IntoInt(desc), + )); + + out + } + + fn bool_into_float(tensor: BoolTensor) -> FloatTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(tensor.shape.clone(), FloatElem::::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Bool( + BoolOperationDescription::IntoFloat(desc), + )); + + out + } + + fn bool_device(tensor: &BoolTensor) -> Device { + tensor.client.device() + } + + fn bool_to_device(tensor: BoolTensor, device: &Device) -> BoolTensor { + if &tensor.client.device() == device { + return tensor; + } + R::change_client_backend(tensor, device) + } + + fn bool_reshape(tensor: BoolTensor, shape: Shape) -> BoolTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(shape.into(), tensor.dtype); + + let desc = ReshapeDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::Reshape(desc), + )); + + out + } + + fn bool_slice( + tensor: BoolTensor, + ranges: &[core::ops::Range], + ) -> BoolTensor { + let client = tensor.client.clone(); + let ndims = tensor.shape().num_dims(); + let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); + + for i in shape.len()..ndims { + shape.push(tensor.shape[i]); + } + + let out = client.register_empty_tensor(shape, tensor.dtype); + + let desc = SliceOperationDescription { + tensor: tensor.into_description(), + ranges: ranges.to_vec(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::Slice(desc), + )); + + out + } + + fn bool_slice_assign( + tensor: BoolTensor, + ranges: &[core::ops::Range], + value: BoolTensor, + ) -> BoolTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype); + + let desc = SliceAssignOperationDescription { + tensor: tensor.into_description(), + ranges: ranges.to_vec(), + value: value.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::SliceAssign(desc), + )); + + out + } + + fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { + let client = lhs.client.clone(); + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::Equal(desc), + )); + + out + } + + fn bool_not(tensor: BoolTensor) -> BoolTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Bool(BoolOperationDescription::Not( + desc, + ))); + + out + } + + fn bool_swap_dims(tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { + let client = tensor.client.clone(); + let mut shape = tensor.shape.clone(); + shape[dim1] = tensor.shape[dim2]; + shape[dim2] = tensor.shape[dim1]; + let out = client.register_empty_tensor(shape, tensor.dtype); + + let desc = SwapDimsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + dim1, + dim2, + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::SwapDims(desc), + )); + + out + } + + fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { + let client = tensor.client.clone(); + // Change the shape of the tensor to match the new axes + let shape = axes.iter().map(|x| tensor.shape[*x]).collect(); + let out = client.register_empty_tensor(shape, tensor.dtype); + + let desc = PermuteOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + axes: axes.to_vec(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::Permute(desc), + )); + + out + } + + fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype); + + let desc = FlipOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + axes: axes.to_vec(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::Flip(desc), + )); + + out + } + + fn bool_expand(tensor: BoolTensor, shape: Shape) -> BoolTensor { + let client = tensor.client.clone(); + let shape: Vec<_> = shape.into(); + let out = client.register_empty_tensor(shape.clone(), tensor.dtype); + + let desc = ExpandOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + shape, + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::Expand(desc), + )); + + out + } + + fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { + let tensor_first = tensors.first().unwrap(); + let client = tensor_first.client.clone(); + let dtype = tensor_first.dtype; + + // Calculate the output shape + let mut shape = tensor_first.shape.clone(); + shape[dim] = 0; + for tensor in tensors.iter() { + shape[dim] += tensor.shape[dim]; + } + let out = client.register_empty_tensor(shape, dtype); + + let desc = CatOperationDescription { + tensors: tensors.into_iter().map(|t| t.into_description()).collect(), + dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::Cat(desc), + )); + + out + } + + fn bool_repeat_dim(tensor: BoolTensor, dim: usize, times: usize) -> BoolTensor { + let client = tensor.client.clone(); + let mut shape = tensor.shape.clone(); + shape[dim] *= times; + let out = client.register_empty_tensor(shape, tensor.dtype); + + let desc = RepeatDimOperationDescription { + tensor: tensor.into_description(), + dim, + times, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::RepeatDim(desc), + )); + + out + } +} diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs new file mode 100644 index 000000000..1859e525d --- /dev/null +++ b/crates/burn-router/src/ops/op_float.rs @@ -0,0 +1,1407 @@ +use alloc::{vec, vec::Vec}; +use burn_tensor::backend::Backend; +use core::ops::Range; + +use burn_tensor::ops::{ + binary_ops_shape, BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntElem, IntTensor, +}; +use burn_tensor::repr::{ + BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, + ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, + FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription, + MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, + PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, + RepeatDimOperationDescription, ReshapeDescription, ScalarOperationDescription, + ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, + SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, + UnaryOperationDescription, +}; +use burn_tensor::{DType, Device, Distribution, Element, ElementConversion, Shape, TensorData}; + +use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; + +impl FloatTensorOps for BackendRouter { + fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { + let client = get_client::(device); + client.register_tensor_data(data.convert::<::FloatElem>()) + } + + fn float_random( + shape: Shape, + distribution: Distribution, + device: &Device, + ) -> FloatTensor { + // Get the runtime client on which to register the operation for execution. + let client = get_client::(device); + let dtype = FloatElem::::dtype(); + let out = client.register_empty_tensor(shape.into(), dtype); + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Random(RandomOperationDescription { + out: out.to_description_out(), + distribution, + }), + )); + + out + } + + fn float_zeros(shape: Shape, device: &Device) -> FloatTensor { + // Get the runtime client on which to register the operation for execution. + let client = get_client::(device); + let dtype = FloatElem::::dtype(); + let out = client.register_empty_tensor(shape.into(), dtype); + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Zeros(out.to_description_out()), + )); + + out + } + + fn float_ones(shape: Shape, device: &Device) -> FloatTensor { + // Get the runtime client on which to register the operation for execution. + let client = get_client::(device); + let dtype = FloatElem::::dtype(); + let out = client.register_empty_tensor(shape.into(), dtype); + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Ones(out.to_description_out()), + )); + + out + } + + fn float_full( + shape: Shape, + fill_value: FloatElem, + device: &Device, + ) -> FloatTensor { + // Get the runtime client on which to register the operation for execution. + let client = get_client::(device); + let dtype = FloatElem::::dtype(); + let out = client.register_empty_tensor(shape.into(), dtype); + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Full((out.to_description_out(), fill_value.elem())), + )); + + out + } + + fn float_shape(tensor: &FloatTensor) -> Shape { + tensor.shape() + } + + async fn float_into_data(tensor: FloatTensor) -> TensorData { + tensor + .into_data() + .await + // Since underlying backends can have different data types, we convert to the current elem + .convert::<::FloatElem>() + } + + fn float_device(tensor: &FloatTensor) -> Device { + tensor.client.device() + } + + fn float_to_device(tensor: FloatTensor, device: &Device) -> FloatTensor { + if &tensor.client.device() == device { + return tensor; + } + R::change_client_backend(tensor, device) + } + + fn float_into_int(tensor: FloatTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), IntElem::::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::IntoInt(desc), + )); + + out + } + + fn float_empty(shape: Shape, device: &Device) -> FloatTensor { + // Get the runtime client on which to register the operation for execution. + let client = get_client::(device); + let out = client.register_empty_tensor(shape.into(), FloatElem::::dtype()); + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Empty(out.to_description_out()), + )); + + out + } + + fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Add(desc), + )); + + out + } + + fn float_add_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::AddScalar(desc), + )); + + out + } + + fn float_clamp( + tensor: FloatTensor, + min: FloatElem, + max: FloatElem, + ) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = ClampOperationDescription { + tensor: tensor.into_description(), + min: min.elem(), + max: max.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Clamp(desc), + )); + + out + } + + fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Sub(desc), + )); + + out + } + + fn float_sub_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::SubScalar(desc), + )); + + out + } + + fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Mul(desc), + )); + + out + } + + fn float_mul_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::MulScalar(desc), + )); + + out + } + + fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Div(desc), + )); + + out + } + + fn float_div_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::DivScalar(desc), + )); + + out + } + + fn float_remainder_scalar(lhs: FloatTensor, rhs: FloatElem) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::RemScalar(desc), + )); + + out + } + + fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + + let mut shape = binary_ops_shape(&lhs.shape, &rhs.shape); + let ndims = lhs.shape().num_dims(); + + shape[ndims - 2] = lhs.shape[ndims - 2]; + shape[ndims - 1] = rhs.shape[ndims - 1]; + let out = client.register_empty_tensor(shape, dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Matmul(desc), + )); + + out + } + + fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor { + let client = tensor.client.clone(); + let mut shape = tensor.shape.clone(); + shape[dim1] = tensor.shape[dim2]; + shape[dim2] = tensor.shape[dim1]; + let out = client.register_empty_tensor(shape, tensor.dtype); + + let desc = SwapDimsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + dim1, + dim2, + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::SwapDims(desc), + )); + + out + } + + fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(shape.into(), tensor.dtype); + + let desc = ReshapeDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Reshape(desc), + )); + + out + } + + fn float_gather( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + ) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(indices.shape.clone(), dtype); + + let desc = GatherOperationDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Gather(desc), + )); + + out + } + + fn float_scatter( + dim: usize, + tensor: FloatTensor, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = ScatterOperationDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Scatter(desc), + )); + + out + } + + fn float_select( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + ) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = indices.shape[0]; + let out = client.register_empty_tensor(shape, dtype); + + let desc = SelectOperationDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Select(desc), + )); + + out + } + + fn float_select_assign( + tensor: FloatTensor, + dim: usize, + indices: IntTensor, + value: FloatTensor, + ) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = SelectAssignOperationDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::SelectAssign(desc), + )); + + out + } + + fn float_slice(tensor: FloatTensor, ranges: &[Range]) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + + let ndims = tensor.shape().num_dims(); + let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); + + for i in shape.len()..ndims { + shape.push(tensor.shape[i]); + } + + let out = client.register_empty_tensor(shape, dtype); + + let desc = SliceOperationDescription { + tensor: tensor.into_description(), + ranges: ranges.to_vec(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Slice(desc), + )); + + out + } + + fn float_slice_assign( + tensor: FloatTensor, + ranges: &[Range], + value: FloatTensor, + ) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = SliceAssignOperationDescription { + tensor: tensor.into_description(), + ranges: ranges.to_vec(), + value: value.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::SliceAssign(desc), + )); + + out + } + + fn float_mask_where( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatTensor, + ) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let shape = binary_ops_shape(&tensor.shape, &mask.shape); + let out = client.register_empty_tensor(shape, dtype); + + let desc = MaskWhereOperationDescription { + tensor: tensor.into_description(), + mask: mask.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::MaskWhere(desc), + )); + + out + } + + fn float_mask_fill( + tensor: FloatTensor, + mask: BoolTensor, + value: FloatElem, + ) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = MaskFillOperationDescription { + tensor: tensor.into_description(), + mask: mask.into_description(), + value: value.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::MaskFill(desc), + )); + + out + } + + fn float_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { + let client = lhs.client.clone(); + let out = + client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Equal(desc), + )); + + out + } + + fn float_equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::EqualElem(desc), + )); + + out + } + + fn float_greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = + client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Greater(desc), + )); + + out + } + + fn float_greater_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::GreaterElem(desc), + )); + + out + } + + fn float_greater_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = + client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::GreaterEqual(desc), + )); + + out + } + + fn float_greater_equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::GreaterEqualElem(desc), + )); + + out + } + + fn float_lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = + client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Lower(desc), + )); + + out + } + + fn float_lower_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::LowerElem(desc), + )); + + out + } + + fn float_lower_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = + client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::LowerEqual(desc), + )); + + out + } + + fn float_lower_equal_elem(lhs: FloatTensor, rhs: FloatElem) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::LowerEqualElem(desc), + )); + + out + } + + fn float_sum(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(vec![1], dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Sum(desc), + )); + + out + } + + fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::SumDim(desc), + )); + + out + } + + fn float_prod(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(vec![1], dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Prod(desc), + )); + + out + } + + fn float_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::ProdDim(desc), + )); + + out + } + + fn float_mean(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(vec![1], dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Mean(desc), + )); + + out + } + + fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::MeanDim(desc), + )); + + out + } + + fn float_exp(lhs: FloatTensor) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: lhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Exp(desc), + )); + + out + } + + fn float_log(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Log(desc), + )); + + out + } + + fn float_log1p(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Log1p(desc), + )); + + out + } + + fn float_powf_scalar(lhs: FloatTensor, rhs: f32) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::PowfScalar(desc), + )); + + out + } + + fn float_sqrt(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Sqrt(desc), + )); + + out + } + + fn float_abs(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Abs(desc), + )); + + out + } + + fn float_cos(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Cos(desc), + )); + + out + } + + fn float_sin(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Sin(desc), + )); + + out + } + + fn float_tanh(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Tanh(desc), + )); + + out + } + + fn float_recip(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Recip(desc), + )); + + out + } + + fn float_erf(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Erf(desc), + )); + + out + } + + fn float_cat(tensors: Vec>, dim: usize) -> FloatTensor { + let tensor_first = tensors.first().unwrap(); + let client = tensor_first.client.clone(); + + // Calculate the output shape + let mut shape = tensor_first.shape.clone(); + shape[dim] = 0; + for tensor in tensors.iter() { + shape[dim] += tensor.shape[dim]; + } + let out = client.register_empty_tensor(shape, tensor_first.dtype); + + let desc = CatOperationDescription { + tensors: tensors + .into_iter() + .map(|tensor| tensor.into_description()) + .collect(), + dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Cat(desc), + )); + + out + } + + fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, IntElem::::dtype()); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::ArgMax(desc), + )); + + out + } + + fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { + let client = tensor.client.clone(); + let mut shape = tensor.shape.clone(); + shape[dim] *= times; + let out = client.register_empty_tensor(shape, tensor.dtype); + + let desc = RepeatDimOperationDescription { + tensor: tensor.into_description(), + dim, + times, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::RepeatDim(desc), + )); + + out + } + + fn float_argmin(tensor: FloatTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, IntElem::::dtype()); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::ArgMin(desc), + )); + + out + } + + fn float_max(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(vec![1], dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Max(desc), + )); + + out + } + + fn float_max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::MaxDim(desc), + )); + + out + } + + fn float_max_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape.clone(), dtype); + let out_indices = client.register_empty_tensor(shape, IntElem::::dtype()); + + let desc = ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::MaxDimWithIndices(desc), + )); + + (out, out_indices) + } + + fn float_min(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(vec![1], dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Min(desc), + )); + + out + } + + fn float_min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::MinDim(desc), + )); + + out + } + + fn float_min_dim_with_indices( + tensor: FloatTensor, + dim: usize, + ) -> (FloatTensor, IntTensor) { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape.clone(), dtype); + let out_indices = client.register_empty_tensor(shape, IntElem::::dtype()); + + let desc = ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::MinDimWithIndices(desc), + )); + + (out, out_indices) + } + + fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericFloat( + dtype, + NumericOperationDescription::Powf(desc), + )); + + out + } + + fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { + let client = tensor.client.clone(); + // Change the shape of the tensor to match the new axes + let shape = axes.iter().map(|x| tensor.shape[*x]).collect(); + let out = client.register_empty_tensor(shape, tensor.dtype); + + let desc = PermuteOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + axes: axes.to_vec(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Permute(desc), + )); + + out + } + + fn float_expand(tensor: FloatTensor, shape: Shape) -> FloatTensor { + let client = tensor.client.clone(); + let shape: Vec<_> = shape.into(); + let out = client.register_empty_tensor(shape.clone(), tensor.dtype); + + let desc = ExpandOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + shape, + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Expand(desc), + )); + + out + } + + fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype); + + let desc = FlipOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + axes: axes.to_vec(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::Flip(desc), + )); + + out + } +} diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs new file mode 100644 index 000000000..4fefa942b --- /dev/null +++ b/crates/burn-router/src/ops/op_int.rs @@ -0,0 +1,1159 @@ +use alloc::{vec, vec::Vec}; +use burn_tensor::backend::Backend; +use core::ops::Range; + +use burn_tensor::ops::{ + binary_ops_shape, BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, IntTensorOps, +}; +use burn_tensor::repr::{ + BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, + ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, + GatherOperationDescription, IntOperationDescription, MaskFillOperationDescription, + MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, + PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, + RepeatDimOperationDescription, ReshapeDescription, ScalarOperationDescription, + ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, + SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, + UnaryOperationDescription, +}; +use burn_tensor::{DType, Device, Distribution, Element, ElementConversion, Shape, TensorData}; + +use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; + +impl IntTensorOps for BackendRouter { + fn int_empty(shape: Shape, device: &Device) -> IntTensor { + // Get the runtime client on which to register the operation for execution. + let client = get_client::(device); + let out = client.register_empty_tensor(shape.into(), IntElem::::dtype()); + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::Empty(out.to_description_out()), + )); + + out + } + + fn int_shape(tensor: &IntTensor) -> Shape { + tensor.shape() + } + + async fn int_into_data(tensor: IntTensor) -> TensorData { + tensor + .into_data() + .await + // Since underlying backends can have different data types, we convert to the current elem + .convert::<::IntElem>() + } + + fn int_from_data(data: TensorData, device: &Device) -> IntTensor { + let client = get_client::(device); + client.register_tensor_data(data.convert::<::IntElem>()) + } + + fn int_device(tensor: &IntTensor) -> Device { + tensor.client.device() + } + + fn int_to_device(tensor: IntTensor, device: &Device) -> IntTensor { + if &tensor.client.device() == device { + return tensor; + } + R::change_client_backend(tensor, device) + } + + fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(shape.into(), tensor.dtype); + + let desc = ReshapeDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::Reshape(desc), + )); + + out + } + + fn int_slice(tensor: IntTensor, ranges: &[Range]) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + + let ndims = tensor.shape().num_dims(); + let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); + + for i in shape.len()..ndims { + shape.push(tensor.shape[i]); + } + + let out = client.register_empty_tensor(shape, dtype); + + let desc = SliceOperationDescription { + tensor: tensor.into_description(), + ranges: ranges.to_vec(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::Slice(desc), + )); + + out + } + + fn int_slice_assign( + tensor: IntTensor, + ranges: &[Range], + value: IntTensor, + ) -> IntTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype); + + let desc = SliceAssignOperationDescription { + tensor: tensor.into_description(), + ranges: ranges.to_vec(), + value: value.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::SliceAssign(desc), + )); + + out + } + + fn int_mask_where( + tensor: IntTensor, + mask: BoolTensor, + value: IntTensor, + ) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let shape = binary_ops_shape(&tensor.shape, &mask.shape); + let out = client.register_empty_tensor(shape, dtype); + + let desc = MaskWhereOperationDescription { + tensor: tensor.into_description(), + mask: mask.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::MaskWhere(desc), + )); + + out + } + + fn int_mask_fill( + tensor: IntTensor, + mask: BoolTensor, + value: IntElem, + ) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = MaskFillOperationDescription { + tensor: tensor.into_description(), + mask: mask.into_description(), + value: value.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::MaskFill(desc), + )); + + out + } + + fn int_gather( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + ) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(indices.shape.clone(), dtype); + + let desc = GatherOperationDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Gather(desc), + )); + + out + } + + fn int_scatter( + dim: usize, + tensor: IntTensor, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = ScatterOperationDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Scatter(desc), + )); + + out + } + + fn int_select( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + ) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = indices.shape[0]; + let out = client.register_empty_tensor(shape, dtype); + + let desc = SelectOperationDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Select(desc), + )); + + out + } + + fn int_select_assign( + tensor: IntTensor, + dim: usize, + indices: IntTensor, + value: IntTensor, + ) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = SelectAssignOperationDescription { + tensor: tensor.into_description(), + dim, + indices: indices.into_description(), + value: value.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::SelectAssign(desc), + )); + + out + } + + fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { + let tensor_first = tensors.first().unwrap(); + let client = tensor_first.client.clone(); + let dtype = tensor_first.dtype; + + // Calculate the output shape + let mut shape = tensor_first.shape.clone(); + shape[dim] = 0; + for tensor in tensors.iter() { + shape[dim] += tensor.shape[dim]; + } + let out = client.register_empty_tensor(shape, dtype); + + let desc = CatOperationDescription { + tensors: tensors.into_iter().map(|t| t.into_description()).collect(), + dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::Cat(desc), + )); + + out + } + + fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + let client = lhs.client.clone(); + let out = + client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::Equal(desc), + )); + + out + } + + fn int_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::EqualElem(desc), + )); + + out + } + + fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = + client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Greater(desc), + )); + + out + } + + fn int_greater_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::GreaterElem(desc), + )); + + out + } + + fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = + client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::GreaterEqual(desc), + )); + + out + } + + fn int_greater_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::GreaterEqualElem(desc), + )); + + out + } + + fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = + client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Lower(desc), + )); + + out + } + + fn int_lower_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::LowerElem(desc), + )); + + out + } + + fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = + client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), DType::Bool); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::LowerEqual(desc), + )); + + out + } + + fn int_lower_equal_elem(lhs: IntTensor, rhs: IntElem) -> BoolTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::LowerEqualElem(desc), + )); + + out + } + + fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Add(desc), + )); + + out + } + + fn int_add_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::AddScalar(desc), + )); + + out + } + + fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Sub(desc), + )); + + out + } + + fn int_sub_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::SubScalar(desc), + )); + + out + } + + fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Mul(desc), + )); + + out + } + + fn int_mul_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::MulScalar(desc), + )); + + out + } + + fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Div(desc), + )); + + out + } + + fn int_div_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::DivScalar(desc), + )); + + out + } + + fn int_remainder_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::RemScalar(desc), + )); + + out + } + + fn int_zeros(shape: Shape, device: &Device) -> IntTensor { + // Get the runtime client on which to register the operation for execution. + let client = get_client::(device); + let dtype = IntElem::::dtype(); + let out = client.register_empty_tensor(shape.dims.to_vec(), dtype); + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Zeros(out.to_description_out()), + )); + + out + } + + fn int_ones(shape: Shape, device: &Device) -> IntTensor { + // Get the runtime client on which to register the operation for execution. + let client = get_client::(device); + let dtype = IntElem::::dtype(); + let out = client.register_empty_tensor(shape.into(), dtype); + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Ones(out.to_description_out()), + )); + + out + } + + fn int_sum(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(vec![1], dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Sum(desc), + )); + + out + } + + fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::SumDim(desc), + )); + + out + } + + fn int_prod(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(vec![1], dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Prod(desc), + )); + + out + } + + fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::ProdDim(desc), + )); + + out + } + + fn int_mean(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(vec![1], dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Mean(desc), + )); + + out + } + + fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::MeanDim(desc), + )); + + out + } + + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, IntElem::::dtype()); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::ArgMax(desc), + )); + + out + } + + fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, IntElem::::dtype()); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::ArgMin(desc), + )); + + out + } + + fn int_clamp( + tensor: IntTensor, + min: IntElem, + max: IntElem, + ) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = ClampOperationDescription { + tensor: tensor.into_description(), + min: min.elem(), + max: max.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Clamp(desc), + )); + + out + } + + fn int_abs(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Abs(desc), + )); + + out + } + + fn int_into_float(tensor: IntTensor) -> FloatTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(tensor.shape.clone(), FloatElem::::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::IntoFloat(desc), + )); + + out + } + + fn int_swap_dims(tensor: IntTensor, dim1: usize, dim2: usize) -> IntTensor { + let client = tensor.client.clone(); + let mut shape = tensor.shape.clone(); + shape[dim1] = tensor.shape[dim2]; + shape[dim2] = tensor.shape[dim1]; + let out = client.register_empty_tensor(shape, tensor.dtype); + + let desc = SwapDimsDescription { + input: tensor.into_description(), + out: out.to_description_out(), + dim1, + dim2, + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::SwapDims(desc), + )); + + out + } + + fn int_max(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(vec![1], dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Max(desc), + )); + + out + } + + fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::MaxDim(desc), + )); + + out + } + + fn int_max_dim_with_indices( + tensor: IntTensor, + dim: usize, + ) -> (IntTensor, IntTensor) { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape.clone(), dtype); + let out_indices = client.register_empty_tensor(shape, IntElem::::dtype()); + + let desc = ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::MaxDimWithIndices(desc), + )); + + (out, out_indices) + } + + fn int_min(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(vec![1], dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::Min(desc), + )); + + out + } + + fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape, dtype); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::MinDim(desc), + )); + + out + } + + fn int_min_dim_with_indices( + tensor: IntTensor, + dim: usize, + ) -> (IntTensor, IntTensor) { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim] = 1; + let out = client.register_empty_tensor(shape.clone(), dtype); + let out_indices = client.register_empty_tensor(shape, IntElem::::dtype()); + + let desc = ReduceDimWithIndicesDescription { + tensor: tensor.into_description(), + dim, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }; + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::MinDimWithIndices(desc), + )); + + (out, out_indices) + } + + fn int_random( + shape: Shape, + distribution: Distribution, + device: &Device, + ) -> IntTensor { + // Get the runtime client on which to register the operation for execution. + let client = get_client::(device); + let dtype = IntElem::::dtype(); + let out = client.register_empty_tensor(shape.into(), dtype); + + client.register(OperationDescription::NumericInt( + dtype, + NumericOperationDescription::IntRandom(RandomOperationDescription { + out: out.to_description_out(), + distribution, + }), + )); + + out + } + + fn int_permute(tensor: IntTensor, axes: &[usize]) -> IntTensor { + let client = tensor.client.clone(); + // Change the shape of the tensor to match the new axes + let shape = axes.iter().map(|x| tensor.shape[*x]).collect(); + let out = client.register_empty_tensor(shape, tensor.dtype); + + let desc = PermuteOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + axes: axes.to_vec(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::Permute(desc), + )); + + out + } + + fn int_expand(tensor: IntTensor, shape: Shape) -> IntTensor { + let client = tensor.client.clone(); + let shape: Vec<_> = shape.into(); + let out = client.register_empty_tensor(shape.clone(), tensor.dtype); + + let desc = ExpandOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + shape, + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::Expand(desc), + )); + + out + } + + fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { + let client = tensor.client.clone(); + let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype); + + let desc = FlipOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + axes: axes.to_vec(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::Flip(desc), + )); + + out + } + + fn int_repeat_dim(tensor: IntTensor, dim: usize, times: usize) -> IntTensor { + let client = tensor.client.clone(); + let mut shape = tensor.shape.clone(); + shape[dim] *= times; + let out = client.register_empty_tensor(shape, tensor.dtype); + + let desc = RepeatDimOperationDescription { + tensor: tensor.into_description(), + dim, + times, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::RepeatDim(desc), + )); + + out + } +} diff --git a/crates/burn-router/src/ops/op_module.rs b/crates/burn-router/src/ops/op_module.rs new file mode 100644 index 000000000..d95ff7714 --- /dev/null +++ b/crates/burn-router/src/ops/op_module.rs @@ -0,0 +1,821 @@ +use alloc::{boxed::Box, vec}; + +use burn_tensor::ops::conv::{ + calculate_conv_output_size, calculate_conv_transpose_output_size, calculate_pool_output_size, +}; +use burn_tensor::ops::{ + ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor, + IntElem, ModuleOps, +}; +use burn_tensor::ops::{ + IntTensor, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, + MaxPool2dWithIndices, +}; +use burn_tensor::repr::{ + AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription, + AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription, + AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription, + AvgPool2dDescription, Conv1dDescription, Conv2dDescription, Conv3dDescription, + ConvTranspose1dDescription, ConvTranspose2dDescription, ConvTranspose3dDescription, + DeformConv2dBackwardDescription, DeformConv2dDescription, InterpolateBackwardDescription, + InterpolateDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, + MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription, + MaxPool2dWithIndicesDescription, ModuleOperationDescription, OperationDescription, +}; +use burn_tensor::Element; + +use crate::{BackendRouter, RunnerChannel, RunnerClient}; + +impl ModuleOps for BackendRouter { + fn conv1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<1>, + ) -> FloatTensor { + let size = calculate_conv_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.dilation[0], + x.shape[2], + ); + + let shape = vec![x.shape[0], weight.shape[0], size]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = Conv1dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options: options.into(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::Conv1d(desc), + )); + + out + } + + fn conv2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<2>, + ) -> FloatTensor { + let size_0 = calculate_conv_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.dilation[0], + x.shape[2], + ); + let size_1 = calculate_conv_output_size( + weight.shape[3], + options.stride[1], + options.padding[1], + options.dilation[1], + x.shape[3], + ); + + let shape = vec![x.shape[0], weight.shape[0], size_0, size_1]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = Conv2dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options: options.into(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::Conv2d(desc), + )); + + out + } + + fn conv3d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<3>, + ) -> FloatTensor { + let size_0 = calculate_conv_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.dilation[0], + x.shape[2], + ); + let size_1 = calculate_conv_output_size( + weight.shape[3], + options.stride[1], + options.padding[1], + options.dilation[1], + x.shape[3], + ); + let size_2 = calculate_conv_output_size( + weight.shape[4], + options.stride[2], + options.padding[2], + options.dilation[2], + x.shape[4], + ); + + let shape = vec![x.shape[0], weight.shape[0], size_0, size_1, size_2]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = Conv3dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options: options.into(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::Conv3d(desc), + )); + + out + } + + fn conv_transpose1d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<1>, + ) -> FloatTensor { + let size = calculate_conv_transpose_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.padding_out[0], + options.dilation[0], + x.shape[2], + ); + + let shape = vec![x.shape[0], weight.shape[1] * options.groups, size]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = ConvTranspose1dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options: options.into(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::ConvTranspose1d(desc), + )); + + out + } + + fn conv_transpose2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor { + let size_0 = calculate_conv_transpose_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.padding_out[0], + options.dilation[0], + x.shape[2], + ); + let size_1 = calculate_conv_transpose_output_size( + weight.shape[3], + options.stride[1], + options.padding[1], + options.padding_out[1], + options.dilation[1], + x.shape[3], + ); + + let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = ConvTranspose2dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options: options.into(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::ConvTranspose2d(desc), + )); + + out + } + + fn conv_transpose3d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<3>, + ) -> FloatTensor { + let size_0 = calculate_conv_transpose_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.padding_out[0], + options.dilation[0], + x.shape[2], + ); + let size_1 = calculate_conv_transpose_output_size( + weight.shape[3], + options.stride[1], + options.padding[1], + options.padding_out[1], + options.dilation[1], + x.shape[3], + ); + let size_2 = calculate_conv_transpose_output_size( + weight.shape[4], + options.stride[2], + options.padding[2], + options.padding_out[2], + options.dilation[2], + x.shape[4], + ); + + let shape = vec![ + x.shape[0], + weight.shape[1] * options.groups, + size_0, + size_1, + size_2, + ]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = ConvTranspose3dDescription { + x: x.into_description(), + weight: weight.into_description(), + bias: bias.map(|bias| bias.into_description()), + options: options.into(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::ConvTranspose3d(desc), + )); + + out + } + + fn avg_pool1d( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> FloatTensor { + let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]); + + let shape = vec![x.shape[0], x.shape[1], size]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = AvgPool1dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::AvgPool1d(desc), + )); + + out + } + + fn avg_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + let size_0 = + calculate_pool_output_size(kernel_size[0], stride[0], padding[0], 1, x.shape[2]); + let size_1 = + calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]); + + let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = AvgPool2dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::AvgPool2d(desc), + )); + + out + } + + fn avg_pool1d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + count_include_pad: bool, + ) -> FloatTensor { + let client = x.client.clone(); + let out = client.register_empty_tensor(x.shape.clone(), x.dtype); + + let desc = AvgPool1dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::AvgPool1dBackward(desc), + )); + + out + } + + fn avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ) -> FloatTensor { + let client = x.client.clone(); + let out = client.register_empty_tensor(x.shape.clone(), x.dtype); + + let desc = AvgPool2dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + kernel_size, + stride, + padding, + count_include_pad, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::AvgPool2dBackward(desc), + )); + + out + } + + fn max_pool1d( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> FloatTensor { + let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); + + let shape = vec![x.shape[0], x.shape[1], size]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = MaxPool1dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::MaxPool1d(desc), + )); + + out + } + + fn max_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> FloatTensor { + let size_0 = calculate_pool_output_size( + kernel_size[0], + stride[0], + padding[0], + dilation[0], + x.shape[2], + ); + let size_1 = calculate_pool_output_size( + kernel_size[1], + stride[1], + padding[1], + dilation[1], + x.shape[3], + ); + + let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = MaxPool2dDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::MaxPool2d(desc), + )); + + out + } + + fn max_pool1d_with_indices( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + ) -> MaxPool1dWithIndices { + let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]); + + let shape = vec![x.shape[0], x.shape[1], size]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape.clone(), x.dtype); + let out_indices = client.register_empty_tensor(shape, IntElem::::dtype()); + + let desc = MaxPool1dWithIndicesDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::MaxPool1dWithIndices(desc), + )); + + MaxPool1dWithIndices::new(out, out_indices) + } + + fn max_pool2d_with_indices( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> MaxPool2dWithIndices { + let size_0 = calculate_pool_output_size( + kernel_size[0], + stride[0], + padding[0], + dilation[0], + x.shape[2], + ); + let size_1 = calculate_pool_output_size( + kernel_size[1], + stride[1], + padding[1], + dilation[1], + x.shape[3], + ); + + let shape = vec![x.shape[0], x.shape[1], size_0, size_1]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape.clone(), x.dtype); + let out_indices = client.register_empty_tensor(shape, IntElem::::dtype()); + + let desc = MaxPool2dWithIndicesDescription { + x: x.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + out_indices: out_indices.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::MaxPool2dWithIndices(desc), + )); + + MaxPool2dWithIndices::new(out, out_indices) + } + + fn max_pool1d_with_indices_backward( + x: FloatTensor, + kernel_size: usize, + stride: usize, + padding: usize, + dilation: usize, + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool1dBackward { + let client = x.client.clone(); + let out = client.register_empty_tensor(x.shape.clone(), x.dtype); + + let desc = MaxPool1dWithIndicesBackwardDescription { + x: x.into_description(), + grad: output_grad.into_description(), + indices: indices.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::MaxPool1dWithIndicesBackward(desc), + )); + + MaxPool1dBackward::new(out) + } + + fn max_pool2d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool2dBackward { + let client = x.client.clone(); + let out = client.register_empty_tensor(x.shape.clone(), x.dtype); + + let desc = MaxPool2dWithIndicesBackwardDescription { + x: x.into_description(), + grad: output_grad.into_description(), + indices: indices.into_description(), + kernel_size, + stride, + padding, + dilation, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::MaxPool2dWithIndicesBackward(desc), + )); + + MaxPool2dBackward::new(out) + } + + fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { + let shape = vec![x.shape[0], x.shape[1], output_size]; + + let client = x.client.clone(); + let out = client.register_empty_tensor(shape.clone(), x.dtype); + + let desc = AdaptiveAvgPool1dDescription { + x: x.into_description(), + output_size, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::AdaptiveAvgPool1d(desc), + )); + + out + } + + fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { + let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; + + let client = x.client.clone(); + let out = client.register_empty_tensor(shape.clone(), x.dtype); + + let desc = AdaptiveAvgPool2dDescription { + x: x.into_description(), + output_size, + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::AdaptiveAvgPool2d(desc), + )); + + out + } + + fn adaptive_avg_pool1d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + let client = x.client.clone(); + let out = client.register_empty_tensor(x.shape.clone(), x.dtype); + + let desc = AdaptiveAvgPool1dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc), + )); + + out + } + + fn adaptive_avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + let client = x.client.clone(); + let out = client.register_empty_tensor(x.shape.clone(), x.dtype); + + let desc = AdaptiveAvgPool2dBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc), + )); + + out + } + + fn interpolate( + x: FloatTensor, + output_size: [usize; 2], + options: InterpolateOptions, + ) -> FloatTensor { + let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]]; + + let client = x.client.clone(); + let out = client.register_empty_tensor(shape.clone(), x.dtype); + + let desc = InterpolateDescription { + x: x.into_description(), + output_size, + options: options.into(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::Interpolate(desc), + )); + + out + } + + fn interpolate_backward( + x: FloatTensor, + grad: FloatTensor, + output_size: [usize; 2], + options: InterpolateOptions, + ) -> FloatTensor { + let client = x.client.clone(); + let out = client.register_empty_tensor(x.shape.clone(), x.dtype); + + let desc = InterpolateBackwardDescription { + x: x.into_description(), + grad: grad.into_description(), + output_size, + options: options.into(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::InterpolateBackward(desc), + )); + + out + } + + fn deform_conv2d( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + options: DeformConvOptions<2>, + ) -> FloatTensor { + let size_0 = calculate_conv_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.dilation[0], + x.shape[2], + ); + let size_1 = calculate_conv_output_size( + weight.shape[3], + options.stride[1], + options.padding[1], + options.dilation[1], + x.shape[3], + ); + + let shape = vec![x.shape[0], weight.shape[0], size_0, size_1]; + let client = x.client.clone(); + let out = client.register_empty_tensor(shape, x.dtype); + + let desc = DeformConv2dDescription { + x: x.into_description(), + offset: offset.into_description(), + weight: weight.into_description(), + mask: mask.map(|mask| mask.into_description()), + bias: bias.map(|bias| bias.into_description()), + options: options.into(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::DeformableConv2d(Box::new(desc)), + )); + + out + } + + fn deform_conv2d_backward( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + output_grad: FloatTensor, + options: DeformConvOptions<2>, + ) -> DeformConv2dBackward { + let client = x.client.clone(); + + let input_grad = client.register_empty_tensor(x.shape.clone(), x.dtype); + let offset_grad = client.register_empty_tensor(offset.shape.clone(), offset.dtype); + let weight_grad = client.register_empty_tensor(weight.shape.clone(), weight.dtype); + let mask_grad = mask + .as_ref() + .map(|mask| client.register_empty_tensor(mask.shape.clone(), mask.dtype)); + let bias_grad = bias + .as_ref() + .map(|bias| client.register_empty_tensor(bias.shape.clone(), bias.dtype)); + + let desc = DeformConv2dBackwardDescription { + x: x.into_description(), + offset: offset.into_description(), + weight: weight.into_description(), + mask: mask.map(|mask| mask.into_description()), + bias: bias.map(|bias| bias.into_description()), + options: options.into(), + out_grad: output_grad.into_description(), + input_grad: input_grad.to_description_out(), + offset_grad: offset_grad.to_description_out(), + weight_grad: weight_grad.to_description_out(), + mask_grad: mask_grad + .as_ref() + .map(|mask_grad| mask_grad.to_description_out()), + bias_grad: bias_grad + .as_ref() + .map(|bias_grad| bias_grad.to_description_out()), + }; + + client.register(OperationDescription::Module( + ModuleOperationDescription::DeformableConv2dBackward(Box::new(desc)), + )); + + DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad) + } +} diff --git a/crates/burn-router/src/ops/op_qfloat.rs b/crates/burn-router/src/ops/op_qfloat.rs new file mode 100644 index 000000000..1f4784ace --- /dev/null +++ b/crates/burn-router/src/ops/op_qfloat.rs @@ -0,0 +1,97 @@ +use core::ops::Range; + +use burn_tensor::{ + ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, + quantization::{QuantizationParametersPrimitive, QuantizationScheme}, + Device, Shape, TensorData, +}; + +use crate::{BackendRouter, RunnerChannel}; + +impl QTensorOps for BackendRouter { + fn q_from_data(_data: TensorData, _device: &Device) -> QuantizedTensor { + unimplemented!() + } + + fn quantize( + _tensor: FloatTensor, + _scheme: &QuantizationScheme, + _qparams: QuantizationParametersPrimitive, + ) -> QuantizedTensor { + unimplemented!() + } + + fn quantize_dynamic( + _tensor: FloatTensor, + _scheme: &QuantizationScheme, + ) -> QuantizedTensor { + unimplemented!() + } + + fn dequantize(_tensor: QuantizedTensor) -> FloatTensor { + unimplemented!() + } + + fn q_shape(_tensor: &QuantizedTensor) -> Shape { + unimplemented!() + } + + fn q_device(_tensor: &QuantizedTensor) -> Device { + unimplemented!() + } + + fn q_to_device( + _tensor: QuantizedTensor, + _device: &Device, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_reshape(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { + unimplemented!() + } + + async fn q_into_data(_tensor: QuantizedTensor) -> TensorData { + unimplemented!() + } + + fn q_swap_dims( + _tensor: QuantizedTensor, + _dim1: usize, + _dim2: usize, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_permute(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { + unimplemented!() + } + + fn q_flip(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { + unimplemented!() + } + + fn q_gather( + _dim: usize, + _tensor: QuantizedTensor, + _indices: IntTensor, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_select( + _tensor: QuantizedTensor, + _dim: usize, + _indices: IntTensor, + ) -> QuantizedTensor { + unimplemented!() + } + + fn q_slice(_tensor: QuantizedTensor, _ranges: &[Range]) -> QuantizedTensor { + unimplemented!() + } + + fn q_expand(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { + unimplemented!() + } +} diff --git a/crates/burn-router/src/ops/unary.rs b/crates/burn-router/src/ops/unary.rs new file mode 100644 index 000000000..11cd06443 --- /dev/null +++ b/crates/burn-router/src/ops/unary.rs @@ -0,0 +1,129 @@ +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! scalar_float_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_float_tensor::(&$desc.lhs); + let output = $ops(lhs, ElementConversion::elem($desc.rhs)); + + $handles.register_float_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! scalar_float_dim_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_float_tensor::(&$desc.lhs); + let output = $ops(lhs, $desc.rhs); + + $handles.register_float_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! scalar_float2int_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_float_tensor::(&$desc.lhs); + let output = $ops(lhs, $desc.rhs); + + $handles.register_int_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! scalar_float_cmp_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_float_tensor::(&$desc.lhs); + let output = $ops(lhs, ElementConversion::elem($desc.rhs)); + + $handles.register_bool_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! unary_float_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_float_tensor::(&$desc.input); + let output = $ops(lhs); + + $handles.register_float_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! scalar_int_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_int_tensor::(&$desc.lhs); + let output = $ops(lhs, ElementConversion::elem($desc.rhs)); + + $handles.register_int_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! int_float_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_int_tensor::(&$desc.lhs); + let output = $ops(lhs, ElementConversion::elem($desc.rhs)); + + $handles.register_int_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! scalar_int_dim_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_int_tensor::(&$desc.lhs); + let output = $ops(lhs, $desc.rhs); + + $handles.register_int_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! scalar_int_cmp_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_int_tensor::(&$desc.lhs); + let output = $ops(lhs, ElementConversion::elem($desc.rhs)); + + $handles.register_bool_tensor::(&$desc.out.id, output); + }}; +} + +#[allow(missing_docs)] +#[macro_export(local_inner_macros)] +macro_rules! unary_int_ops { + ( + $handles:expr, $desc:expr, $ops:expr + ) => {{ + let lhs = $handles.get_int_tensor::(&$desc.input); + let output = $ops(lhs); + + $handles.register_int_tensor::(&$desc.out.id, output); + }}; +} diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs new file mode 100644 index 000000000..87aa5b358 --- /dev/null +++ b/crates/burn-router/src/runner.rs @@ -0,0 +1,1222 @@ +use alloc::{sync::Arc, vec::Vec}; +use spin::Mutex; + +use burn_tensor::{ + backend::{Backend, BackendBridge}, + ops::FullPrecisionBackend, + repr::{ + BaseOperationDescription, BoolOperationDescription, FloatOperationDescription, + HandleContainer, IntOperationDescription, ModuleOperationDescription, + NumericOperationDescription, OperationDescription, ReprBackend, TensorDescription, + TensorId, TensorStatus, + }, + DType, Element, ElementConversion, Shape, TensorData, +}; + +use super::{RouterTensor, RunnerClient}; +use crate::{ + binary_float_cmp_ops, binary_float_ops, binary_int_cmp_ops, binary_int_ops, + scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_dim_ops, scalar_float_ops, + scalar_int_cmp_ops, scalar_int_dim_ops, scalar_int_ops, unary_float_ops, unary_int_ops, +}; + +/// A runner's context contains a [handle container](HandleContainer) to manage +/// (i.e., fetch and update) existing tensors. +pub struct RunnerContext { + /// Handle container to retrieve tensors based on their description. + handles: HandleContainer, +} + +impl RunnerContext { + /// Create a new (uninitialized) empty tensor and returns its corresponding [tensor id](TensorId). + fn create_empty_handle(&mut self) -> Arc { + self.handles.create_tensor_uninit() + } + + fn free_orphans(&mut self) { + // Passing an empty "remaining" tensor identifiers will remove the orphan handles from the container + self.handles.free_orphans(&[]) + } + + /// Set a tensor handle to be removed. + fn drop_tensor_handle(&mut self, id: TensorId) { + self.handles.handles_orphan.push(id); + } +} + +/// A runner is responsible for executing tensor operations for a given [intermediate backend](ReprBackend). +#[derive(Clone)] +pub struct Runner { + // Mutex for the mutable handles + context: Arc>>, + device: B::Device, +} + +impl Runner +where + // Restrict full precision backend handle to be the same + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, +{ + pub(crate) fn new(device: B::Device) -> Self { + Self { + context: Arc::new(Mutex::new(RunnerContext { + handles: HandleContainer::new(), + })), + device, + } + } + + /// Get the tensor handle for the given [tensor description](TensorDescription). + pub(crate) fn get_tensor_handle(&self, tensor: &TensorDescription) -> B::Handle { + let handles = &mut self.context.lock().handles; + handles.get_tensor_handle(tensor).handle + } + + /// Create a tensor with the given handle and shape. + pub(crate) fn register_tensor( + &self, + handle: B::Handle, + shape: Vec, + dtype: DType, + client: C, + ) -> RouterTensor { + let mut ctx = self.context.lock(); + let id = ctx.create_empty_handle(); + + ctx.handles.register_handle(*id.as_ref(), handle); + core::mem::drop(ctx); + + RouterTensor::new(id, shape, dtype, client) + } + + pub(crate) fn register_tensor_data_desc(&self, data: TensorData) -> TensorDescription { + let mut ctx = self.context.lock(); + let id = ctx.create_empty_handle(); + let shape = data.shape.clone(); + let dtype = data.dtype; + + if dtype.is_float() { + let tensor = B::float_from_data(data, &self.device); + ctx.handles.register_float_tensor::(&id, tensor) + } else if dtype.is_int() { + let tensor = B::int_from_data(data, &self.device); + ctx.handles.register_int_tensor::(&id, tensor) + } else if dtype.is_bool() { + let tensor = B::bool_from_data(data, &self.device); + ctx.handles.register_bool_tensor::(&id, tensor) + } else if let DType::QFloat(_) = dtype { + todo!(); + } + + core::mem::drop(ctx); + + TensorDescription { + id: *id, + shape, + status: TensorStatus::ReadWrite, + dtype, + } + } + + pub(crate) fn register_empty_tensor_desc( + &self, + shape: Vec, + dtype: DType, + ) -> TensorDescription { + let mut ctx = self.context.lock(); + let id = ctx.create_empty_handle(); + core::mem::drop(ctx); + + TensorDescription { + id: *id, + shape, + status: TensorStatus::NotInit, + dtype, + } + } + + pub(crate) fn register_float_tensor_desc( + &self, + shape: Vec, + full_precision: bool, + ) -> TensorDescription { + let dtype = if full_precision { + as Backend>::FloatElem::dtype() + } else { + B::FloatElem::dtype() + }; + self.register_empty_tensor_desc(shape, dtype) + } +} + +impl RunnerClient for Runner +where + // Restrict full precision backend handle to be the same + <::FullPrecisionBridge as BackendBridge>::Target: + ReprBackend, +{ + type Device = B::Device; + + /// Execute a tensor operation. + fn register(&self, op: OperationDescription) { + // Remove unused tensor handles + let mut ctx = self.context.lock(); + ctx.free_orphans(); + + let handles = &mut ctx.handles; + match &op { + // For every op: get the input(s), execute the operation and register the output(s) + OperationDescription::BaseFloat(op) => match op { + BaseOperationDescription::ToDevice(_) => unreachable!(), + BaseOperationDescription::Reshape(desc) => { + let tensor = handles.get_float_tensor::(&desc.input); + + let output = B::float_reshape(tensor, desc.out.shape.clone().into()); + handles.register_float_tensor::(&desc.out.id, output); + } + BaseOperationDescription::SwapDims(desc) => { + let tensor = handles.get_float_tensor::(&desc.input); + + let output = B::float_swap_dims(tensor, desc.dim1, desc.dim2); + handles.register_float_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Permute(desc) => { + let tensor = handles.get_float_tensor::(&desc.input); + + let output = B::float_permute(tensor, &desc.axes); + handles.register_float_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Flip(desc) => { + let tensor = handles.get_float_tensor::(&desc.input); + + let output = B::float_flip(tensor, &desc.axes); + handles.register_float_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Expand(desc) => { + let tensor = handles.get_float_tensor::(&desc.input); + + let output = B::float_expand(tensor, desc.shape.clone().into()); + handles.register_float_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Slice(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + + let output = B::float_slice(tensor, &desc.ranges); + handles.register_float_tensor::(&desc.out.id, output); + } + BaseOperationDescription::SliceAssign(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + let value = handles.get_float_tensor::(&desc.value); + + let output = B::float_slice_assign(tensor, &desc.ranges, value); + handles.register_float_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Equal(desc) => { + binary_float_cmp_ops!(handles, desc, B::float_equal) + } + BaseOperationDescription::RepeatDim(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + + let output = B::float_repeat_dim(tensor, desc.dim, desc.times); + handles.register_float_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Cat(desc) => { + let tensors = desc + .tensors + .iter() + .map(|tensor| handles.get_float_tensor::(tensor)) + .collect(); + + let output = B::float_cat(tensors, desc.dim); + handles.register_float_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Cast(desc) => { + let input_dtype = desc.input.dtype; + let out_dtype = desc.out.dtype; + let float_dtype = B::FloatElem::dtype(); + let full_dtype = as Backend>::FloatElem::dtype(); + + if input_dtype == float_dtype && out_dtype == full_dtype { + let tensor = handles.get_float_tensor::(&desc.input); + let output = B::float_into_full_precision(tensor); + handles + .register_float_tensor::>(&desc.out.id, output); + } else if input_dtype == full_dtype && out_dtype == float_dtype { + let tensor = + handles.get_float_tensor::>(&desc.input); + let output = B::float_from_full_precision(tensor); + handles.register_float_tensor::(&desc.out.id, output); + } else { + unimplemented!() // only cast to and from full precision + } + } + BaseOperationDescription::Empty(desc) => { + let shape = Shape::from(desc.shape.clone()); + let output = B::float_empty(shape, &self.device); + handles.register_float_tensor::(&desc.id, output); + } + }, + OperationDescription::BaseInt(op) => match op { + BaseOperationDescription::ToDevice(_) => unreachable!(), + BaseOperationDescription::Reshape(desc) => { + let tensor = handles.get_int_tensor::(&desc.input); + + let output = B::int_reshape(tensor, desc.out.shape.clone().into()); + handles.register_int_tensor::(&desc.out.id, output); + } + BaseOperationDescription::SwapDims(desc) => { + let tensor = handles.get_int_tensor::(&desc.input); + + let output = B::int_swap_dims(tensor, desc.dim1, desc.dim2); + handles.register_int_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Permute(desc) => { + let tensor = handles.get_int_tensor::(&desc.input); + + let output = B::int_permute(tensor, &desc.axes); + handles.register_int_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Flip(desc) => { + let tensor = handles.get_int_tensor::(&desc.input); + + let output = B::int_flip(tensor, &desc.axes); + handles.register_int_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Expand(desc) => { + let tensor = handles.get_int_tensor::(&desc.input); + + let output = B::int_expand(tensor, desc.shape.clone().into()); + handles.register_int_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Slice(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + + let output = B::int_slice(tensor, &desc.ranges); + handles.register_int_tensor::(&desc.out.id, output); + } + BaseOperationDescription::SliceAssign(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + let value = handles.get_int_tensor::(&desc.value); + + let output = B::int_slice_assign(tensor, &desc.ranges, value); + handles.register_int_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Equal(desc) => { + binary_int_cmp_ops!(handles, desc, B::int_equal) + } + BaseOperationDescription::RepeatDim(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + + let output = B::int_repeat_dim(tensor, desc.dim, desc.times); + handles.register_int_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Cat(desc) => { + let tensors = desc + .tensors + .iter() + .map(|tensor| handles.get_int_tensor::(tensor)) + .collect(); + + let output = B::int_cat(tensors, desc.dim); + handles.register_int_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Cast(_) => unreachable!(), + BaseOperationDescription::Empty(desc) => { + let shape = Shape::from(desc.shape.clone()); + let output = B::int_empty(shape, &self.device); + handles.register_int_tensor::(&desc.id, output); + } + }, + OperationDescription::BaseBool(op) => match op { + BaseOperationDescription::ToDevice(_) => unreachable!(), + BaseOperationDescription::Reshape(desc) => { + let tensor = handles.get_bool_tensor::(&desc.input); + + let output = B::bool_reshape(tensor, desc.out.shape.clone().into()); + handles.register_bool_tensor::(&desc.out.id, output); + } + BaseOperationDescription::SwapDims(desc) => { + let tensor = handles.get_bool_tensor::(&desc.input); + + let output = B::bool_swap_dims(tensor, desc.dim1, desc.dim2); + handles.register_bool_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Permute(desc) => { + let tensor = handles.get_bool_tensor::(&desc.input); + + let output = B::bool_permute(tensor, &desc.axes); + handles.register_bool_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Flip(desc) => { + let tensor = handles.get_bool_tensor::(&desc.input); + + let output = B::bool_flip(tensor, &desc.axes); + handles.register_bool_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Expand(desc) => { + let tensor = handles.get_bool_tensor::(&desc.input); + + let output = B::bool_expand(tensor, desc.shape.clone().into()); + handles.register_bool_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Slice(desc) => { + let tensor = handles.get_bool_tensor::(&desc.tensor); + + let output = B::bool_slice(tensor, &desc.ranges); + handles.register_bool_tensor::(&desc.out.id, output); + } + BaseOperationDescription::SliceAssign(desc) => { + let tensor = handles.get_bool_tensor::(&desc.tensor); + let value = handles.get_bool_tensor::(&desc.value); + + let output = B::bool_slice_assign(tensor, &desc.ranges, value); + handles.register_bool_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Equal(desc) => { + let lhs = handles.get_bool_tensor::(&desc.lhs); + let rhs = handles.get_bool_tensor::(&desc.rhs); + + let output = B::bool_equal(lhs, rhs); + handles.register_bool_tensor::(&desc.out.id, output); + } + BaseOperationDescription::RepeatDim(desc) => { + let tensor = handles.get_bool_tensor::(&desc.tensor); + + let output = B::bool_repeat_dim(tensor, desc.dim, desc.times); + handles.register_bool_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Cat(desc) => { + let tensors = desc + .tensors + .iter() + .map(|tensor| handles.get_bool_tensor::(tensor)) + .collect(); + + let output = B::bool_cat(tensors, desc.dim); + handles.register_bool_tensor::(&desc.out.id, output); + } + BaseOperationDescription::Cast(_) => unreachable!(), + BaseOperationDescription::Empty(desc) => { + let shape = Shape::from(desc.shape.clone()); + let output = B::bool_empty(shape, &self.device); + handles.register_bool_tensor::(&desc.id, output); + } + }, + OperationDescription::NumericFloat(_dtype, op) => match op { + NumericOperationDescription::Add(desc) => { + binary_float_ops!(handles, desc, B::float_add) + } + NumericOperationDescription::AddScalar(desc) => { + scalar_float_ops!(handles, desc, B::float_add_scalar) + } + NumericOperationDescription::Sub(desc) => { + binary_float_ops!(handles, desc, B::float_sub) + } + NumericOperationDescription::SubScalar(desc) => { + scalar_float_ops!(handles, desc, B::float_sub_scalar) + } + NumericOperationDescription::Div(desc) => { + binary_float_ops!(handles, desc, B::float_div) + } + NumericOperationDescription::DivScalar(desc) => { + scalar_float_ops!(handles, desc, B::float_div_scalar) + } + NumericOperationDescription::RemScalar(desc) => { + scalar_float_ops!(handles, desc, B::float_remainder_scalar) + } + NumericOperationDescription::Mul(desc) => { + binary_float_ops!(handles, desc, B::float_mul) + } + NumericOperationDescription::MulScalar(desc) => { + scalar_float_ops!(handles, desc, B::float_mul_scalar) + } + NumericOperationDescription::Abs(desc) => { + unary_float_ops!(handles, desc, B::float_abs) + } + NumericOperationDescription::Ones(desc) => { + let shape = Shape::from(desc.shape.clone()); + let output = B::float_ones(shape, &self.device); + handles.register_float_tensor::(&desc.id, output); + } + NumericOperationDescription::Zeros(desc) => { + let shape = Shape::from(desc.shape.clone()); + let output = B::float_zeros(shape, &self.device); + handles.register_float_tensor::(&desc.id, output); + } + NumericOperationDescription::Full((desc, elem)) => { + let shape = Shape::from(desc.shape.clone()); + let output = B::float_full(shape, elem.elem(), &self.device); + handles.register_float_tensor::(&desc.id, output); + } + NumericOperationDescription::Gather(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + let indices = handles.get_int_tensor::(&desc.indices); + + let output = B::float_gather(desc.dim, tensor, indices); + handles.register_float_tensor::(&desc.out.id, output); + } + NumericOperationDescription::Scatter(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + let indices = handles.get_int_tensor::(&desc.indices); + let value = handles.get_float_tensor::(&desc.value); + + let output = B::float_scatter(desc.dim, tensor, indices, value); + handles.register_float_tensor::(&desc.out.id, output); + } + NumericOperationDescription::Select(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + let indices = handles.get_int_tensor::(&desc.indices); + + let output = B::float_select(tensor, desc.dim, indices); + handles.register_float_tensor::(&desc.out.id, output); + } + NumericOperationDescription::SelectAssign(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + let indices = handles.get_int_tensor::(&desc.indices); + let value = handles.get_float_tensor::(&desc.value); + + let output = B::float_select_assign(tensor, desc.dim, indices, value); + handles.register_float_tensor::(&desc.out.id, output); + } + NumericOperationDescription::MaskWhere(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + let mask = handles.get_bool_tensor::(&desc.mask); + let value = handles.get_float_tensor::(&desc.value); + + let output = B::float_mask_where(tensor, mask, value); + handles.register_float_tensor::(&desc.out.id, output); + } + NumericOperationDescription::MaskFill(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + let mask = handles.get_bool_tensor::(&desc.mask); + + let output = B::float_mask_fill(tensor, mask, desc.value.elem()); + handles.register_float_tensor::(&desc.out.id, output); + } + NumericOperationDescription::MeanDim(desc) => { + scalar_float_dim_ops!(handles, desc, B::float_mean_dim) + } + NumericOperationDescription::Mean(desc) => { + unary_float_ops!(handles, desc, B::float_mean) + } + NumericOperationDescription::Sum(desc) => { + unary_float_ops!(handles, desc, B::float_sum) + } + NumericOperationDescription::SumDim(desc) => { + scalar_float_dim_ops!(handles, desc, B::float_sum_dim) + } + NumericOperationDescription::Prod(desc) => { + unary_float_ops!(handles, desc, B::float_prod) + } + NumericOperationDescription::ProdDim(desc) => { + scalar_float_dim_ops!(handles, desc, B::float_prod_dim) + } + NumericOperationDescription::EqualElem(desc) => { + scalar_float_cmp_ops!(handles, desc, B::float_equal_elem) + } + NumericOperationDescription::Greater(desc) => { + binary_float_cmp_ops!(handles, desc, B::float_greater) + } + NumericOperationDescription::GreaterElem(desc) => { + scalar_float_cmp_ops!(handles, desc, B::float_greater_elem) + } + NumericOperationDescription::GreaterEqual(desc) => { + binary_float_cmp_ops!(handles, desc, B::float_greater_equal) + } + NumericOperationDescription::GreaterEqualElem(desc) => { + scalar_float_cmp_ops!(handles, desc, B::float_greater_equal_elem) + } + NumericOperationDescription::Lower(desc) => { + binary_float_cmp_ops!(handles, desc, B::float_lower) + } + NumericOperationDescription::LowerElem(desc) => { + scalar_float_cmp_ops!(handles, desc, B::float_lower_elem) + } + NumericOperationDescription::LowerEqual(desc) => { + binary_float_cmp_ops!(handles, desc, B::float_lower_equal) + } + NumericOperationDescription::LowerEqualElem(desc) => { + scalar_float_cmp_ops!(handles, desc, B::float_lower_equal_elem) + } + NumericOperationDescription::ArgMax(desc) => { + scalar_float2int_ops!(handles, desc, B::float_argmax) + } + NumericOperationDescription::ArgMin(desc) => { + scalar_float2int_ops!(handles, desc, B::float_argmin) + } + NumericOperationDescription::Max(desc) => { + unary_float_ops!(handles, desc, B::float_max) + } + NumericOperationDescription::MaxDimWithIndices(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + + let (output, output_idx) = B::float_max_dim_with_indices(tensor, desc.dim); + handles.register_float_tensor::(&desc.out.id, output); + handles.register_int_tensor::(&desc.out_indices.id, output_idx); + } + NumericOperationDescription::MinDimWithIndices(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + + let (output, output_idx) = B::float_min_dim_with_indices(tensor, desc.dim); + handles.register_float_tensor::(&desc.out.id, output); + handles.register_int_tensor::(&desc.out_indices.id, output_idx); + } + NumericOperationDescription::Min(desc) => { + unary_float_ops!(handles, desc, B::float_min) + } + NumericOperationDescription::MaxDim(desc) => { + scalar_float_dim_ops!(handles, desc, B::float_max_dim) + } + NumericOperationDescription::MinDim(desc) => { + scalar_float_dim_ops!(handles, desc, B::float_min_dim) + } + NumericOperationDescription::Clamp(desc) => { + let tensor = handles.get_float_tensor::(&desc.tensor); + + let output = B::float_clamp(tensor, desc.min.elem(), desc.max.elem()); + handles.register_float_tensor::(&desc.out.id, output); + } + NumericOperationDescription::IntRandom(_) => unreachable!(), + NumericOperationDescription::Powf(desc) => { + binary_float_ops!(handles, desc, B::float_powf) + } + }, + OperationDescription::NumericInt(_dtype, op) => match op { + NumericOperationDescription::Add(desc) => { + binary_int_ops!(handles, desc, B::int_add) + } + NumericOperationDescription::AddScalar(desc) => { + scalar_int_ops!(handles, desc, B::int_add_scalar) + } + NumericOperationDescription::Sub(desc) => { + binary_int_ops!(handles, desc, B::int_sub) + } + NumericOperationDescription::SubScalar(desc) => { + scalar_int_ops!(handles, desc, B::int_sub_scalar) + } + NumericOperationDescription::Div(desc) => { + binary_int_ops!(handles, desc, B::int_div) + } + NumericOperationDescription::DivScalar(desc) => { + scalar_int_ops!(handles, desc, B::int_div_scalar) + } + NumericOperationDescription::RemScalar(desc) => { + scalar_int_ops!(handles, desc, B::int_remainder_scalar) + } + NumericOperationDescription::Mul(desc) => { + binary_int_ops!(handles, desc, B::int_mul) + } + NumericOperationDescription::MulScalar(desc) => { + scalar_int_ops!(handles, desc, B::int_mul_scalar) + } + NumericOperationDescription::Abs(desc) => { + unary_int_ops!(handles, desc, B::int_abs) + } + NumericOperationDescription::Ones(desc) => { + let shape = Shape::from(desc.shape.clone()); + let output = B::int_ones(shape, &self.device); + handles.register_int_tensor::(&desc.id, output); + } + NumericOperationDescription::Zeros(desc) => { + let shape = Shape::from(desc.shape.clone()); + let output = B::int_zeros(shape, &self.device); + handles.register_int_tensor::(&desc.id, output); + } + NumericOperationDescription::Full((desc, elem)) => { + let shape = Shape::from(desc.shape.clone()); + let output = B::int_full(shape, elem.elem(), &self.device); + handles.register_int_tensor::(&desc.id, output); + } + NumericOperationDescription::Gather(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + let indices = handles.get_int_tensor::(&desc.indices); + + let output = B::int_gather(desc.dim, tensor, indices); + handles.register_int_tensor::(&desc.out.id, output); + } + NumericOperationDescription::Scatter(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + let indices = handles.get_int_tensor::(&desc.indices); + let value = handles.get_int_tensor::(&desc.value); + + let output = B::int_scatter(desc.dim, tensor, indices, value); + handles.register_int_tensor::(&desc.out.id, output); + } + NumericOperationDescription::Select(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + let indices = handles.get_int_tensor::(&desc.indices); + + let output = B::int_select(tensor, desc.dim, indices); + handles.register_int_tensor::(&desc.out.id, output); + } + NumericOperationDescription::SelectAssign(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + let indices = handles.get_int_tensor::(&desc.indices); + let value = handles.get_int_tensor::(&desc.value); + + let output = B::int_select_assign(tensor, desc.dim, indices, value); + handles.register_int_tensor::(&desc.out.id, output); + } + NumericOperationDescription::MaskWhere(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + let mask = handles.get_bool_tensor::(&desc.mask); + let value = handles.get_int_tensor::(&desc.value); + + let output = B::int_mask_where(tensor, mask, value); + handles.register_int_tensor::(&desc.out.id, output); + } + NumericOperationDescription::MaskFill(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + let mask = handles.get_bool_tensor::(&desc.mask); + + let output = B::int_mask_fill(tensor, mask, desc.value.elem()); + handles.register_int_tensor::(&desc.out.id, output); + } + NumericOperationDescription::MeanDim(desc) => { + scalar_int_dim_ops!(handles, desc, B::int_mean_dim) + } + NumericOperationDescription::Mean(desc) => { + unary_int_ops!(handles, desc, B::int_mean) + } + NumericOperationDescription::Sum(desc) => { + unary_int_ops!(handles, desc, B::int_sum) + } + NumericOperationDescription::SumDim(desc) => { + scalar_int_dim_ops!(handles, desc, B::int_sum_dim) + } + NumericOperationDescription::Prod(desc) => { + unary_int_ops!(handles, desc, B::int_prod) + } + NumericOperationDescription::ProdDim(desc) => { + scalar_int_dim_ops!(handles, desc, B::int_prod_dim) + } + NumericOperationDescription::EqualElem(desc) => { + scalar_int_cmp_ops!(handles, desc, B::int_equal_elem) + } + NumericOperationDescription::Greater(desc) => { + binary_int_cmp_ops!(handles, desc, B::int_greater) + } + NumericOperationDescription::GreaterElem(desc) => { + scalar_int_cmp_ops!(handles, desc, B::int_greater_elem) + } + NumericOperationDescription::GreaterEqual(desc) => { + binary_int_cmp_ops!(handles, desc, B::int_greater_equal) + } + NumericOperationDescription::GreaterEqualElem(desc) => { + scalar_int_cmp_ops!(handles, desc, B::int_greater_equal_elem) + } + NumericOperationDescription::Lower(desc) => { + binary_int_cmp_ops!(handles, desc, B::int_lower) + } + NumericOperationDescription::LowerElem(desc) => { + scalar_int_cmp_ops!(handles, desc, B::int_lower_elem) + } + NumericOperationDescription::LowerEqual(desc) => { + binary_int_cmp_ops!(handles, desc, B::int_lower_equal) + } + NumericOperationDescription::LowerEqualElem(desc) => { + scalar_int_cmp_ops!(handles, desc, B::int_lower_equal_elem) + } + NumericOperationDescription::ArgMax(desc) => { + scalar_int_dim_ops!(handles, desc, B::int_argmax) + } + NumericOperationDescription::ArgMin(desc) => { + scalar_int_dim_ops!(handles, desc, B::int_argmin) + } + NumericOperationDescription::Max(desc) => { + unary_int_ops!(handles, desc, B::int_max) + } + NumericOperationDescription::MaxDimWithIndices(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + + let (output, output_idx) = B::int_max_dim_with_indices(tensor, desc.dim); + handles.register_int_tensor::(&desc.out.id, output); + handles.register_int_tensor::(&desc.out_indices.id, output_idx); + } + NumericOperationDescription::MinDimWithIndices(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + + let (output, output_idx) = B::int_min_dim_with_indices(tensor, desc.dim); + handles.register_int_tensor::(&desc.out.id, output); + handles.register_int_tensor::(&desc.out_indices.id, output_idx); + } + NumericOperationDescription::Min(desc) => { + unary_int_ops!(handles, desc, B::int_min) + } + NumericOperationDescription::MaxDim(desc) => { + scalar_int_dim_ops!(handles, desc, B::int_max_dim) + } + NumericOperationDescription::MinDim(desc) => { + scalar_int_dim_ops!(handles, desc, B::int_min_dim) + } + NumericOperationDescription::Clamp(desc) => { + let tensor = handles.get_int_tensor::(&desc.tensor); + + let output = B::int_clamp(tensor, desc.min.elem(), desc.max.elem()); + handles.register_int_tensor::(&desc.out.id, output); + } + NumericOperationDescription::IntRandom(desc) => { + let shape = Shape::from(desc.out.shape.clone()); + + let output = B::int_random(shape, desc.distribution, &self.device); + handles.register_int_tensor::(&desc.out.id, output); + } + NumericOperationDescription::Powf(desc) => { + let lhs = handles.get_int_tensor::(&desc.lhs); + let rhs = handles.get_float_tensor::(&desc.rhs); + + let output = B::int_powf(lhs, rhs); + handles.register_int_tensor::(&desc.out.id, output); + } + }, + OperationDescription::Bool(op) => match op { + BoolOperationDescription::IntoFloat(desc) => { + let tensor = handles.get_bool_tensor::(&desc.input); + + let output = B::bool_into_float(tensor); + handles.register_float_tensor::(&desc.out.id, output); + } + BoolOperationDescription::IntoInt(desc) => { + let tensor = handles.get_bool_tensor::(&desc.input); + + let output = B::bool_into_int(tensor); + handles.register_int_tensor::(&desc.out.id, output); + } + BoolOperationDescription::Not(desc) => { + let tensor = handles.get_bool_tensor::(&desc.input); + + let output = B::bool_not(tensor); + handles.register_bool_tensor::(&desc.out.id, output); + } + }, + OperationDescription::Int(op) => match op { + IntOperationDescription::IntoFloat(desc) => { + let tensor = handles.get_int_tensor::(&desc.input); + + let output = B::int_into_float(tensor); + handles.register_float_tensor::(&desc.out.id, output); + } + }, + OperationDescription::Float(_dtype, op) => match op { + FloatOperationDescription::Exp(desc) => { + unary_float_ops!(handles, desc, B::float_exp) + } + FloatOperationDescription::Log(desc) => { + unary_float_ops!(handles, desc, B::float_log) + } + FloatOperationDescription::Log1p(desc) => { + unary_float_ops!(handles, desc, B::float_log1p) + } + FloatOperationDescription::Erf(desc) => { + unary_float_ops!(handles, desc, B::float_erf) + } + FloatOperationDescription::PowfScalar(desc) => { + scalar_float_ops!(handles, desc, B::float_powf_scalar) + } + FloatOperationDescription::Sqrt(desc) => { + unary_float_ops!(handles, desc, B::float_sqrt) + } + FloatOperationDescription::Cos(desc) => { + unary_float_ops!(handles, desc, B::float_cos) + } + FloatOperationDescription::Sin(desc) => { + unary_float_ops!(handles, desc, B::float_sin) + } + FloatOperationDescription::Tanh(desc) => { + unary_float_ops!(handles, desc, B::float_tanh) + } + FloatOperationDescription::IntoInt(desc) => { + let tensor = handles.get_float_tensor::(&desc.input); + + let output = B::float_into_int(tensor); + handles.register_int_tensor::(&desc.out.id, output); + } + FloatOperationDescription::Matmul(desc) => { + binary_float_ops!(handles, desc, B::float_matmul) + } + FloatOperationDescription::Random(desc) => { + let shape = Shape::from(desc.out.shape.clone()); + + let output = B::float_random(shape, desc.distribution, &self.device); + handles.register_float_tensor::(&desc.out.id, output); + } + FloatOperationDescription::Recip(desc) => { + unary_float_ops!(handles, desc, B::float_recip) + } + FloatOperationDescription::Quantize(_) => todo!(), + FloatOperationDescription::Dequantize(_) => todo!(), + }, + OperationDescription::Module(op) => match op { + ModuleOperationDescription::Embedding(desc) => { + let weights = handles.get_float_tensor::(&desc.weights); + let indices = handles.get_int_tensor::(&desc.indices); + + let output = B::embedding(weights, indices); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::EmbeddingBackward(desc) => { + let weights = handles.get_float_tensor::(&desc.weights); + let indices = handles.get_int_tensor::(&desc.indices); + let output_grad = handles.get_float_tensor::(&desc.out_grad); + + let output = B::embedding_backward(weights, output_grad, indices); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::Conv1d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let weight = handles.get_float_tensor::(&desc.weight); + let bias = desc + .bias + .as_ref() + .map(|bias| handles.get_float_tensor::(bias)); + + let output = B::conv1d(x, weight, bias, desc.clone().options.into()); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::Conv2d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let weight = handles.get_float_tensor::(&desc.weight); + let bias = desc + .bias + .as_ref() + .map(|bias| handles.get_float_tensor::(bias)); + + let output = B::conv2d(x, weight, bias, desc.clone().options.into()); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::Conv3d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let weight = handles.get_float_tensor::(&desc.weight); + let bias = desc + .bias + .as_ref() + .map(|bias| handles.get_float_tensor::(bias)); + + let output = B::conv3d(x, weight, bias, desc.options.clone().into()); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::DeformableConv2d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let offset = handles.get_float_tensor::(&desc.offset); + let mask = desc + .mask + .as_ref() + .map(|mask| handles.get_float_tensor::(mask)); + let weight = handles.get_float_tensor::(&desc.weight); + let bias = desc + .bias + .as_ref() + .map(|bias| handles.get_float_tensor::(bias)); + + let output = B::deform_conv2d( + x, + offset, + weight, + mask, + bias, + desc.options.clone().into(), + ); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::DeformableConv2dBackward(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let offset = handles.get_float_tensor::(&desc.offset); + let mask = desc + .mask + .as_ref() + .map(|mask| handles.get_float_tensor::(mask)); + let weight = handles.get_float_tensor::(&desc.weight); + let bias = desc + .bias + .as_ref() + .map(|bias| handles.get_float_tensor::(bias)); + let output_grad = handles.get_float_tensor::(&desc.out_grad); + + let output = B::deform_conv2d_backward( + x, + offset, + weight, + mask, + bias, + output_grad, + desc.options.clone().into(), + ); + + handles.register_float_tensor::(&desc.input_grad.id, output.x_grad); + handles.register_float_tensor::(&desc.offset_grad.id, output.offset_grad); + handles.register_float_tensor::(&desc.weight_grad.id, output.weight_grad); + if let Some((mask_grad, field)) = output.mask_grad.zip(desc.mask_grad.as_ref()) + { + handles.register_float_tensor::(&field.id, mask_grad); + } + if let Some((bias_grad, field)) = output.bias_grad.zip(desc.bias_grad.as_ref()) + { + handles.register_float_tensor::(&field.id, bias_grad); + } + } + ModuleOperationDescription::ConvTranspose1d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let weight = handles.get_float_tensor::(&desc.weight); + let bias = desc + .bias + .as_ref() + .map(|bias| handles.get_float_tensor::(bias)); + + let output = B::conv_transpose1d(x, weight, bias, desc.options.clone().into()); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::ConvTranspose2d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let weight = handles.get_float_tensor::(&desc.weight); + let bias = desc + .bias + .as_ref() + .map(|bias| handles.get_float_tensor::(bias)); + + let output = B::conv_transpose2d(x, weight, bias, desc.options.clone().into()); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::ConvTranspose3d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let weight = handles.get_float_tensor::(&desc.weight); + let bias = desc + .bias + .as_ref() + .map(|bias| handles.get_float_tensor::(bias)); + + let output = B::conv_transpose3d(x, weight, bias, desc.options.clone().into()); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::AvgPool1d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + + let output = B::avg_pool1d( + x, + desc.kernel_size, + desc.stride, + desc.padding, + desc.count_include_pad, + ); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::AvgPool2d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + + let output = B::avg_pool2d( + x, + desc.kernel_size, + desc.stride, + desc.padding, + desc.count_include_pad, + ); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::AvgPool1dBackward(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let grad = handles.get_float_tensor::(&desc.grad); + + let output = B::avg_pool1d_backward( + x, + grad, + desc.kernel_size, + desc.stride, + desc.padding, + desc.count_include_pad, + ); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::AvgPool2dBackward(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let grad = handles.get_float_tensor::(&desc.grad); + + let output = B::avg_pool2d_backward( + x, + grad, + desc.kernel_size, + desc.stride, + desc.padding, + desc.count_include_pad, + ); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::AdaptiveAvgPool1d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + + let output = B::adaptive_avg_pool1d(x, desc.output_size); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::AdaptiveAvgPool2d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + + let output = B::adaptive_avg_pool2d(x, desc.output_size); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let grad = handles.get_float_tensor::(&desc.grad); + + let output = B::adaptive_avg_pool1d_backward(x, grad); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let grad = handles.get_float_tensor::(&desc.grad); + + let output = B::adaptive_avg_pool2d_backward(x, grad); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::MaxPool1d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + + let output = B::max_pool1d( + x, + desc.kernel_size, + desc.stride, + desc.padding, + desc.dilation, + ); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::MaxPool1dWithIndices(desc) => { + let x = handles.get_float_tensor::(&desc.x); + + let output = B::max_pool1d_with_indices( + x, + desc.kernel_size, + desc.stride, + desc.padding, + desc.dilation, + ); + handles.register_float_tensor::(&desc.out.id, output.output); + handles.register_int_tensor::(&desc.out_indices.id, output.indices); + } + ModuleOperationDescription::MaxPool1dWithIndicesBackward(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let output_grad = handles.get_float_tensor::(&desc.grad); + let indices = handles.get_int_tensor::(&desc.indices); + + let output = B::max_pool1d_with_indices_backward( + x, + desc.kernel_size, + desc.stride, + desc.padding, + desc.dilation, + output_grad, + indices, + ); + handles.register_float_tensor::(&desc.out.id, output.x_grad); + } + ModuleOperationDescription::MaxPool2d(desc) => { + let x = handles.get_float_tensor::(&desc.x); + + let output = B::max_pool2d( + x, + desc.kernel_size, + desc.stride, + desc.padding, + desc.dilation, + ); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::MaxPool2dWithIndices(desc) => { + let x = handles.get_float_tensor::(&desc.x); + + let output = B::max_pool2d_with_indices( + x, + desc.kernel_size, + desc.stride, + desc.padding, + desc.dilation, + ); + handles.register_float_tensor::(&desc.out.id, output.output); + handles.register_int_tensor::(&desc.out_indices.id, output.indices); + } + ModuleOperationDescription::MaxPool2dWithIndicesBackward(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let output_grad = handles.get_float_tensor::(&desc.grad); + let indices = handles.get_int_tensor::(&desc.indices); + + let output = B::max_pool2d_with_indices_backward( + x, + desc.kernel_size, + desc.stride, + desc.padding, + desc.dilation, + output_grad, + indices, + ); + handles.register_float_tensor::(&desc.out.id, output.x_grad); + } + ModuleOperationDescription::Interpolate(desc) => { + let x = handles.get_float_tensor::(&desc.x); + + let output = B::interpolate(x, desc.output_size, desc.options.clone().into()); + handles.register_float_tensor::(&desc.out.id, output); + } + ModuleOperationDescription::InterpolateBackward(desc) => { + let x = handles.get_float_tensor::(&desc.x); + let grad = handles.get_float_tensor::(&desc.grad); + + let output = B::interpolate_backward( + x, + grad, + desc.output_size, + desc.options.clone().into(), + ); + handles.register_float_tensor::(&desc.out.id, output); + } + }, + } + } + + async fn read_tensor(&self, tensor: TensorDescription) -> TensorData { + let mut ctx = self.context.lock(); + + if tensor.dtype.is_float() { + let tensor = ctx.handles.get_float_tensor::(&tensor); + B::float_into_data(tensor).await + } else if tensor.dtype.is_int() { + let tensor = ctx.handles.get_int_tensor::(&tensor); + B::int_into_data(tensor).await + } else if tensor.dtype.is_bool() { + let tensor = ctx.handles.get_bool_tensor::(&tensor); + B::bool_into_data(tensor).await + } else if let DType::QFloat(_) = tensor.dtype { + todo!() + } else { + unimplemented!() + } + } + + fn register_tensor_data(&self, data: TensorData) -> RouterTensor { + let desc = self.register_tensor_data_desc(data); + RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) + } + + fn register_empty_tensor(&self, shape: Vec, dtype: DType) -> RouterTensor { + let desc = self.register_empty_tensor_desc(shape, dtype); + RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) + } + + fn register_float_tensor(&self, shape: Vec, full_precision: bool) -> RouterTensor { + let desc = self.register_float_tensor_desc(shape, full_precision); + RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone()) + } + + fn device(&self) -> Self::Device { + self.device.clone() + } + + fn register_orphan(&self, id: &TensorId) { + self.context.lock().drop_tensor_handle(*id) + } + + fn sync(&self) { + B::sync(&self.device); + } + + fn seed(&self, seed: u64) { + B::seed(seed) + } +} diff --git a/crates/burn-router/src/tensor.rs b/crates/burn-router/src/tensor.rs new file mode 100644 index 000000000..c53f2069b --- /dev/null +++ b/crates/burn-router/src/tensor.rs @@ -0,0 +1,122 @@ +use alloc::{sync::Arc, vec::Vec}; + +use super::RunnerClient; +use burn_tensor::{ + repr::{TensorDescription, TensorId, TensorStatus}, + DType, Shape, TensorData, +}; + +/// Tensor primitive for the [router backend](crate::BackendRouter). +pub struct RouterTensor { + pub(crate) id: Arc, + pub(crate) shape: Vec, + pub(crate) dtype: DType, + pub(crate) client: C, + + // Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`. + // + // When a tensor is dropped and is still an orphan, we need to register it as such to avoid + // memory leak. + pub(crate) is_orphan: bool, +} + +impl RouterTensor { + pub(crate) fn new(id: Arc, shape: Vec, dtype: DType, client: C) -> Self { + Self { + id, + shape, + dtype, + client, + is_orphan: true, + } + } + + pub(crate) async fn into_data(self) -> TensorData { + self.client + .clone() + .read_tensor(self.into_description()) + .await + } + + pub(crate) fn into_description(mut self) -> TensorDescription { + let status = self.status(); + let mut shape_out = Vec::new(); + core::mem::swap(&mut self.shape, &mut shape_out); + + if let TensorStatus::ReadWrite = status { + self.is_orphan = false; + } + + TensorDescription { + status, + shape: shape_out, + id: *self.id.as_ref(), + dtype: self.dtype, + } + } + + pub(crate) fn to_description_out(&self) -> TensorDescription { + TensorDescription { + status: TensorStatus::NotInit, + shape: self.shape.clone(), + id: *self.id.as_ref(), + dtype: self.dtype, + } + } + + pub(crate) fn shape(&self) -> Shape { + Shape::from(self.shape.clone()) + } + + pub(crate) fn status(&self) -> TensorStatus { + if Arc::strong_count(&self.id) <= 1 { + TensorStatus::ReadWrite + } else { + TensorStatus::ReadOnly + } + } +} + +impl core::fmt::Debug for RouterTensor { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str( + format!( + "{{ id: {:?}, shape: {:?}, dtype: {:?}, should_drop: {:?}, device: {:?} }}", + self.id, + self.shape, + self.dtype, + self.is_orphan, + self.client.device().clone(), + ) + .as_str(), + ) + } +} + +impl Clone for RouterTensor { + fn clone(&self) -> Self { + Self { + id: self.id.clone(), + shape: self.shape.clone(), + client: self.client.clone(), + dtype: self.dtype, + is_orphan: self.is_orphan, + } + } +} + +impl Drop for RouterTensor { + fn drop(&mut self) { + if !self.is_orphan { + return; + } + + match self.status() { + TensorStatus::ReadWrite => { + self.client.register_orphan(&self.id); + } + TensorStatus::ReadOnly => {} + TensorStatus::NotInit => {} + } + } +} diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 42e53b9e1..ff6d36da0 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -49,6 +49,9 @@ hashbrown = { workspace = true } # no_std compatible serde = { workspace = true } serde_bytes = { workspace = true } +[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] +portable-atomic-util = { workspace = true } + [dev-dependencies] rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std diff --git a/crates/burn-tensor/src/repr/backend.rs b/crates/burn-tensor/src/repr/backend.rs index 9853a2eaf..dad341697 100644 --- a/crates/burn-tensor/src/repr/backend.rs +++ b/crates/burn-tensor/src/repr/backend.rs @@ -4,16 +4,27 @@ use crate::{ quantization::QuantizationScheme, Shape, }; -use alloc::vec::Vec; /// A tensor representation containing a reference to a tensor resource with a given shape. -pub struct TensorHandle { +#[derive(Clone)] +pub struct TensorHandle { /// The type that can be used to point to a tensor of any kind. pub handle: H, /// The shape associated to the tensor. pub shape: Shape, } +/// A simple struct to encapsulate a quantized tensor kind. +#[derive(Clone)] +pub struct QuantizedKind { + /// The quantized tensor. + pub tensor: T, + /// The scaling factor. + pub scale: T, + /// The zero-point offset. + pub offset: Option, +} + /// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation /// for compilation purpose or other... pub trait ReprBackend: Backend { @@ -28,7 +39,7 @@ pub trait ReprBackend: Backend { fn bool_tensor(handle: TensorHandle) -> BoolTensor; /// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive). fn quantized_tensor( - handles: Vec>, + handle: QuantizedKind>, scheme: QuantizationScheme, ) -> QuantizedTensor; @@ -40,5 +51,33 @@ pub trait ReprBackend: Backend { fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle; /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle). /// A quantized tensor has multiple handles for the tensor itself and the quantization parameters. - fn quantized_tensor_handle(tensor: QuantizedTensor) -> Vec; + fn quantized_tensor_handle(tensor: QuantizedTensor) -> QuantizedKind; +} + +/// Handle which points to a backend tensor primitive kind. +#[derive(Clone, Debug)] +pub enum HandleKind { + /// Float tensor handle. + Float(B::FloatTensorPrimitive), + /// Int tensor handle. + Int(B::IntTensorPrimitive), + /// Bool tensor handle. + Bool(B::BoolTensorPrimitive), + /// Quantized tensor handle. + Quantized(B::QuantizedTensorPrimitive), + /// Empty handle (used as a dummy representation). + Empty, +} + +impl HandleKind { + /// Returns the handle kind name. + pub fn name(&self) -> &str { + match self { + HandleKind::Float(_) => "float", + HandleKind::Int(_) => "int", + HandleKind::Bool(_) => "bool", + HandleKind::Quantized(_) => "quantized", + HandleKind::Empty => unreachable!(), // should not happen + } + } } diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 88b4b8934..f0af17a9b 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -5,9 +5,16 @@ use crate::{ }, Shape, }; -use std::{collections::HashMap, sync::Arc}; +use alloc::vec::Vec; +use hashbrown::HashMap; -use super::{QuantizedTensorDescription, TensorHandle}; +#[cfg(target_has_atomic = "ptr")] +use alloc::sync::Arc; + +#[cfg(not(target_has_atomic = "ptr"))] +use portable_atomic_util::Arc; + +use super::{QuantizedKind, QuantizedTensorDescription, TensorHandle}; /// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources /// are used optimally. @@ -19,6 +26,16 @@ pub struct HandleContainer { pub handles_orphan: Vec, } +impl core::fmt::Debug for HandleContainer { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("HandleContainer") + .field("handles", &self.handles.keys()) // only care about the IDs when debugging + .field("counter", &self.counter) + .field("handles_orphan", &self.handles_orphan) + .finish() + } +} + /// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state pub enum Handle { /// No [tensor handle](ReprBackend::Handle) has been created yet @@ -69,7 +86,7 @@ impl HandleContainer { } /// Get the tensor handle for the given [tensor description](TensorDescription). - fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle { + pub fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle { TensorHandle { handle: self.get_handle(&tensor.id, &tensor.status), shape: Shape::from(&tensor.shape), @@ -112,12 +129,14 @@ impl HandleContainer { where B: ReprBackend, { - let qtensor = self.get_tensor_handle(&tensor.tensor); - let scale = self.get_tensor_handle(&tensor.qparams.scale); - let handles = if let Some(offset) = &tensor.qparams.offset { - vec![qtensor, scale, self.get_tensor_handle(offset)] - } else { - vec![qtensor, scale] + let handles = QuantizedKind { + tensor: self.get_tensor_handle(&tensor.tensor), + scale: self.get_tensor_handle(&tensor.qparams.scale), + offset: tensor + .qparams + .offset + .as_ref() + .map(|offset| self.get_tensor_handle(offset)), }; B::quantized_tensor(handles, tensor.scheme.clone()) } @@ -134,20 +153,20 @@ impl HandleContainer { /// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId). pub fn register_quantized_tensor( &mut self, - ids: &[&TensorId], + id: &QuantizedKind, tensor: B::QuantizedTensorPrimitive, ) where B: ReprBackend, { let handles = B::quantized_tensor_handle(tensor); - assert_eq!( - ids.len(), - handles.len(), - "Number of tensor ids and handles must match" - ); - for (handle, id) in handles.into_iter().zip(ids) { - self.handles.insert(**id, Handle::Existing(handle)); + self.handles + .insert(id.tensor, Handle::Existing(handles.tensor)); + self.handles + .insert(id.scale, Handle::Existing(handles.scale)); + + if let (Some(id), Some(handle)) = (id.offset, handles.offset) { + self.handles.insert(id, Handle::Existing(handle)); } } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 4e5148bdd..9102b4ac7 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -1,5 +1,8 @@ +use core::ops::Range; use serde::{Deserialize, Serialize}; -use std::ops::Range; + +use alloc::boxed::Box; +use alloc::{vec, vec::Vec}; use crate::{ ops::{ @@ -214,6 +217,13 @@ pub enum BaseOperationDescription { Cat(CatOperationDescription), /// Cast operation, no direct operation and should be supported by fusion backend. Cast(UnaryOperationDescription), + + /// Operation corresponding to: + /// + /// Float => [equal](crate::ops::FloatTensorOps::float_empty). + /// Int => [equal](crate::ops::IntTensorOps::int_empty). + /// Bool => [equal](crate::ops::BoolTensorOps::bool_empty). + Empty(TensorDescription), } /// Numeric operations on int and float tensors. @@ -1286,6 +1296,7 @@ impl BaseOperationDescription { } BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(), BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out], + BaseOperationDescription::Empty(desc) => vec![desc], } } } @@ -1605,7 +1616,7 @@ impl ModuleOperationDescription { } impl core::hash::Hash for RandomOperationDescription { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.out.hash(state); match self.distribution { @@ -1618,14 +1629,14 @@ impl core::hash::Hash for RandomOperationDescription { } impl core::hash::Hash for ScalarOperationDescription { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.lhs.hash(state); self.out.hash(state); } } impl core::hash::Hash for MaskFillOperationDescription { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.tensor.hash(state); self.mask.hash(state); self.out.hash(state); @@ -1633,14 +1644,14 @@ impl core::hash::Hash for MaskFillOperationDescription { } impl core::hash::Hash for ClampOperationDescription { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.tensor.hash(state); self.out.hash(state); } } impl core::hash::Hash for NumericOperationDescription { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { match self { NumericOperationDescription::Add(desc) => desc.hash(state), NumericOperationDescription::AddScalar(desc) => desc.hash(state), diff --git a/crates/burn-tensor/src/repr/tensor.rs b/crates/burn-tensor/src/repr/tensor.rs index a68d6b9c2..eda9007f7 100644 --- a/crates/burn-tensor/src/repr/tensor.rs +++ b/crates/burn-tensor/src/repr/tensor.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use alloc::vec::Vec; + use crate::DType; /// The tensor unique identifier. diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index fecfcd68e..5f13bb52c 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -59,8 +59,8 @@ pub trait Backend: + ActivationOps + QTensorOps + Clone - + Sized + Default + + Sized + Send + Sync + core::fmt::Debug diff --git a/crates/burn-tensor/src/tensor/element/base.rs b/crates/burn-tensor/src/tensor/element/base.rs index 17382be91..f7b062d2d 100644 --- a/crates/burn-tensor/src/tensor/element/base.rs +++ b/crates/burn-tensor/src/tensor/element/base.rs @@ -258,3 +258,19 @@ pub enum DType { Bool, QFloat(QuantizationStrategy), } + +impl DType { + /// Returns true if the data type is a floating point type. + pub fn is_float(&self) -> bool { + matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16) + } + /// Returns true if the data type is a signed integer type. + pub fn is_int(&self) -> bool { + matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8) + } + + /// Returns true if the data type is a boolean type + pub fn is_bool(&self) -> bool { + matches!(self, DType::Bool) + } +} diff --git a/crates/burn-tensor/src/tensor/ops/binary.rs b/crates/burn-tensor/src/tensor/ops/binary.rs new file mode 100644 index 000000000..d89766773 --- /dev/null +++ b/crates/burn-tensor/src/tensor/ops/binary.rs @@ -0,0 +1,12 @@ +use alloc::vec::Vec; + +/// Computes the output shape for binary operations with broadcasting support. +pub fn binary_ops_shape(lhs: &[usize], rhs: &[usize]) -> Vec { + let mut shape_out = Vec::with_capacity(lhs.len()); + + for (l, r) in lhs.iter().zip(rhs.iter()) { + shape_out.push(usize::max(*l, *r)); + } + + shape_out +} diff --git a/crates/burn-tensor/src/tensor/ops/mod.rs b/crates/burn-tensor/src/tensor/ops/mod.rs index 1cce56258..79a1b4f61 100644 --- a/crates/burn-tensor/src/tensor/ops/mod.rs +++ b/crates/burn-tensor/src/tensor/ops/mod.rs @@ -1,5 +1,6 @@ mod activation; mod alias; +mod binary; mod bool_tensor; mod int_tensor; mod modules; @@ -8,6 +9,7 @@ mod tensor; pub use activation::*; pub use alias::*; +pub use binary::*; pub use bool_tensor::*; pub use int_tensor::*; pub use modules::*;