Add `BackendRouter` to handle multiple backends (#2353)

* WIP

* it compiles

* WIP

* Remove const D (w/ rebase) + WIP BackendRouter

* Add missing types from merge

* Rework traits, types and add MultiBackendBridge & RunnerClientLocator (WIP)

* First draft ByteBridge to_backend(tensor, device)

* Refactor into modules

* Add mutex, fix types

* Remove StreamId and implement ReprBackend for Fusion (WIP)

* float_add op working (w/o fusion)

* Small cleanup

* Remove comment

* Cleanup

* Fix fusion ReprBackend implementation (duhhh)

* Add runner ops

* More ops

* Cleanup

* Add name

* Update Cargo.lock

* Undo fusion stream changes to common

* Clippy + cleanup

* Fix no-std

* Deal with unused tensors

* Clippy baby

* Fix comment

* Fix tensor handle orphans management

* Implement runner read_tensor for other dtypes

* Move backend router to its own crate

* Refactor repr quantized tensor handle

* Fix typo

* Implement repr backend for ndarray

* Add router tests w/ ndarray and wgpu backends (+ fix tests)

* Add empty base operation and autodiff tests

* Add precision bridge

* Remove dep from local changes

* Apply same mask_where broadcast fix

* Add float and int elem associative types (should match for each backend)

* Add simple byte bridge test

* Remove comment

* Remove bridge types

* Screw you windows

* Remove dead code

* Add unreachable message

* Add seed

* Fix clippy

* Set the seed anytime a new client is initialized

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
This commit is contained in:
Guillaume Lagrange 2024-10-18 14:23:26 -04:00 committed by GitHub
parent 604dbae57d
commit eaf50e617c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 6826 additions and 286 deletions

13
Cargo.lock generated
View File

@ -706,6 +706,18 @@ dependencies = [
"serde",
]
[[package]]
name = "burn-router"
version = "0.15.0"
dependencies = [
"burn-autodiff",
"burn-ndarray",
"burn-tensor",
"burn-wgpu",
"hashbrown 0.14.5",
"spin",
]
[[package]]
name = "burn-tch"
version = "0.15.0"
@ -732,6 +744,7 @@ dependencies = [
"half",
"hashbrown 0.14.5",
"num-traits",
"portable-atomic-util",
"rand",
"rand_distr",
"serde",

View File

@ -4,8 +4,8 @@ use crate::{
};
use burn_tensor::{
backend::{Backend, DeviceOps},
ops::FloatTensor,
repr::{OperationDescription, ReprBackend},
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle},
Device,
};
use serde::{de::DeserializeOwned, Serialize};
@ -166,3 +166,43 @@ pub trait FusionBackend:
/// Pointer to the full precision fusion backend.
type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
}
// Fusion implements `ReprBackend` to enable router backend usage.
impl<B: FusionBackend> ReprBackend for Fusion<B> {
type Handle = FusionTensor<B::FusionRuntime>;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
handle.handle
}
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
handle.handle
}
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
handle.handle
}
fn quantized_tensor(
_handles: QuantizedKind<TensorHandle<Self::Handle>>,
_scheme: burn_tensor::quantization::QuantizationScheme,
) -> QuantizedTensor<Self> {
todo!() // not as simple
}
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
tensor
}
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
tensor
}
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
tensor
}
fn quantized_tensor_handle(_tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle> {
todo!() // not as simple
}
}

View File

@ -73,16 +73,6 @@ macro_rules! binary_int_cmp_ops {
};
}
pub(crate) fn binary_ops_shape(lhs: &[usize], rhs: &[usize]) -> Vec<usize> {
let mut shape_out = Vec::with_capacity(lhs.len());
for (l, r) in lhs.iter().zip(rhs.iter()) {
shape_out.push(usize::max(*l, *r));
}
shape_out
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_int_ops {

View File

@ -1,5 +1,5 @@
use burn_tensor::{
ops::{FloatTensor, IntTensor},
ops::{binary_ops_shape, FloatTensor, IntTensor},
DType, Element, TensorData,
};
use std::marker::PhantomData;
@ -7,7 +7,6 @@ use std::marker::PhantomData;
use crate::{
client::FusionClient,
get_client,
ops::binary::binary_ops_shape,
stream::{execution::Operation, StreamId},
Fusion, FusionBackend,
};

View File

@ -1,14 +1,12 @@
use crate::{
binary_float_cmp_ops, binary_float_ops,
client::FusionClient,
get_client,
ops::binary::binary_ops_shape,
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
get_client, scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
stream::{execution::Operation, StreamId},
unary_float_ops, Fusion, FusionBackend,
};
use burn_tensor::{
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
ops::{binary_ops_shape, BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
repr::*,
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData,
};

View File

@ -1,14 +1,12 @@
use crate::{
binary_int_cmp_ops, binary_int_ops,
client::FusionClient,
get_client,
ops::binary::binary_ops_shape,
scalar_int_cmp_ops, scalar_int_ops,
get_client, scalar_int_cmp_ops, scalar_int_ops,
stream::{execution::Operation, StreamId},
unary_int_ops, Fusion, FusionBackend,
};
use burn_tensor::{
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
ops::{binary_ops_shape, BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
repr::{self, *},
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData,
};

View File

@ -6,6 +6,7 @@ use burn_tensor::{
repr::{
DequantizeOperationDescription, FloatOperationDescription, HandleContainer,
OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription,
QuantizedKind,
},
DType, Device, Element, Shape, TensorData,
};
@ -25,19 +26,17 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
let tensor = B::q_from_data(data, device);
let shape = B::q_shape(&tensor);
let mut handles = B::quantized_tensor_handle(tensor);
let handles = B::quantized_tensor_handle(tensor);
let qparams = match strategy {
QuantizationStrategy::PerTensorAffineInt8(_) => {
let num_handles = handles.len();
assert_eq!(
num_handles, 3,
"Expected 3 handles for quantized tensor, got {num_handles}"
);
let offset = handles.pop().unwrap();
let scale = handles.pop().unwrap();
let offset = if let Some(offset) = handles.offset {
offset
} else {
panic!("Expected offset for quantized tensor.");
};
FusionQuantizationParameters {
scale: client.register_tensor(
scale,
handles.scale,
vec![1],
StreamId::current(),
B::FloatElem::dtype(),
@ -51,15 +50,13 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
}
}
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
let num_handles = handles.len();
assert_eq!(
num_handles, 2,
"Expected 2 handles for quantized tensor, got {num_handles}"
assert!(
handles.offset.is_none(),
"Offset should not be provided for symmetric quantization."
);
let scale = handles.pop().unwrap();
FusionQuantizationParameters {
scale: client.register_tensor(
scale,
handles.scale,
vec![1],
StreamId::current(),
B::FloatElem::dtype(),
@ -69,7 +66,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
}
};
let qtensor = client.register_tensor(
handles.pop().unwrap(),
handles.tensor,
shape.dims,
StreamId::current(),
B::QuantizedEncoding::dtype(),
@ -111,17 +108,20 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
let qparams = QuantizationParametersPrimitive { scale, offset };
let output = B::quantize(tensor, &self.desc.scheme, qparams);
if let Some(offset) = &self.desc.qparams.offset {
handles.register_quantized_tensor::<B>(
&[&self.desc.out.id, &self.desc.qparams.scale.id, &offset.id],
output,
);
} else {
handles.register_quantized_tensor::<B>(
&[&self.desc.out.id, &self.desc.qparams.scale.id],
output,
);
let q_ids = if let Some(offset) = &self.desc.qparams.offset {
QuantizedKind {
tensor: self.desc.out.id,
scale: self.desc.qparams.scale.id,
offset: Some(offset.id),
}
} else {
QuantizedKind {
tensor: self.desc.out.id,
scale: self.desc.qparams.scale.id,
offset: None,
}
};
handles.register_quantized_tensor::<B>(&q_ids, output);
}
}

View File

@ -3,7 +3,8 @@ use crate::{
FusionBackend, FusionRuntime,
};
use burn_tensor::repr::{
HandleContainer, OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId,
HandleContainer, OperationDescription, QuantizedKind, QuantizedTensorDescription,
TensorDescription, TensorId,
};
use std::sync::Arc;
@ -183,18 +184,28 @@ where
let scale_id = server_device.create_empty_handle();
let offset_id = server_device.create_empty_handle();
let q_ids = QuantizedKind {
tensor: *tensor_id,
scale: *scale_id,
offset: Some(*offset_id),
};
server_device
.handles
.register_quantized_tensor::<B>(&[&tensor_id, &scale_id, &offset_id], tensor);
.register_quantized_tensor::<B>(&q_ids, tensor);
vec![tensor_id, scale_id, offset_id]
} else {
let tensor_id = server_device.create_empty_handle();
let scale_id = server_device.create_empty_handle();
let q_ids = QuantizedKind {
tensor: *tensor_id,
scale: *scale_id,
offset: None,
};
server_device
.handles
.register_quantized_tensor::<B>(&[&tensor_id, &scale_id], tensor);
.register_quantized_tensor::<B>(&q_ids, tensor);
vec![tensor_id, scale_id]
}

View File

@ -968,6 +968,9 @@ impl RelativeOps for BaseOperationDescription {
out: desc.out.to_relative(converter),
})
}
BaseOperationDescription::Empty(desc) => {
BaseOperationDescription::Empty(desc.to_relative(converter))
}
}
}
}

View File

@ -31,7 +31,7 @@ template = []
burn-common = { path = "../burn-common", version = "0.15.0" }
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = [
"cubecl",
"cubecl", "repr",
] }
cubecl = { workspace = true, features = ["linalg"] }

View File

@ -7,6 +7,13 @@ use cubecl::server::ComputeServer;
use rand::{rngs::StdRng, SeedableRng};
use std::{marker::PhantomData, sync::Mutex};
#[cfg(not(feature = "fusion"))]
use burn_tensor::{
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
quantization::QuantizationScheme,
repr::{HandleKind, QuantizedKind, ReprBackend, TensorHandle},
};
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
/// Generic tensor backend that can be compiled just-in-time to any shader runtime
@ -82,3 +89,62 @@ where
type JitDevice = R::Device;
type JitServer = R::Server;
}
#[cfg(not(feature = "fusion"))]
impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R, F, I> {
type Handle = HandleKind<Self>;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
match handle.handle {
HandleKind::Float(handle) => handle,
_ => panic!("Expected float handle, got {}", handle.handle.name()),
}
}
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
match handle.handle {
HandleKind::Int(handle) => handle,
_ => panic!("Expected int handle, got {}", handle.handle.name()),
}
}
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
match handle.handle {
HandleKind::Bool(handle) => handle,
_ => panic!("Expected bool handle, got {}", handle.handle.name()),
}
}
fn quantized_tensor(
handles: QuantizedKind<TensorHandle<Self::Handle>>,
_scheme: QuantizationScheme,
) -> QuantizedTensor<Self> {
let handle = handles.tensor.handle;
match handle {
HandleKind::Quantized(handle) => handle,
_ => panic!("Expected quantized handle, got {}", handle.name()),
}
}
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
HandleKind::Float(tensor)
}
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
HandleKind::Int(tensor)
}
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
HandleKind::Bool(tensor)
}
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle> {
QuantizedKind {
tensor: HandleKind::Quantized(tensor),
// The quantized tensor primitive already encapsulates the required quantization
// parameters so we set the scale as an empty handle (unused).
scale: HandleKind::Empty,
offset: None,
}
}
}

View File

@ -7,7 +7,7 @@ use crate::{
};
use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
use burn_tensor::quantization::QuantizationScheme;
use burn_tensor::repr::TensorHandle;
use burn_tensor::repr::{QuantizedKind, TensorHandle};
use burn_tensor::{repr::ReprBackend, Shape};
use core::marker::PhantomData;
use cubecl::client::ComputeClient;
@ -79,41 +79,22 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R
}
fn quantized_tensor(
handles: Vec<TensorHandle<Self::Handle>>,
handles: QuantizedKind<TensorHandle<Self::Handle>>,
scheme: QuantizationScheme,
) -> burn_tensor::ops::QuantizedTensor<Self> {
match handles.len() {
// NOTE: the order of the handles is known [qtensor, scale, <offset>]
3 => {
let mut handles = handles;
let offset = handles.pop().unwrap();
let scale = handles.pop().unwrap();
let qtensor = handles.pop().unwrap();
let qtensor = handles.tensor.handle.into_tensor(handles.tensor.shape);
let scale = handles.scale.handle.into_tensor(handles.scale.shape);
let offset = handles.offset;
let qparams = JitQuantizationParameters {
scale,
offset: offset.map(|h| h.handle.into_tensor(h.shape)),
};
QJitTensor {
qtensor: qtensor.handle.into_tensor(qtensor.shape),
qtensor,
scheme,
qparams: JitQuantizationParameters {
scale: scale.handle.into_tensor(scale.shape),
offset: Some(offset.handle.into_tensor(offset.shape)),
},
}
}
2 => {
let mut handles = handles;
let scale = handles.pop().unwrap();
let qtensor = handles.pop().unwrap();
QJitTensor {
qtensor: qtensor.handle.into_tensor(qtensor.shape),
scheme,
qparams: JitQuantizationParameters {
scale: scale.handle.into_tensor(scale.shape),
offset: None,
},
}
}
_ => {
panic!("Expected handles for the quantized tensor and its quantization parameters.")
}
qparams,
}
}
@ -131,14 +112,14 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R
fn quantized_tensor_handle(
tensor: burn_tensor::ops::QuantizedTensor<Self>,
) -> Vec<Self::Handle> {
) -> QuantizedKind<Self::Handle> {
let qtensor: JitFusionHandle<R> = tensor.qtensor.into();
let scale: JitFusionHandle<R> = tensor.qparams.scale.into();
if let Some(offset) = tensor.qparams.offset {
let offset: JitFusionHandle<R> = offset.into();
vec![qtensor, scale, offset]
} else {
vec![qtensor, scale]
QuantizedKind {
tensor: qtensor,
scale,
offset: tensor.qparams.offset.map(|offset| offset.into()),
}
}
}

View File

@ -45,7 +45,7 @@ blas-openblas-system = [
burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", optional = true }
burn-common = { path = "../burn-common", version = "0.15.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false, features = ["repr"] }
atomic_float = { workspace = true }
blas-src = { workspace = true, default-features = false, optional = true } # no-std compatible

View File

@ -1,9 +1,12 @@
use crate::element::{FloatNdArrayElement, QuantElement};
use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
use crate::PrecisionBridge;
use crate::{NdArrayQTensor, NdArrayTensor};
use alloc::string::String;
use burn_common::stub::Mutex;
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
use burn_tensor::quantization::QuantizationScheme;
use burn_tensor::repr::{HandleKind, QuantizedKind, ReprBackend, TensorHandle};
use core::marker::PhantomData;
use rand::{rngs::StdRng, SeedableRng};
@ -35,20 +38,21 @@ impl Default for NdArrayDevice {
/// This backend is compatible with CPUs and can be compiled for almost any platform, including
/// `wasm`, `arm`, and `x86`.
#[derive(Clone, Copy, Default, Debug)]
pub struct NdArray<E = f32, Q = i8> {
pub struct NdArray<E = f32, I = i64, Q = i8> {
_e: PhantomData<E>,
_i: PhantomData<I>,
_q: PhantomData<Q>,
}
impl<E: FloatNdArrayElement, Q: QuantElement> Backend for NdArray<E, Q> {
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q> {
type Device = NdArrayDevice;
type FullPrecisionBridge = PrecisionBridge<f32>;
type FloatTensorPrimitive = NdArrayTensor<E>;
type FloatElem = E;
type IntTensorPrimitive = NdArrayTensor<i64>;
type IntElem = i64;
type IntTensorPrimitive = NdArrayTensor<I>;
type IntElem = I;
type BoolTensorPrimitive = NdArrayTensor<bool>;
@ -69,3 +73,63 @@ impl<E: FloatNdArrayElement, Q: QuantElement> Backend for NdArray<E, Q> {
*seed = Some(rng);
}
}
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ReprBackend
for NdArray<E, I, Q>
{
type Handle = HandleKind<Self>;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
match handle.handle {
HandleKind::Float(handle) => handle,
_ => panic!("Expected float handle, got {}", handle.handle.name()),
}
}
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
match handle.handle {
HandleKind::Int(handle) => handle,
_ => panic!("Expected int handle, got {}", handle.handle.name()),
}
}
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
match handle.handle {
HandleKind::Bool(handle) => handle,
_ => panic!("Expected bool handle, got {}", handle.handle.name()),
}
}
fn quantized_tensor(
handles: QuantizedKind<TensorHandle<Self::Handle>>,
_scheme: QuantizationScheme,
) -> QuantizedTensor<Self> {
let handle = handles.tensor.handle;
match handle {
HandleKind::Quantized(handle) => handle,
_ => panic!("Expected quantized handle, got {}", handle.name()),
}
}
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
HandleKind::Float(tensor)
}
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
HandleKind::Int(tensor)
}
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
HandleKind::Bool(tensor)
}
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle> {
QuantizedKind {
tensor: HandleKind::Quantized(tensor),
// The quantized tensor primitive already encapsulates the required quantization
// parameters so we set the scale as an empty handle (unused).
scale: HandleKind::Empty,
offset: None,
}
}
}

View File

@ -1,4 +1,7 @@
use crate::{element::QuantElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
use crate::{
element::{IntNdArrayElement, QuantElement},
FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor,
};
use burn_tensor::{backend::BackendBridge, ops::FloatTensor};
use core::marker::PhantomData;
@ -8,13 +11,15 @@ pub struct PrecisionBridge<E: FloatNdArrayElement> {
_e: PhantomData<E>,
}
impl<TElem, OElem, QElem> BackendBridge<NdArray<OElem, QElem>> for PrecisionBridge<TElem>
impl<TElem, OElem, QElem, IntElem> BackendBridge<NdArray<OElem, IntElem, QElem>>
for PrecisionBridge<TElem>
where
TElem: FloatNdArrayElement,
OElem: FloatNdArrayElement,
QElem: QuantElement,
IntElem: IntNdArrayElement,
{
type Target = NdArray<TElem>;
type Target = NdArray<TElem, IntElem, QElem>;
fn into_target(
tensor: FloatTensor<NdArray<OElem>>,

View File

@ -16,6 +16,8 @@ where
{
}
pub trait IntNdArrayElement: NdArrayElement + core::ops::Rem<Output = Self> + Signed {}
/// A general element for ndarray backend.
pub trait NdArrayElement:
Element
@ -49,6 +51,9 @@ impl QuantElement for i8 {}
impl FloatNdArrayElement for f64 {}
impl FloatNdArrayElement for f32 {}
impl IntNdArrayElement for i64 {}
impl IntNdArrayElement for i32 {}
macro_rules! make_elem {
(
double

View File

@ -1,11 +1,13 @@
use crate::{
element::{FloatNdArrayElement, QuantElement},
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
tensor::NdArrayTensor,
NdArray,
};
use burn_tensor::{ops::ActivationOps, ElementConversion};
impl<E: FloatNdArrayElement, Q: QuantElement> ActivationOps<Self> for NdArray<E, Q> {
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ActivationOps<Self>
for NdArray<E, I, Q>
{
fn relu(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
let zero = 0.elem();
let array = tensor

View File

@ -261,10 +261,10 @@ where
}
}
pub fn gather(
pub fn gather<I: NdArrayElement>(
dim: usize,
mut tensor: NdArrayTensor<E>,
mut indices: NdArrayTensor<i64>,
mut indices: NdArrayTensor<I>,
) -> NdArrayTensor<E> {
let ndims = tensor.shape().num_dims();
if dim != ndims - 1 {
@ -284,7 +284,7 @@ where
let indices = indices.slice(s!(b, ..));
for (i, index) in indices.iter().enumerate() {
output[[b, i]] = tensor[[b, *index as usize]];
output[[b, i]] = tensor[[b, index.elem::<i64>() as usize]];
}
}
@ -300,10 +300,10 @@ where
output
}
pub fn scatter(
pub fn scatter<I: NdArrayElement>(
dim: usize,
mut tensor: NdArrayTensor<E>,
mut indices: NdArrayTensor<i64>,
mut indices: NdArrayTensor<I>,
mut value: NdArrayTensor<E>,
) -> NdArrayTensor<E> {
let ndims = tensor.shape().num_dims();
@ -338,7 +338,7 @@ where
let indices = indices.slice(s!(b, ..));
for (i, index) in indices.iter().enumerate() {
let index = *index as usize;
let index = index.elem::<i64>() as usize;
tensor[[b, index]] += value[[b, i]];
}
}
@ -403,33 +403,33 @@ where
batch_size
}
pub fn select(
pub fn select<I: NdArrayElement>(
tensor: NdArrayTensor<E>,
dim: usize,
indices: NdArrayTensor<i64>,
indices: NdArrayTensor<I>,
) -> NdArrayTensor<E> {
let array = tensor.array.select(
Axis(dim),
&indices
.array
.into_iter()
.map(|i| i as usize)
.map(|i| i.elem::<i64>() as usize)
.collect::<Vec<_>>(),
);
NdArrayTensor::new(array.into_shared())
}
pub fn select_assign(
pub fn select_assign<I: NdArrayElement>(
tensor: NdArrayTensor<E>,
dim: usize,
indices: NdArrayTensor<i64>,
indices: NdArrayTensor<I>,
value: NdArrayTensor<E>,
) -> NdArrayTensor<E> {
let mut output_array = tensor.array.into_owned();
for (index_value, index) in indices.array.into_iter().enumerate() {
let mut view = output_array.index_axis_mut(Axis(dim), index as usize);
let mut view = output_array.index_axis_mut(Axis(dim), index.elem::<i64>() as usize);
let value = value.array.index_axis(Axis(dim), index_value);
view.zip_mut_with(&value, |a, b| *a += *b);
@ -437,11 +437,11 @@ where
NdArrayTensor::new(output_array.into_shared())
}
pub fn argmax(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<i64> {
pub fn argmax<I: NdArrayElement>(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
arg(tensor, dim, CmpType::Max)
}
pub fn argmin(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<i64> {
pub fn argmin<I: NdArrayElement>(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
arg(tensor, dim, CmpType::Min)
}
@ -523,11 +523,11 @@ enum CmpType {
Max,
}
fn arg<E: NdArrayElement>(
fn arg<E: NdArrayElement, I: NdArrayElement>(
tensor: NdArrayTensor<E>,
dim: usize,
cmp: CmpType,
) -> NdArrayTensor<i64> {
) -> NdArrayTensor<I> {
let mut reshape = tensor.array.shape().to_vec();
reshape[dim] = 1;
@ -546,7 +546,7 @@ fn arg<E: NdArrayElement>(
}
});
idx as i64
(idx as i64).elem()
});
let output = output.to_shape(Dim(reshape.as_slice())).unwrap();

View File

@ -7,7 +7,7 @@ use core::ops::Range;
use ndarray::{IntoDimension, Zip};
// Current crate
use crate::element::{FloatNdArrayElement, QuantElement};
use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
use crate::NdArrayDevice;
use crate::{tensor::NdArrayTensor, NdArray};
@ -16,7 +16,9 @@ use burn_tensor::{backend::Backend, Shape, TensorData};
use super::NdArrayOps;
impl<E: FloatNdArrayElement, Q: QuantElement> BoolTensorOps<Self> for NdArray<E, Q> {
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
for NdArray<E, I, Q>
{
fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<bool> {
NdArrayTensor::from_data(data)
}
@ -43,11 +45,11 @@ impl<E: FloatNdArrayElement, Q: QuantElement> BoolTensorOps<Self> for NdArray<E,
NdArrayOps::slice(tensor, ranges)
}
fn bool_into_int(tensor: NdArrayTensor<bool>) -> NdArrayTensor<i64> {
fn bool_into_int(tensor: NdArrayTensor<bool>) -> NdArrayTensor<I> {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
NdArray::<E>::int_from_data(
TensorData::new(values, shape).convert::<i64>(),
NdArray::<E, I>::int_from_data(
TensorData::new(values, shape).convert::<I>(),
&NdArrayDevice::Cpu,
)
}

View File

@ -11,7 +11,7 @@ use ndarray::{
};
use crate::{
element::{FloatNdArrayElement, QuantElement},
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
ops::padding::{apply_padding_4d, apply_padding_5d},
sharing::UnsafeSharedRef,
tensor::NdArrayTensor,
@ -98,7 +98,7 @@ fn conv3d_mad_inner<E: FloatNdArrayElement>(
}
}
pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
pub(crate) fn conv2d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E>,
weight: NdArrayTensor<E>,
bias: Option<NdArrayTensor<E>>,
@ -126,7 +126,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
in_width,
);
let x = apply_padding_4d::<E, Q>(x, options.padding, 0i32.elem()).array;
let x = apply_padding_4d::<E, I, Q>(x, options.padding, 0i32.elem()).array;
// Convert inputs from dynamic indexes to static to improve perf.
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
@ -310,7 +310,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
NdArrayTensor::new(output.into_dyn().into_shared())
}
pub(crate) fn conv3d<E: FloatNdArrayElement, Q: QuantElement>(
pub(crate) fn conv3d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E>,
weight: NdArrayTensor<E>,
bias: Option<NdArrayTensor<E>>,
@ -345,7 +345,7 @@ pub(crate) fn conv3d<E: FloatNdArrayElement, Q: QuantElement>(
in_width,
);
let x = apply_padding_5d::<E, Q>(x, options.padding, 0i32.elem()).array;
let x = apply_padding_5d::<E, I, Q>(x, options.padding, 0i32.elem()).array;
// Convert inputs from dynamic indexes to static to improve perf.
let x = x.into_dimensionality::<ndarray::Ix5>().unwrap();

View File

@ -252,7 +252,7 @@ pub mod backward {
#[cfg(target_has_atomic = "32")]
use core::sync::atomic::Ordering;
use crate::NdArray;
use crate::{element::IntNdArrayElement, NdArray};
use atomic_float::AtomicF32;
use burn_tensor::ops::DeformConv2dBackward;
use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4};
@ -260,7 +260,11 @@ pub mod backward {
use super::*;
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
pub(crate) fn deform_conv2d_backward<F: FloatNdArrayElement, Q: QuantElement>(
pub(crate) fn deform_conv2d_backward<
F: FloatNdArrayElement,
I: IntNdArrayElement,
Q: QuantElement,
>(
input: NdArrayTensor<F>,
offset: NdArrayTensor<F>,
weight: NdArrayTensor<F>,
@ -268,7 +272,7 @@ pub mod backward {
bias: Option<NdArrayTensor<F>>,
out_grad: NdArrayTensor<F>,
args: DeformConvOptions<2>,
) -> DeformConv2dBackward<NdArray<F, Q>> {
) -> DeformConv2dBackward<NdArray<F, I, Q>> {
let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims();
let [_, _, kernel_h, kernel_w] = weight.shape().dims();
let groups = args.weight_groups;

View File

@ -11,8 +11,8 @@ use ndarray::IntoDimension;
use ndarray::Zip;
// Current crate
use crate::element::ExpElement;
use crate::element::FloatNdArrayElement;
use crate::element::IntNdArrayElement;
use crate::element::QuantElement;
use crate::{tensor::NdArrayTensor, NdArray};
use crate::{NdArrayDevice, SEED};
@ -22,71 +22,73 @@ use burn_tensor::{backend::Backend, Shape, TensorData};
use super::{NdArrayMathOps, NdArrayOps};
impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E, Q> {
fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<i64> {
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
for NdArray<E, I, Q>
{
fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<I> {
NdArrayTensor::from_data(data)
}
fn int_shape(tensor: &NdArrayTensor<i64>) -> Shape {
fn int_shape(tensor: &NdArrayTensor<I>) -> Shape {
tensor.shape()
}
async fn int_into_data(tensor: NdArrayTensor<i64>) -> TensorData {
async fn int_into_data(tensor: NdArrayTensor<I>) -> TensorData {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
TensorData::new(values, shape)
}
fn int_to_device(tensor: NdArrayTensor<i64>, _device: &NdArrayDevice) -> NdArrayTensor<i64> {
fn int_to_device(tensor: NdArrayTensor<I>, _device: &NdArrayDevice) -> NdArrayTensor<I> {
tensor
}
fn int_reshape(tensor: NdArrayTensor<i64>, shape: Shape) -> NdArrayTensor<i64> {
fn int_reshape(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
NdArrayOps::reshape(tensor, shape)
}
fn int_slice(tensor: NdArrayTensor<i64>, ranges: &[Range<usize>]) -> NdArrayTensor<i64> {
fn int_slice(tensor: NdArrayTensor<I>, ranges: &[Range<usize>]) -> NdArrayTensor<I> {
NdArrayOps::slice(tensor, ranges)
}
fn int_device(_tensor: &NdArrayTensor<i64>) -> <NdArray<E> as Backend>::Device {
fn int_device(_tensor: &NdArrayTensor<I>) -> <NdArray<E> as Backend>::Device {
NdArrayDevice::Cpu
}
fn int_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<i64> {
fn int_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
let values = vec![0; shape.num_elements()];
NdArrayTensor::from_data(TensorData::new(values, shape))
}
fn int_mask_where(
tensor: NdArrayTensor<i64>,
tensor: NdArrayTensor<I>,
mask: NdArrayTensor<bool>,
source: NdArrayTensor<i64>,
) -> NdArrayTensor<i64> {
source: NdArrayTensor<I>,
) -> NdArrayTensor<I> {
NdArrayMathOps::mask_where(tensor, mask, source)
}
fn int_mask_fill(
tensor: NdArrayTensor<i64>,
tensor: NdArrayTensor<I>,
mask: NdArrayTensor<bool>,
value: i64,
) -> NdArrayTensor<i64> {
value: I,
) -> NdArrayTensor<I> {
NdArrayMathOps::mask_fill(tensor, mask, value)
}
fn int_slice_assign(
tensor: NdArrayTensor<i64>,
tensor: NdArrayTensor<I>,
ranges: &[Range<usize>],
value: NdArrayTensor<i64>,
) -> NdArrayTensor<i64> {
value: NdArrayTensor<I>,
) -> NdArrayTensor<I> {
NdArrayOps::slice_assign(tensor, ranges, value)
}
fn int_cat(tensors: Vec<NdArrayTensor<i64>>, dim: usize) -> NdArrayTensor<i64> {
fn int_cat(tensors: Vec<NdArrayTensor<I>>, dim: usize) -> NdArrayTensor<I> {
NdArrayOps::cat(tensors, dim)
}
fn int_equal(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<bool> {
fn int_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
let output = Zip::from(&lhs.array)
.and(&rhs.array)
.map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
@ -94,196 +96,196 @@ impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E,
NdArrayTensor::new(output)
}
fn int_equal_elem(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<bool> {
fn int_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
let array = lhs.array.mapv(|a| a == rhs).into_shared();
NdArrayTensor { array }
}
fn int_greater(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<bool> {
fn int_greater(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
let tensor = Self::int_sub(lhs, rhs);
Self::int_greater_elem(tensor, 0)
Self::int_greater_elem(tensor, 0.elem())
}
fn int_greater_elem(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<bool> {
fn int_greater_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
let array = lhs.array.mapv(|a| a > rhs).into_shared();
NdArrayTensor::new(array)
}
fn int_greater_equal(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<bool> {
fn int_greater_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
let tensor = Self::int_sub(lhs, rhs);
Self::int_greater_equal_elem(tensor, 0)
Self::int_greater_equal_elem(tensor, 0.elem())
}
fn int_greater_equal_elem(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<bool> {
fn int_greater_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
let array = lhs.array.mapv(|a| a >= rhs).into_shared();
NdArrayTensor::new(array)
}
fn int_lower(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<bool> {
fn int_lower(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
let tensor = Self::int_sub(lhs, rhs);
Self::int_lower_elem(tensor, 0)
Self::int_lower_elem(tensor, 0.elem())
}
fn int_lower_elem(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<bool> {
fn int_lower_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
let array = lhs.array.mapv(|a| a < rhs).into_shared();
NdArrayTensor::new(array)
}
fn int_lower_equal(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<bool> {
fn int_lower_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
let tensor = Self::int_sub(lhs, rhs);
Self::int_lower_equal_elem(tensor, 0)
Self::int_lower_equal_elem(tensor, 0.elem())
}
fn int_lower_equal_elem(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<bool> {
fn int_lower_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
let array = lhs.array.mapv(|a| a <= rhs).into_shared();
NdArrayTensor::new(array)
}
fn int_add(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
fn int_add(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
NdArrayMathOps::add(lhs, rhs)
}
fn int_add_scalar(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<i64> {
fn int_add_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
NdArrayMathOps::add_scalar(lhs, rhs)
}
fn int_sub(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
fn int_sub(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
NdArrayMathOps::sub(lhs, rhs)
}
fn int_sub_scalar(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<i64> {
fn int_sub_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
NdArrayMathOps::sub_scalar(lhs, rhs)
}
fn int_mul(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
fn int_mul(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
NdArrayMathOps::mul(lhs, rhs)
}
fn int_mul_scalar(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<i64> {
fn int_mul_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
NdArrayMathOps::mul_scalar(lhs, rhs)
}
fn int_div(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
fn int_div(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
NdArrayMathOps::div(lhs, rhs)
}
fn int_div_scalar(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<i64> {
fn int_div_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
NdArrayMathOps::div_scalar(lhs, rhs)
}
fn int_remainder_scalar(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<i64> {
fn int_remainder_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
NdArrayMathOps::remainder_scalar(lhs, rhs)
}
fn int_neg(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
Self::int_mul_scalar(tensor, -1)
fn int_neg(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
Self::int_mul_scalar(tensor, (-1).elem())
}
fn int_zeros(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<i64> {
fn int_zeros(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
Self::int_from_data(TensorData::zeros::<i64, _>(shape), device)
}
fn int_ones(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<i64> {
fn int_ones(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
Self::int_from_data(TensorData::ones::<i64, _>(shape), device)
}
fn int_full(
shape: Shape,
fill_value: i64,
fill_value: I,
device: &<NdArray<E> as Backend>::Device,
) -> NdArrayTensor<i64> {
) -> NdArrayTensor<I> {
Self::int_from_data(TensorData::full(shape, fill_value), device)
}
fn int_sum(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
fn int_sum(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
NdArrayMathOps::sum(tensor)
}
fn int_sum_dim(tensor: NdArrayTensor<i64>, dim: usize) -> NdArrayTensor<i64> {
fn int_sum_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
NdArrayMathOps::sum_dim(tensor, dim)
}
fn int_prod(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
fn int_prod(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
NdArrayMathOps::prod(tensor)
}
fn int_prod_dim(tensor: NdArrayTensor<i64>, dim: usize) -> NdArrayTensor<i64> {
fn int_prod_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
NdArrayMathOps::prod_dim(tensor, dim)
}
fn int_mean(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
fn int_mean(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
NdArrayMathOps::mean(tensor)
}
fn int_mean_dim(tensor: NdArrayTensor<i64>, dim: usize) -> NdArrayTensor<i64> {
fn int_mean_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
NdArrayMathOps::mean_dim(tensor, dim)
}
fn int_gather(
dim: usize,
tensor: NdArrayTensor<i64>,
indices: NdArrayTensor<i64>,
) -> NdArrayTensor<i64> {
tensor: NdArrayTensor<I>,
indices: NdArrayTensor<I>,
) -> NdArrayTensor<I> {
NdArrayMathOps::gather(dim, tensor, indices)
}
fn int_scatter(
dim: usize,
tensor: NdArrayTensor<i64>,
indices: NdArrayTensor<i64>,
value: NdArrayTensor<i64>,
) -> NdArrayTensor<i64> {
tensor: NdArrayTensor<I>,
indices: NdArrayTensor<I>,
value: NdArrayTensor<I>,
) -> NdArrayTensor<I> {
NdArrayMathOps::scatter(dim, tensor, indices, value)
}
fn int_select(
tensor: NdArrayTensor<i64>,
tensor: NdArrayTensor<I>,
dim: usize,
indices: NdArrayTensor<i64>,
) -> NdArrayTensor<i64> {
indices: NdArrayTensor<I>,
) -> NdArrayTensor<I> {
NdArrayMathOps::select(tensor, dim, indices)
}
fn int_select_assign(
tensor: NdArrayTensor<i64>,
tensor: NdArrayTensor<I>,
dim: usize,
indices: NdArrayTensor<i64>,
value: NdArrayTensor<i64>,
) -> NdArrayTensor<i64> {
indices: NdArrayTensor<I>,
value: NdArrayTensor<I>,
) -> NdArrayTensor<I> {
NdArrayMathOps::select_assign(tensor, dim, indices, value)
}
fn int_argmax(tensor: NdArrayTensor<i64>, dim: usize) -> NdArrayTensor<i64> {
fn int_argmax(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
NdArrayMathOps::argmax(tensor, dim)
}
fn int_argmin(tensor: NdArrayTensor<i64>, dim: usize) -> NdArrayTensor<i64> {
fn int_argmin(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
NdArrayMathOps::argmin(tensor, dim)
}
fn int_clamp_min(tensor: NdArrayTensor<i64>, min: i64) -> NdArrayTensor<i64> {
fn int_clamp_min(tensor: NdArrayTensor<I>, min: I) -> NdArrayTensor<I> {
NdArrayMathOps::clamp_min(tensor, min)
}
fn int_clamp_max(tensor: NdArrayTensor<i64>, max: i64) -> NdArrayTensor<i64> {
fn int_clamp_max(tensor: NdArrayTensor<I>, max: I) -> NdArrayTensor<I> {
NdArrayMathOps::clamp_max(tensor, max)
}
fn int_clamp(tensor: NdArrayTensor<i64>, min: i64, max: i64) -> NdArrayTensor<i64> {
fn int_clamp(tensor: NdArrayTensor<I>, min: I, max: I) -> NdArrayTensor<I> {
NdArrayMathOps::clamp(tensor, min, max)
}
fn int_abs(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
fn int_abs(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
let array = tensor.array.mapv_into(|a| a.int_abs_elem()).into_shared();
NdArrayTensor::new(array)
}
fn int_into_float(tensor: NdArrayTensor<i64>) -> <NdArray<E> as Backend>::FloatTensorPrimitive {
fn int_into_float(tensor: NdArrayTensor<I>) -> <NdArray<E> as Backend>::FloatTensorPrimitive {
let array = tensor.array.mapv(|a| a.elem()).into_shared();
NdArrayTensor { array }
}
fn int_swap_dims(tensor: NdArrayTensor<i64>, dim1: usize, dim2: usize) -> NdArrayTensor<i64> {
fn int_swap_dims(tensor: NdArrayTensor<I>, dim1: usize, dim2: usize) -> NdArrayTensor<I> {
NdArrayOps::swap_dims(tensor, dim1, dim2)
}
@ -291,7 +293,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E,
shape: Shape,
distribution: Distribution,
device: &NdArrayDevice,
) -> NdArrayTensor<i64> {
) -> NdArrayTensor<I> {
let mut seed = SEED.lock().unwrap();
let mut rng = if let Some(rng_seeded) = seed.as_ref() {
rng_seeded.clone()
@ -313,32 +315,36 @@ impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E,
tensor
}
fn int_powi(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &i64, b: &i64| a.pow(*b as u32))
fn int_powi(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
(a.elem::<i64>().pow(b.elem::<u32>())).elem()
})
}
fn int_powf(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<E>) -> NdArrayTensor<i64> {
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &i64, b: &E| a.pow(b.elem::<u32>()))
fn int_powf(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<E>) -> NdArrayTensor<I> {
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &E| {
(a.elem::<i64>().pow(b.elem::<u32>())).elem()
})
}
fn int_powf_scalar(lhs: NdArrayTensor<i64>, rhs: f32) -> NdArrayTensor<i64> {
NdArrayMathOps::elementwise_op_scalar(lhs, |a: i64| a.pow(rhs as u32))
fn int_powf_scalar(lhs: NdArrayTensor<I>, rhs: f32) -> NdArrayTensor<I> {
NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| (a.elem::<i64>().pow(rhs as u32)).elem())
}
fn int_permute(tensor: NdArrayTensor<i64>, axes: &[usize]) -> NdArrayTensor<i64> {
fn int_permute(tensor: NdArrayTensor<I>, axes: &[usize]) -> NdArrayTensor<I> {
let array = tensor.array.permuted_axes(axes.into_dimension());
NdArrayTensor { array }
}
fn int_flip(tensor: NdArrayTensor<i64>, axes: &[usize]) -> NdArrayTensor<i64> {
fn int_flip(tensor: NdArrayTensor<I>, axes: &[usize]) -> NdArrayTensor<I> {
NdArrayOps::flip(tensor, axes)
}
fn int_sign(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
fn int_sign(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
NdArrayMathOps::sign_op(tensor)
}
fn int_expand(tensor: NdArrayTensor<i64>, shape: Shape) -> NdArrayTensor<i64> {
fn int_expand(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
NdArrayOps::expand(tensor, shape)
}
}

View File

@ -1,5 +1,5 @@
use crate::{
element::{FloatNdArrayElement, QuantElement},
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
ops::padding::apply_padding_4d,
sharing::UnsafeSharedRef,
tensor::NdArrayTensor,
@ -9,7 +9,7 @@ use burn_common::{iter_range_par, run_par};
use burn_tensor::ElementConversion;
use ndarray::Array4;
pub(crate) fn max_pool2d<E: FloatNdArrayElement, Q: QuantElement>(
pub(crate) fn max_pool2d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
@ -30,7 +30,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement, Q: QuantElement>(
/ stride_width)
+ 1;
let x = apply_padding_4d::<E, Q>(x, padding, inf).array;
let x = apply_padding_4d::<E, I, Q>(x, padding, inf).array;
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
@ -69,13 +69,17 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement, Q: QuantElement>(
NdArrayTensor::new(output.into_dyn().into_shared())
}
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, Q: QuantElement>(
pub(crate) fn max_pool2d_with_indices<
E: FloatNdArrayElement,
I: IntNdArrayElement,
Q: QuantElement,
>(
x: NdArrayTensor<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> (NdArrayTensor<E>, NdArrayTensor<i64>) {
) -> (NdArrayTensor<E>, NdArrayTensor<I>) {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
@ -90,10 +94,10 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, Q: QuantElement>(
/ stride_width)
+ 1;
let x = apply_padding_4d::<E, Q>(x, padding, inf).array;
let x = apply_padding_4d::<E, I, Q>(x, padding, inf).array;
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
let mut indices = Array4::<i64>::zeros((batch_size, channels, out_height, out_width));
let mut indices = Array4::<I>::zeros((batch_size, channels, out_height, out_width));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
@ -130,7 +134,7 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, Q: QuantElement>(
}
output[[b, c, oh, ow]] = max_val;
indices[[b, c, oh, ow]] = index;
indices[[b, c, oh, ow]] = index.elem();
}
}
})
@ -142,14 +146,14 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, Q: QuantElement>(
(output, indices)
}
pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement, I: IntNdArrayElement>(
x: NdArrayTensor<E>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
_dilation: [usize; 2],
output_grad: NdArrayTensor<E>,
indices: NdArrayTensor<i64>,
indices: NdArrayTensor<I>,
) -> NdArrayTensor<E> {
let [_batch_size, _channels, height, width] = output_grad.shape().dims();
let [batch_size, channels, height_x, width_x] = x.shape().dims();
@ -170,7 +174,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
for h in 0..height {
for w in 0..width {
let index = indices[[b, c, h, w]];
let index = indices[[b, c, h, w]].elem::<i64>();
let grad = output_grad[[b, c, h, w]];
let index_h = index as usize / width_x;

View File

@ -7,17 +7,22 @@ use super::{
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
};
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
use crate::{element::QuantElement, ops::interpolate::nearest_interpolate_backward};
use crate::{
element::{IntNdArrayElement, QuantElement},
ops::interpolate::nearest_interpolate_backward,
};
use burn_tensor::ops::*;
impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q> {
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
for NdArray<E, I, Q>
{
fn conv2d(
x: NdArrayTensor<E>,
weight: NdArrayTensor<E>,
bias: Option<NdArrayTensor<E>>,
options: ConvOptions<2>,
) -> NdArrayTensor<E> {
conv2d::<E, Q>(x, weight, bias, options)
conv2d::<E, I, Q>(x, weight, bias, options)
}
fn deform_conv2d(
@ -80,7 +85,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q>
padding: [usize; 2],
dilation: [usize; 2],
) -> NdArrayTensor<E> {
max_pool2d::<E, Q>(x, kernel_size, stride, padding, dilation)
max_pool2d::<E, I, Q>(x, kernel_size, stride, padding, dilation)
}
fn max_pool2d_with_indices(
@ -89,9 +94,9 @@ impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q>
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<NdArray<E, Q>> {
) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
let (output, indices) =
max_pool2d_with_indices::<E, Q>(x, kernel_size, stride, padding, dilation);
max_pool2d_with_indices::<E, I, Q>(x, kernel_size, stride, padding, dilation);
MaxPool2dWithIndices::new(output, indices)
}
@ -103,8 +108,8 @@ impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q>
padding: [usize; 2],
dilation: [usize; 2],
output_grad: NdArrayTensor<E>,
indices: NdArrayTensor<i64>,
) -> MaxPool2dBackward<NdArray<E, Q>> {
indices: NdArrayTensor<I>,
) -> MaxPool2dBackward<NdArray<E, I, Q>> {
MaxPool2dBackward::new(max_pool2d_backward(
x,
kernel_size,
@ -162,7 +167,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q>
bias: Option<NdArrayTensor<E>>,
options: ConvOptions<3>,
) -> NdArrayTensor<E> {
conv3d::<E, Q>(x, weight, bias, options)
conv3d::<E, I, Q>(x, weight, bias, options)
}
fn conv_transpose3d(

View File

@ -1,12 +1,12 @@
use crate::{
element::{FloatNdArrayElement, QuantElement},
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
tensor::NdArrayTensor,
NdArray,
};
use burn_tensor::ops::FloatTensorOps;
use ndarray::{Array4, Array5};
pub(crate) fn apply_padding_4d<E: FloatNdArrayElement, Q: QuantElement>(
pub(crate) fn apply_padding_4d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E>,
padding: [usize; 2],
elem: E,
@ -22,7 +22,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement, Q: QuantElement>(
);
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
x_new = NdArray::<E, Q>::float_slice_assign(
x_new = NdArray::<E, I, Q>::float_slice_assign(
x_new,
&[
0..batch_size,
@ -36,7 +36,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement, Q: QuantElement>(
x_new
}
pub(crate) fn apply_padding_5d<E: FloatNdArrayElement, Q: QuantElement>(
pub(crate) fn apply_padding_5d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
x: NdArrayTensor<E>,
padding: [usize; 3],
elem: E,
@ -59,7 +59,7 @@ pub(crate) fn apply_padding_5d<E: FloatNdArrayElement, Q: QuantElement>(
);
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
x_new = NdArray::<E, Q>::float_slice_assign(
x_new = NdArray::<E, I, Q>::float_slice_assign(
x_new,
&[
0..batch_size,

View File

@ -10,7 +10,7 @@ use burn_tensor::{
};
use crate::{
element::{NdArrayElement, QuantElement},
element::{IntNdArrayElement, NdArrayElement, QuantElement},
FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor,
};
@ -22,7 +22,9 @@ fn into_data<E: NdArrayElement>(tensor: NdArrayTensor<E>) -> TensorData {
TensorData::new(values, shape)
}
impl<E: FloatNdArrayElement, Q: QuantElement> QTensorOps<Self> for NdArray<E, Q> {
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
for NdArray<E, I, Q>
{
fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
match data.dtype {
DType::QFloat(strategy) => match strategy {

View File

@ -5,7 +5,7 @@ use ndarray::Zip;
// Current crate
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
use crate::element::{FloatNdArrayElement, QuantElement};
use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
use crate::{tensor::NdArrayTensor, NdArray};
use crate::{NdArrayDevice, SEED};
@ -20,7 +20,9 @@ use num_traits::Float;
use libm::erf;
impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E, Q> {
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
for NdArray<E, I, Q>
{
fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<E> {
NdArrayTensor::from_data(data)
}
@ -125,7 +127,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
fn float_gather(
dim: usize,
tensor: NdArrayTensor<E>,
indices: NdArrayTensor<i64>,
indices: NdArrayTensor<I>,
) -> NdArrayTensor<E> {
NdArrayMathOps::gather(dim, tensor, indices)
}
@ -133,7 +135,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
fn float_scatter(
dim: usize,
tensor: NdArrayTensor<E>,
indices: NdArrayTensor<i64>,
indices: NdArrayTensor<I>,
value: NdArrayTensor<E>,
) -> NdArrayTensor<E> {
NdArrayMathOps::scatter(dim, tensor, indices, value)
@ -142,7 +144,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
fn float_select(
tensor: NdArrayTensor<E>,
dim: usize,
indices: NdArrayTensor<i64>,
indices: NdArrayTensor<I>,
) -> NdArrayTensor<E> {
NdArrayMathOps::select(tensor, dim, indices)
}
@ -150,7 +152,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
fn float_select_assign(
tensor: NdArrayTensor<E>,
dim: usize,
indices: NdArrayTensor<i64>,
indices: NdArrayTensor<I>,
value: NdArrayTensor<E>,
) -> NdArrayTensor<E> {
NdArrayMathOps::select_assign(tensor, dim, indices, value)
@ -266,11 +268,11 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
NdArrayMathOps::sum_dim(tensor, dim)
}
fn float_argmax(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<i64> {
fn float_argmax(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
NdArrayMathOps::argmax(tensor, dim)
}
fn float_argmin(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<i64> {
fn float_argmin(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
NdArrayMathOps::argmin(tensor, dim)
}
@ -374,7 +376,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
NdArrayMathOps::clamp(tensor, min, max)
}
fn float_into_int(tensor: NdArrayTensor<E>) -> <NdArray<E> as Backend>::IntTensorPrimitive {
fn float_into_int(tensor: NdArrayTensor<E>) -> NdArrayTensor<I> {
let array = tensor.array.mapv(|a| a.elem()).into_shared();
NdArrayTensor { array }
}

View File

@ -212,7 +212,7 @@ mod tests {
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
let qparams = QuantizationParametersPrimitive {
scale: NdArrayTensor::from_data(TensorData::from([0.009_019_608])),
offset: Some(NdArrayTensor::from_data(TensorData::from([72]))),
offset: Some(NdArrayTensor::<i64>::from_data(TensorData::from([72]))),
};
let qtensor: NdArrayQTensor<i8> = NdArray::quantize(tensor, &scheme, qparams);

View File

@ -0,0 +1,39 @@
[package]
authors = ["guillaumelagrange <lagrange.guillaume.1@gmail.com>", "nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "Multi-backend router decorator for the Burn framework"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "data"]
license.workspace = true
name = "burn-router"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router"
documentation = "https://docs.rs/burn-router"
version.workspace = true
[features]
default = ["std"]
std = []
doc = ["default"]
[dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false, features = ["repr"]}
hashbrown = { workspace = true }
spin = { workspace = true }
[dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", default-features = false, features = [
"export_tests",
] }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false, features = [
"export_tests",
] }
burn-ndarray = { path = "../burn-ndarray", version = "0.15.0" }
burn-wgpu = { path = "../burn-wgpu", version = "0.15.0" }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@ -0,0 +1,3 @@
# Burn Router
A multi-backend extension that forwards the tensor operations to the appropriate backend.

View File

@ -0,0 +1,125 @@
use alloc::{format, string::String};
use core::marker::PhantomData;
use burn_tensor::{
backend::{Backend, BackendBridge},
ops::FloatTensor,
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
repr::{BaseOperationDescription, OperationDescription, UnaryOperationDescription},
Device,
};
use super::{get_client, set_seed, RouterTensor, RunnerChannel, RunnerClient};
/// A backend that forwards the tensor operations to the appropiate backend (given multiple backends).
pub struct BackendRouter<R: RunnerChannel> {
r: PhantomData<R>,
}
impl<R: RunnerChannel> core::fmt::Debug for BackendRouter<R> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("router"))
}
}
impl<R: RunnerChannel> Clone for BackendRouter<R> {
fn clone(&self) -> Self {
Self { r: PhantomData }
}
}
impl<R: RunnerChannel> Default for BackendRouter<R> {
fn default() -> Self {
Self { r: PhantomData }
}
}
// TODO: quantization tensor primitive (w/ qparams)
impl<R: RunnerClient> QTensorPrimitive for RouterTensor<R> {
fn scheme(&self) -> &QuantizationScheme {
todo!()
}
fn strategy(&self) -> QuantizationStrategy {
todo!()
}
}
impl<R: RunnerChannel> Backend for BackendRouter<R> {
type Device = R::Device;
type FullPrecisionBridge = PrecisionBridge;
type FloatTensorPrimitive = RouterTensor<R::Client>;
type FloatElem = R::FloatElem;
type IntTensorPrimitive = RouterTensor<R::Client>;
type IntElem = R::IntElem;
type BoolTensorPrimitive = RouterTensor<R::Client>;
type QuantizedTensorPrimitive = RouterTensor<R::Client>;
type QuantizedEncoding = u32;
fn name() -> String {
format!("router<{}>", R::name())
}
fn seed(seed: u64) {
set_seed(seed)
}
fn sync(device: &Self::Device) {
let client = get_client::<R>(device);
client.sync();
}
}
/// Handle precision conversion.
#[derive(Debug)]
pub struct PrecisionBridge {}
impl<R: RunnerChannel> BackendBridge<BackendRouter<R>> for PrecisionBridge {
type Target = BackendRouter<R>;
fn into_target(
tensor: FloatTensor<BackendRouter<R>>,
_device: Option<Device<Self::Target>>,
) -> FloatTensor<Self::Target> {
let client = tensor.client.clone();
let out = client.register_float_tensor(tensor.shape.clone(), true);
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::BaseFloat(
BaseOperationDescription::Cast(desc),
));
out
}
fn from_target(
tensor: FloatTensor<Self::Target>,
_device: Option<Device<BackendRouter<R>>>,
) -> FloatTensor<BackendRouter<R>> {
let client = tensor.client.clone();
let out = client.register_float_tensor(tensor.shape.clone(), false);
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::BaseFloat(
BaseOperationDescription::Cast(desc),
));
out
}
}

View File

@ -0,0 +1,32 @@
use burn_tensor::{backend::DeviceOps, Shape};
/// Allows tensors to be transferred between multiple backends.
pub trait MultiBackendBridge: Send + Sync + 'static {
/// The type that can be used to point to a tensor of any kind.
type TensorHandle;
/// Device type used by the backends.
type Device: DeviceOps;
/// Change the backend of the given float tensor.
fn change_backend_float(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle;
/// Change the backend of the given int tensor.
fn change_backend_int(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle;
/// Change the backend of the given bool tensor.
fn change_backend_bool(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle;
// TODO: change_backend_quantized
}

View File

@ -0,0 +1,171 @@
use core::marker::PhantomData;
use burn_tensor::{
repr::{ReprBackend, TensorHandle},
try_read_sync, Shape,
};
use super::base::MultiBackendBridge;
use crate::{MultiDevice2, TensorHandle2};
/// Simply transfers tensors between backends via the underlying [tensor data](burn_tensor::TensorData).
pub struct ByteBridge<Backends> {
backends: PhantomData<Backends>,
}
impl<B1: ReprBackend, B2: ReprBackend> MultiBackendBridge for ByteBridge<(B1, B2)> {
type TensorHandle = TensorHandle2<B1, B2>;
type Device = MultiDevice2<B1, B2>;
fn change_backend_float(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
let msg = "Failed to read tensor data synchronously.
This can happen on platforms that don't support blocking futures like WASM.";
match tensor {
TensorHandle2::Handle1(handle) => match target_device {
MultiDevice2::Device1(device) => {
// Same backend
let tensor = B1::float_tensor(TensorHandle { handle, shape });
let tensor = B1::float_to_device(tensor, device);
let handle = B1::float_tensor_handle(tensor);
TensorHandle2::Handle1(handle)
}
MultiDevice2::Device2(device) => {
let tensor = B1::float_tensor(TensorHandle { handle, shape });
let data = try_read_sync(B1::float_into_data(tensor)).expect(msg);
let tensor = B2::float_from_data(data, device);
let handle = B2::float_tensor_handle(tensor);
TensorHandle2::Handle2(handle)
}
},
TensorHandle2::Handle2(handle) => match target_device {
MultiDevice2::Device1(device) => {
let tensor = B2::float_tensor(TensorHandle { handle, shape });
let data = try_read_sync(B2::float_into_data(tensor)).expect(msg);
let tensor = B1::float_from_data(data, device);
let handle = B1::float_tensor_handle(tensor);
TensorHandle2::Handle1(handle)
}
MultiDevice2::Device2(device) => {
// Same backend
let tensor = B2::float_tensor(TensorHandle { handle, shape });
let tensor = B2::float_to_device(tensor, device);
let handle = B2::float_tensor_handle(tensor);
TensorHandle2::Handle2(handle)
}
},
}
}
fn change_backend_int(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
let msg = "Failed to read tensor data synchronously.
This can happen on platforms that don't support blocking futures like WASM.";
match tensor {
TensorHandle2::Handle1(handle) => match target_device {
MultiDevice2::Device1(device) => {
// Same backend
let tensor = B1::int_tensor(TensorHandle { handle, shape });
let tensor = B1::int_to_device(tensor, device);
let handle = B1::int_tensor_handle(tensor);
TensorHandle2::Handle1(handle)
}
MultiDevice2::Device2(device) => {
let tensor = B1::int_tensor(TensorHandle { handle, shape });
let data = try_read_sync(B1::int_into_data(tensor)).expect(msg);
let tensor = B2::int_from_data(data, device);
let handle = B2::int_tensor_handle(tensor);
TensorHandle2::Handle2(handle)
}
},
TensorHandle2::Handle2(handle) => match target_device {
MultiDevice2::Device1(device) => {
let tensor = B2::int_tensor(TensorHandle { handle, shape });
let data = try_read_sync(B2::int_into_data(tensor)).expect(msg);
let tensor = B1::int_from_data(data, device);
let handle = B1::int_tensor_handle(tensor);
TensorHandle2::Handle1(handle)
}
MultiDevice2::Device2(device) => {
// Same backend
let tensor = B2::int_tensor(TensorHandle { handle, shape });
let tensor = B2::int_to_device(tensor, device);
let handle = B2::int_tensor_handle(tensor);
TensorHandle2::Handle2(handle)
}
},
}
}
fn change_backend_bool(
tensor: Self::TensorHandle,
shape: Shape,
target_device: &Self::Device,
) -> Self::TensorHandle {
let msg = "Failed to read tensor data synchronously.
This can happen on platforms that don't support blocking futures like WASM.";
match tensor {
TensorHandle2::Handle1(handle) => match target_device {
MultiDevice2::Device1(device) => {
// Same backend
let tensor = B1::bool_tensor(TensorHandle { handle, shape });
let tensor = B1::bool_to_device(tensor, device);
let handle = B1::bool_tensor_handle(tensor);
TensorHandle2::Handle1(handle)
}
MultiDevice2::Device2(device) => {
let tensor = B1::bool_tensor(TensorHandle { handle, shape });
let data = try_read_sync(B1::bool_into_data(tensor)).expect(msg);
let tensor = B2::bool_from_data(data, device);
let handle = B2::bool_tensor_handle(tensor);
TensorHandle2::Handle2(handle)
}
},
TensorHandle2::Handle2(handle) => match target_device {
MultiDevice2::Device1(device) => {
let tensor = B2::bool_tensor(TensorHandle { handle, shape });
let data = try_read_sync(B2::bool_into_data(tensor)).expect(msg);
let tensor = B1::bool_from_data(data, device);
let handle = B1::bool_tensor_handle(tensor);
TensorHandle2::Handle1(handle)
}
MultiDevice2::Device2(device) => {
// Same backend
let tensor = B2::bool_tensor(TensorHandle { handle, shape });
let tensor = B2::bool_to_device(tensor, device);
let handle = B2::bool_tensor_handle(tensor);
TensorHandle2::Handle2(handle)
}
},
}
}
}
#[cfg(not(target_os = "windows"))]
#[cfg(test)]
mod tests {
use burn_tensor::{backend::Backend, Tensor};
use super::*;
use crate::tests::{TestBackend, TestBackend1, TestBackend2};
#[test]
fn should_support_dual_byte_bridge() {
let device1 = MultiDevice2::Device1(<TestBackend1 as Backend>::Device::default());
let device2 = MultiDevice2::Device2(<TestBackend2 as Backend>::Device::default());
let tensor1 = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device1);
let tensor2 = Tensor::<TestBackend, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device2);
let tensor1_2 = tensor1.clone().to_device(&device2);
tensor1.into_data().assert_eq(&tensor1_2.into_data(), true);
let tensor2_1 = tensor2.clone().to_device(&device1);
tensor2.into_data().assert_eq(&tensor2_1.into_data(), true);
}
}

View File

@ -0,0 +1,5 @@
mod base;
mod byte;
pub use base::*;
pub use byte::*;

View File

@ -0,0 +1,68 @@
use alloc::{string::String, vec::Vec};
use burn_tensor::{backend::DeviceOps, repr::TensorDescription, DType, Element};
use crate::{get_client, MultiBackendBridge, RouterTensor, RunnerClient};
/// Type alias for `<Br as MultiBackendBridge>::TensorHandle`.
pub type TensorHandle<Br> = <Br as MultiBackendBridge>::TensorHandle;
/// Defines the connection channel and operations for a setup with multiple backend runner clients.
pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized {
/// Device type.
type Device: DeviceOps;
/// A bridge that can transfer tensors between multiple backends.
type Bridge: MultiBackendBridge<Device = Self::Device>;
/// Client type.
type Client: RunnerClient<Device = Self::Device>;
/// Float element type.
type FloatElem: Element;
/// Int element type.
type IntElem: Element;
/// Name of the channel.
fn name() -> String;
/// Initialize a new client for the given device.
fn init_client(device: &Self::Device) -> Self::Client;
/// Get the tensor handle corresponding to the [tensor description](TensorDescription).
fn get_tensor_handle(
tensor: &TensorDescription,
client: &Self::Client,
) -> TensorHandle<Self::Bridge>;
// TODO: get quantized tensor handle from QuantizedTensorDescription
/// Create a tensor with the given handle and shape.
fn register_tensor(
client: &Self::Client,
handle: TensorHandle<Self::Bridge>,
shape: Vec<usize>,
dtype: DType,
) -> RouterTensor<Self::Client>;
/// Change the tensor to a different client backend.
fn change_client_backend(
tensor: RouterTensor<Self::Client>,
device: &Self::Device, // target device
) -> RouterTensor<Self::Client> {
// Get tensor handle from current client
let original_client = tensor.client.clone();
let desc = tensor.into_description();
let mut handle = Self::get_tensor_handle(&desc, &original_client);
if desc.dtype.is_float() {
handle = Self::Bridge::change_backend_float(handle, desc.shape.clone().into(), device);
} else if desc.dtype.is_int() {
handle = Self::Bridge::change_backend_int(handle, desc.shape.clone().into(), device);
} else if desc.dtype.is_bool() {
handle = Self::Bridge::change_backend_bool(handle, desc.shape.clone().into(), device);
} else {
unimplemented!()
}
// Register tensor handle on target client
let target_client = get_client::<Self>(device);
Self::register_tensor(&target_client, handle, desc.shape, desc.dtype)
}
}

View File

@ -0,0 +1,271 @@
use alloc::{format, string::String, sync::Arc, vec::Vec};
use core::marker::PhantomData;
use burn_tensor::{
backend::{Backend, BackendBridge, DeviceId, DeviceOps},
repr::{OperationDescription, ReprBackend, TensorDescription, TensorId},
DType, TensorData,
};
use super::{RunnerChannel, TensorHandle};
use crate::{MultiBackendBridge, RouterTensor, Runner, RunnerClient};
/// A local channel with direct connection to the backend runner clients.
pub struct DirectChannel<Backends, Bridge> {
backends: PhantomData<Backends>,
bridge: PhantomData<Bridge>,
}
impl<Backends, Bridge> Clone for DirectChannel<Backends, Bridge> {
fn clone(&self) -> Self {
Self {
backends: self.backends,
bridge: self.bridge,
}
}
}
impl<B1, B2, Br> RunnerChannel for DirectChannel<(B1, B2), Br>
where
B1: ReprBackend,
B2: ReprBackend<FloatElem = B1::FloatElem, IntElem = B1::IntElem>,
Br: MultiBackendBridge<TensorHandle = TensorHandle2<B1, B2>, Device = MultiDevice2<B1, B2>>,
// Restrict full precision backend handle to be the same
<<B1 as Backend>::FullPrecisionBridge as BackendBridge<B1>>::Target:
ReprBackend<Handle = B1::Handle>,
<<B2 as Backend>::FullPrecisionBridge as BackendBridge<B2>>::Target:
ReprBackend<Handle = B2::Handle>,
{
type Device = Br::Device;
type Bridge = Br;
type FloatElem = B1::FloatElem;
type IntElem = B1::IntElem;
type Client = MultiRunnerClient2<B1, B2>;
fn init_client(device: &Self::Device) -> Self::Client {
match device {
MultiDevice2::Device1(device) => {
MultiRunnerClient2::RunnerClient1(Runner::new(device.clone()))
}
MultiDevice2::Device2(device) => {
MultiRunnerClient2::RunnerClient2(Runner::new(device.clone()))
}
}
}
fn get_tensor_handle(
tensor: &TensorDescription,
client: &Self::Client,
) -> TensorHandle<Self::Bridge> {
match client {
MultiRunnerClient2::RunnerClient1(runner) => {
TensorHandle2::Handle1(runner.get_tensor_handle(tensor))
}
MultiRunnerClient2::RunnerClient2(runner) => {
TensorHandle2::Handle2(runner.get_tensor_handle(tensor))
}
}
}
fn register_tensor(
client: &Self::Client,
handle: TensorHandle<Self::Bridge>,
shape: Vec<usize>,
dtype: DType,
) -> RouterTensor<Self::Client> {
match client {
MultiRunnerClient2::RunnerClient1(runner) => match handle {
TensorHandle2::Handle1(handle) => {
runner.register_tensor(handle, shape, dtype, client.clone())
}
TensorHandle2::Handle2(_) => {
unreachable!("Can't register tensor handle for another backend.")
}
},
MultiRunnerClient2::RunnerClient2(runner) => match handle {
TensorHandle2::Handle1(_) => {
unreachable!("Can't register tensor handle for another backend.")
}
TensorHandle2::Handle2(handle) => {
runner.register_tensor(handle, shape, dtype, client.clone())
}
},
}
}
fn name() -> String {
format!("direct<({}, {})>", B1::name(), B2::name())
}
}
// TODO: generate this for different number of backends (up to 4?)
/// Handle type to interact with two backends.
pub enum TensorHandle2<B1: ReprBackend, B2: ReprBackend> {
/// Handle for the first backend.
Handle1(B1::Handle),
/// Handle for the second backend.
Handle2(B2::Handle),
}
/// Device type to interact with two backends.
#[derive(Clone, Debug)]
pub enum MultiDevice2<B1: Backend, B2: Backend> {
/// Device for the first backend.
Device1(B1::Device),
/// Device for the second backend.
Device2(B2::Device),
}
impl<B1: Backend, B2: Backend> PartialEq for MultiDevice2<B1, B2> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Device1(lhs), Self::Device1(rhs)) => lhs == rhs,
(Self::Device2(lhs), Self::Device2(rhs)) => lhs == rhs,
_ => false,
}
}
}
impl<B1: Backend, B2: Backend> Eq for MultiDevice2<B1, B2> {}
impl<B1: Backend, B2: Backend> Default for MultiDevice2<B1, B2> {
fn default() -> Self {
Self::Device1(B1::Device::default())
}
}
impl<B1: Backend, B2: Backend> DeviceOps for MultiDevice2<B1, B2> {
fn id(&self) -> DeviceId {
match self {
MultiDevice2::Device1(device) => device.id(),
MultiDevice2::Device2(device) => device.id(),
}
}
}
/// Local [`RunnerClient`] with two backends.
#[derive(Clone)]
pub enum MultiRunnerClient2<B1: ReprBackend, B2: ReprBackend> {
/// Client for the first backend runner.
RunnerClient1(Runner<B1>),
/// Client for the second backend runner.
RunnerClient2(Runner<B2>),
}
impl<B1: ReprBackend, B2: ReprBackend> RunnerClient for MultiRunnerClient2<B1, B2>
where
<<B1 as Backend>::FullPrecisionBridge as BackendBridge<B1>>::Target:
ReprBackend<Handle = B1::Handle>,
<<B2 as Backend>::FullPrecisionBridge as BackendBridge<B2>>::Target:
ReprBackend<Handle = B2::Handle>,
{
type Device = MultiDevice2<B1, B2>;
fn register(&self, op: OperationDescription) {
match self {
MultiRunnerClient2::RunnerClient1(runner) => runner.register(op),
MultiRunnerClient2::RunnerClient2(runner) => runner.register(op),
}
}
async fn read_tensor(&self, tensor: TensorDescription) -> TensorData {
match self {
MultiRunnerClient2::RunnerClient1(runner) => runner.read_tensor(tensor).await,
MultiRunnerClient2::RunnerClient2(runner) => runner.read_tensor(tensor).await,
}
}
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
match self {
MultiRunnerClient2::RunnerClient1(runner) => {
let desc = runner.register_tensor_data_desc(data);
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
}
MultiRunnerClient2::RunnerClient2(runner) => {
let desc = runner.register_tensor_data_desc(data);
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
}
}
}
fn register_empty_tensor(&self, shape: Vec<usize>, dtype: DType) -> RouterTensor<Self> {
match self {
MultiRunnerClient2::RunnerClient1(runner) => {
let desc = runner.register_empty_tensor_desc(shape, dtype);
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
}
MultiRunnerClient2::RunnerClient2(runner) => {
let desc = runner.register_empty_tensor_desc(shape, dtype);
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
}
}
}
fn register_float_tensor(&self, shape: Vec<usize>, full_precision: bool) -> RouterTensor<Self> {
match self {
MultiRunnerClient2::RunnerClient1(runner) => {
let desc = runner.register_float_tensor_desc(shape, full_precision);
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
}
MultiRunnerClient2::RunnerClient2(runner) => {
let desc = runner.register_float_tensor_desc(shape, full_precision);
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
}
}
}
fn device(&self) -> Self::Device {
match self {
MultiRunnerClient2::RunnerClient1(runner) => MultiDevice2::Device1(runner.device()),
MultiRunnerClient2::RunnerClient2(runner) => MultiDevice2::Device2(runner.device()),
}
}
fn register_orphan(&self, id: &TensorId) {
match self {
MultiRunnerClient2::RunnerClient1(runner) => runner.register_orphan(id),
MultiRunnerClient2::RunnerClient2(runner) => runner.register_orphan(id),
}
}
fn sync(&self) {
match self {
MultiRunnerClient2::RunnerClient1(runner) => runner.sync(),
MultiRunnerClient2::RunnerClient2(runner) => runner.sync(),
}
}
fn seed(&self, seed: u64) {
match self {
MultiRunnerClient2::RunnerClient1(runner) => runner.seed(seed),
MultiRunnerClient2::RunnerClient2(runner) => runner.seed(seed),
}
}
}
// NOTE: conflicting implementations because B1 and B2 cannot be differentiated (could be the same type)
// impl<B1: ReprBackend, B2: ReprBackend> From<RouterTensor<Runner<B1>>>
// for RouterTensor<MultiRunnerClient2<B1, B2>>
// {
// fn from(value: RouterTensor<Runner<B1>>) -> Self {
// RouterTensor {
// desc: value.desc,
// client: MultiRunnerClient2::RunnerClient1(value.client),
// }
// }
// }
// impl<B1: ReprBackend, B2: ReprBackend> From<RouterTensor<Runner<B2>>>
// for RouterTensor<MultiRunnerClient2<B1, B2>>
// {
// fn from(value: RouterTensor<Runner<B2>>) -> Self {
// RouterTensor {
// desc: value.desc,
// client: MultiRunnerClient2::RunnerClient2(value.client),
// }
// }
// }

View File

@ -0,0 +1,5 @@
mod base;
mod direct;
pub use base::*;
pub use direct::*;

View File

@ -0,0 +1,138 @@
use alloc::{boxed::Box, vec::Vec};
use core::{
future::Future,
ops::DerefMut,
sync::atomic::{AtomicBool, AtomicU64, Ordering},
};
use hashbrown::HashMap;
use spin::Mutex;
use burn_tensor::{
backend::{DeviceId, DeviceOps},
repr::{OperationDescription, TensorDescription, TensorId},
DType, TensorData,
};
use crate::{RouterTensor, RunnerChannel};
/// Type alias for `<R as RunnerChannel>::Client`.
pub type Client<R> = <R as RunnerChannel>::Client;
pub(crate) static CLIENTS: RunnerClientLocator = RunnerClientLocator::new();
static SEED_SET: AtomicBool = AtomicBool::new(false);
static SEED: AtomicU64 = AtomicU64::new(0);
type Key = (core::any::TypeId, DeviceId);
/// Define how to interact with the runner.
pub trait RunnerClient: Clone + Send + Sync + Sized {
/// Device type.
type Device: DeviceOps;
/// Register a new tensor operation to be executed by the (runner) server.
fn register(&self, op: OperationDescription);
/// Read the values contained by a tensor.
fn read_tensor(&self, tensor: TensorDescription) -> impl Future<Output = TensorData> + Send;
/// Create a new [RouterTensor] from the tensor data.
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self>;
/// Create a new [RouterTensor] with no resources associated.
fn register_empty_tensor(&self, shape: Vec<usize>, dtype: DType) -> RouterTensor<Self>;
/// Create a new float [RouterTensor] with no resources associated.
fn register_float_tensor(&self, shape: Vec<usize>, full_precision: bool) -> RouterTensor<Self>;
/// Get the current device used by all operations handled by this client.
fn device(&self) -> Self::Device;
/// Drop the tensor with the given [tensor id](TensorId).
fn register_orphan(&self, id: &TensorId);
/// Sync the runner, ensure that all computations are finished.
fn sync(&self);
/// Seed the runner.
fn seed(&self, seed: u64);
}
pub(crate) struct RunnerClientLocator {
clients: Mutex<Option<HashMap<Key, Box<dyn core::any::Any + Send>>>>,
}
pub(crate) fn get_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
CLIENTS.client::<R>(device)
}
pub(crate) fn set_seed(seed: u64) {
SEED_SET.store(true, Ordering::Relaxed);
SEED.store(seed, Ordering::Relaxed);
}
fn get_seed() -> Option<u64> {
if SEED_SET.load(Ordering::Relaxed) {
Some(SEED.load(Ordering::Relaxed))
} else {
None
}
}
/// Initialize a new client for the given device.
///
/// If a (global) seed was previously set, the client seed is set.
fn new_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
let client = R::init_client(device);
if let Some(seed) = get_seed() {
client.seed(seed)
}
client
}
impl RunnerClientLocator {
/// Create a new client locator.
pub const fn new() -> Self {
Self {
clients: Mutex::new(None),
}
}
/// Get the runner client for the given device.
///
/// If a client isn't already initialized, it is created.
pub fn client<R: RunnerChannel + 'static>(&self, device: &R::Device) -> Client<R> {
let device_id = device.id();
let client_id = (core::any::TypeId::of::<R>(), device_id);
let mut clients = self.clients.lock();
if clients.is_none() {
let client = new_client::<R>(device);
Self::register_inner::<R>(client_id, client, &mut clients);
}
match clients.deref_mut() {
Some(clients) => match clients.get(&client_id) {
Some(client) => {
let client: &Client<R> = client.downcast_ref().unwrap();
client.clone()
}
None => {
let client = new_client::<R>(device);
let any = Box::new(client.clone());
clients.insert(client_id, any);
client
}
},
_ => unreachable!(),
}
}
fn register_inner<R: RunnerChannel + 'static>(
key: Key,
client: Client<R>,
clients: &mut Option<HashMap<Key, Box<dyn core::any::Any + Send>>>,
) {
if clients.is_none() {
*clients = Some(HashMap::new());
}
if let Some(clients) = clients {
if clients.contains_key(&key) {
panic!("Client already created for device {:?}", key);
}
clients.insert(key, Box::new(client));
}
}
}

View File

@ -0,0 +1,3 @@
mod base;
pub use base::*;

View File

@ -0,0 +1,50 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
//! Burn multi-backend router.
mod backend;
mod bridge;
mod channel;
mod client;
mod ops;
mod runner;
mod tensor;
pub use backend::*;
pub use bridge::*;
pub use channel::*;
pub use client::*;
pub use runner::*;
pub use tensor::*;
extern crate alloc;
#[cfg(test)]
mod tests {
use alloc::format;
use alloc::vec;
use crate::BackendRouter;
use crate::ByteBridge;
use crate::DirectChannel;
type DirectByteChannel<Backends> = DirectChannel<Backends, ByteBridge<Backends>>;
pub type TestBackend1 = burn_ndarray::NdArray<f32, i32>;
pub type TestBackend2 = burn_wgpu::Wgpu<f32, i32>;
pub type TestBackend = BackendRouter<DirectByteChannel<(TestBackend1, TestBackend2)>>;
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
pub type TestTensorBool<const D: usize> =
burn_tensor::Tensor<TestBackend, D, burn_tensor::Bool>;
burn_tensor::testgen_all!();
// TODO: add support for quantization
// burn_tensor::testgen_quantization!();
#[cfg(feature = "std")]
burn_autodiff::testgen_all!();
}

View File

@ -0,0 +1,55 @@
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_float_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let rhs = $handles.get_float_tensor::<B>(&$desc.rhs);
let output = $ops(lhs, rhs);
$handles.register_float_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_float_cmp_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let rhs = $handles.get_float_tensor::<B>(&$desc.rhs);
let output = $ops(lhs, rhs);
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_int_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let rhs = $handles.get_int_tensor::<B>(&$desc.rhs);
let output = $ops(lhs, rhs);
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! binary_int_cmp_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let rhs = $handles.get_int_tensor::<B>(&$desc.rhs);
let output = $ops(lhs, rhs);
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
}};
}

View File

@ -0,0 +1,8 @@
mod binary;
mod op_activation;
mod op_bool;
mod op_float;
mod op_int;
mod op_module;
mod op_qfloat;
mod unary;

View File

@ -0,0 +1,4 @@
use crate::{BackendRouter, RunnerChannel};
use burn_tensor::ops::ActivationOps;
impl<R: RunnerChannel> ActivationOps<Self> for BackendRouter<R> {}

View File

@ -0,0 +1,302 @@
use alloc::vec::Vec;
use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntElem, IntTensor};
use burn_tensor::repr::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
OperationDescription, PermuteOperationDescription, RepeatDimOperationDescription,
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription,
SwapDimsDescription, UnaryOperationDescription,
};
use burn_tensor::{DType, Device, Element, Shape, TensorData};
use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient};
impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
// Get the runtime client on which to register the operation for execution.
let client = get_client::<R>(device);
let out = client.register_empty_tensor(shape.into(), DType::Bool);
client.register(OperationDescription::BaseBool(
BaseOperationDescription::Empty(out.to_description_out()),
));
out
}
fn bool_shape(tensor: &BoolTensor<Self>) -> Shape {
Shape::from(tensor.shape.clone())
}
async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
tensor.into_data().await
}
fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
let client = get_client::<R>(device);
client.register_tensor_data(data.convert::<bool>())
}
fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
let client = tensor.client.clone();
let out = client.register_empty_tensor(tensor.shape.clone(), IntElem::<Self>::dtype());
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::Bool(
BoolOperationDescription::IntoInt(desc),
));
out
}
fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
let client = tensor.client.clone();
let out = client.register_empty_tensor(tensor.shape.clone(), FloatElem::<Self>::dtype());
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::Bool(
BoolOperationDescription::IntoFloat(desc),
));
out
}
fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
tensor.client.device()
}
fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
if &tensor.client.device() == device {
return tensor;
}
R::change_client_backend(tensor, device)
}
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
let client = tensor.client.clone();
let out = client.register_empty_tensor(shape.into(), tensor.dtype);
let desc = ReshapeDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::BaseBool(
BaseOperationDescription::Reshape(desc),
));
out
}
fn bool_slice(
tensor: BoolTensor<Self>,
ranges: &[core::ops::Range<usize>],
) -> BoolTensor<Self> {
let client = tensor.client.clone();
let ndims = tensor.shape().num_dims();
let mut shape: Vec<usize> = ranges.iter().map(|range| range.end - range.start).collect();
for i in shape.len()..ndims {
shape.push(tensor.shape[i]);
}
let out = client.register_empty_tensor(shape, tensor.dtype);
let desc = SliceOperationDescription {
tensor: tensor.into_description(),
ranges: ranges.to_vec(),
out: out.to_description_out(),
};
client.register(OperationDescription::BaseBool(
BaseOperationDescription::Slice(desc),
));
out
}
fn bool_slice_assign(
tensor: BoolTensor<Self>,
ranges: &[core::ops::Range<usize>],
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
let client = tensor.client.clone();
let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype);
let desc = SliceAssignOperationDescription {
tensor: tensor.into_description(),
ranges: ranges.to_vec(),
value: value.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::BaseBool(
BaseOperationDescription::SliceAssign(desc),
));
out
}
fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
let client = lhs.client.clone();
let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool);
let desc = BinaryOperationDescription {
lhs: lhs.into_description(),
rhs: rhs.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::BaseBool(
BaseOperationDescription::Equal(desc),
));
out
}
fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
let client = tensor.client.clone();
let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype);
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::Bool(BoolOperationDescription::Not(
desc,
)));
out
}
fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
let client = tensor.client.clone();
let mut shape = tensor.shape.clone();
shape[dim1] = tensor.shape[dim2];
shape[dim2] = tensor.shape[dim1];
let out = client.register_empty_tensor(shape, tensor.dtype);
let desc = SwapDimsDescription {
input: tensor.into_description(),
out: out.to_description_out(),
dim1,
dim2,
};
client.register(OperationDescription::BaseBool(
BaseOperationDescription::SwapDims(desc),
));
out
}
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
let client = tensor.client.clone();
// Change the shape of the tensor to match the new axes
let shape = axes.iter().map(|x| tensor.shape[*x]).collect();
let out = client.register_empty_tensor(shape, tensor.dtype);
let desc = PermuteOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
axes: axes.to_vec(),
};
client.register(OperationDescription::BaseBool(
BaseOperationDescription::Permute(desc),
));
out
}
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
let client = tensor.client.clone();
let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype);
let desc = FlipOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
axes: axes.to_vec(),
};
client.register(OperationDescription::BaseBool(
BaseOperationDescription::Flip(desc),
));
out
}
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
let client = tensor.client.clone();
let shape: Vec<_> = shape.into();
let out = client.register_empty_tensor(shape.clone(), tensor.dtype);
let desc = ExpandOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
shape,
};
client.register(OperationDescription::BaseBool(
BaseOperationDescription::Expand(desc),
));
out
}
fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
let tensor_first = tensors.first().unwrap();
let client = tensor_first.client.clone();
let dtype = tensor_first.dtype;
// Calculate the output shape
let mut shape = tensor_first.shape.clone();
shape[dim] = 0;
for tensor in tensors.iter() {
shape[dim] += tensor.shape[dim];
}
let out = client.register_empty_tensor(shape, dtype);
let desc = CatOperationDescription {
tensors: tensors.into_iter().map(|t| t.into_description()).collect(),
dim,
out: out.to_description_out(),
};
client.register(OperationDescription::BaseBool(
BaseOperationDescription::Cat(desc),
));
out
}
fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
let client = tensor.client.clone();
let mut shape = tensor.shape.clone();
shape[dim] *= times;
let out = client.register_empty_tensor(shape, tensor.dtype);
let desc = RepeatDimOperationDescription {
tensor: tensor.into_description(),
dim,
times,
out: out.to_description_out(),
};
client.register(OperationDescription::BaseBool(
BaseOperationDescription::RepeatDim(desc),
));
out
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,821 @@
use alloc::{boxed::Box, vec};
use burn_tensor::ops::conv::{
calculate_conv_output_size, calculate_conv_transpose_output_size, calculate_pool_output_size,
};
use burn_tensor::ops::{
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor,
IntElem, ModuleOps,
};
use burn_tensor::ops::{
IntTensor, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
MaxPool2dWithIndices,
};
use burn_tensor::repr::{
AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription,
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription,
AvgPool2dDescription, Conv1dDescription, Conv2dDescription, Conv3dDescription,
ConvTranspose1dDescription, ConvTranspose2dDescription, ConvTranspose3dDescription,
DeformConv2dBackwardDescription, DeformConv2dDescription, InterpolateBackwardDescription,
InterpolateDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
MaxPool2dWithIndicesDescription, ModuleOperationDescription, OperationDescription,
};
use burn_tensor::Element;
use crate::{BackendRouter, RunnerChannel, RunnerClient};
impl<R: RunnerChannel> ModuleOps<Self> for BackendRouter<R> {
fn conv1d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvOptions<1>,
) -> FloatTensor<Self> {
let size = calculate_conv_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.dilation[0],
x.shape[2],
);
let shape = vec![x.shape[0], weight.shape[0], size];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = Conv1dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::Conv1d(desc),
));
out
}
fn conv2d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvOptions<2>,
) -> FloatTensor<Self> {
let size_0 = calculate_conv_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.dilation[0],
x.shape[2],
);
let size_1 = calculate_conv_output_size(
weight.shape[3],
options.stride[1],
options.padding[1],
options.dilation[1],
x.shape[3],
);
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = Conv2dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::Conv2d(desc),
));
out
}
fn conv3d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvOptions<3>,
) -> FloatTensor<Self> {
let size_0 = calculate_conv_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.dilation[0],
x.shape[2],
);
let size_1 = calculate_conv_output_size(
weight.shape[3],
options.stride[1],
options.padding[1],
options.dilation[1],
x.shape[3],
);
let size_2 = calculate_conv_output_size(
weight.shape[4],
options.stride[2],
options.padding[2],
options.dilation[2],
x.shape[4],
);
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1, size_2];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = Conv3dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::Conv3d(desc),
));
out
}
fn conv_transpose1d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<Self> {
let size = calculate_conv_transpose_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.padding_out[0],
options.dilation[0],
x.shape[2],
);
let shape = vec![x.shape[0], weight.shape[1] * options.groups, size];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = ConvTranspose1dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::ConvTranspose1d(desc),
));
out
}
fn conv_transpose2d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<Self> {
let size_0 = calculate_conv_transpose_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.padding_out[0],
options.dilation[0],
x.shape[2],
);
let size_1 = calculate_conv_transpose_output_size(
weight.shape[3],
options.stride[1],
options.padding[1],
options.padding_out[1],
options.dilation[1],
x.shape[3],
);
let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = ConvTranspose2dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::ConvTranspose2d(desc),
));
out
}
fn conv_transpose3d(
x: FloatTensor<Self>,
weight: FloatTensor<Self>,
bias: Option<FloatTensor<Self>>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<Self> {
let size_0 = calculate_conv_transpose_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.padding_out[0],
options.dilation[0],
x.shape[2],
);
let size_1 = calculate_conv_transpose_output_size(
weight.shape[3],
options.stride[1],
options.padding[1],
options.padding_out[1],
options.dilation[1],
x.shape[3],
);
let size_2 = calculate_conv_transpose_output_size(
weight.shape[4],
options.stride[2],
options.padding[2],
options.padding_out[2],
options.dilation[2],
x.shape[4],
);
let shape = vec![
x.shape[0],
weight.shape[1] * options.groups,
size_0,
size_1,
size_2,
];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = ConvTranspose3dDescription {
x: x.into_description(),
weight: weight.into_description(),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::ConvTranspose3d(desc),
));
out
}
fn avg_pool1d(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
) -> FloatTensor<Self> {
let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]);
let shape = vec![x.shape[0], x.shape[1], size];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = AvgPool1dDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
count_include_pad,
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::AvgPool1d(desc),
));
out
}
fn avg_pool2d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
) -> FloatTensor<Self> {
let size_0 =
calculate_pool_output_size(kernel_size[0], stride[0], padding[0], 1, x.shape[2]);
let size_1 =
calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]);
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = AvgPool2dDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
count_include_pad,
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::AvgPool2d(desc),
));
out
}
fn avg_pool1d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
) -> FloatTensor<Self> {
let client = x.client.clone();
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
let desc = AvgPool1dBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
kernel_size,
stride,
padding,
count_include_pad,
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::AvgPool1dBackward(desc),
));
out
}
fn avg_pool2d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
) -> FloatTensor<Self> {
let client = x.client.clone();
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
let desc = AvgPool2dBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
kernel_size,
stride,
padding,
count_include_pad,
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::AvgPool2dBackward(desc),
));
out
}
fn max_pool1d(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> FloatTensor<Self> {
let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
let shape = vec![x.shape[0], x.shape[1], size];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = MaxPool1dDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::MaxPool1d(desc),
));
out
}
fn max_pool2d(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> FloatTensor<Self> {
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
x.shape[2],
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
x.shape[3],
);
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = MaxPool2dDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::MaxPool2d(desc),
));
out
}
fn max_pool1d_with_indices(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> MaxPool1dWithIndices<Self> {
let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
let shape = vec![x.shape[0], x.shape[1], size];
let client = x.client.clone();
let out = client.register_empty_tensor(shape.clone(), x.dtype);
let out_indices = client.register_empty_tensor(shape, IntElem::<Self>::dtype());
let desc = MaxPool1dWithIndicesDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
out_indices: out_indices.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::MaxPool1dWithIndices(desc),
));
MaxPool1dWithIndices::new(out, out_indices)
}
fn max_pool2d_with_indices(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> MaxPool2dWithIndices<Self> {
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
x.shape[2],
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
x.shape[3],
);
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
let client = x.client.clone();
let out = client.register_empty_tensor(shape.clone(), x.dtype);
let out_indices = client.register_empty_tensor(shape, IntElem::<Self>::dtype());
let desc = MaxPool2dWithIndicesDescription {
x: x.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
out_indices: out_indices.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::MaxPool2dWithIndices(desc),
));
MaxPool2dWithIndices::new(out, out_indices)
}
fn max_pool1d_with_indices_backward(
x: FloatTensor<Self>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
output_grad: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> MaxPool1dBackward<Self> {
let client = x.client.clone();
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
let desc = MaxPool1dWithIndicesBackwardDescription {
x: x.into_description(),
grad: output_grad.into_description(),
indices: indices.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::MaxPool1dWithIndicesBackward(desc),
));
MaxPool1dBackward::new(out)
}
fn max_pool2d_with_indices_backward(
x: FloatTensor<Self>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: FloatTensor<Self>,
indices: IntTensor<Self>,
) -> MaxPool2dBackward<Self> {
let client = x.client.clone();
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
let desc = MaxPool2dWithIndicesBackwardDescription {
x: x.into_description(),
grad: output_grad.into_description(),
indices: indices.into_description(),
kernel_size,
stride,
padding,
dilation,
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::MaxPool2dWithIndicesBackward(desc),
));
MaxPool2dBackward::new(out)
}
fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
let shape = vec![x.shape[0], x.shape[1], output_size];
let client = x.client.clone();
let out = client.register_empty_tensor(shape.clone(), x.dtype);
let desc = AdaptiveAvgPool1dDescription {
x: x.into_description(),
output_size,
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::AdaptiveAvgPool1d(desc),
));
out
}
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
let client = x.client.clone();
let out = client.register_empty_tensor(shape.clone(), x.dtype);
let desc = AdaptiveAvgPool2dDescription {
x: x.into_description(),
output_size,
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::AdaptiveAvgPool2d(desc),
));
out
}
fn adaptive_avg_pool1d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
let desc = AdaptiveAvgPool1dBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc),
));
out
}
fn adaptive_avg_pool2d_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
) -> FloatTensor<Self> {
let client = x.client.clone();
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
let desc = AdaptiveAvgPool2dBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc),
));
out
}
fn interpolate(
x: FloatTensor<Self>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self> {
let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
let client = x.client.clone();
let out = client.register_empty_tensor(shape.clone(), x.dtype);
let desc = InterpolateDescription {
x: x.into_description(),
output_size,
options: options.into(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::Interpolate(desc),
));
out
}
fn interpolate_backward(
x: FloatTensor<Self>,
grad: FloatTensor<Self>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self> {
let client = x.client.clone();
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
let desc = InterpolateBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
output_size,
options: options.into(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::InterpolateBackward(desc),
));
out
}
fn deform_conv2d(
x: FloatTensor<Self>,
offset: FloatTensor<Self>,
weight: FloatTensor<Self>,
mask: Option<FloatTensor<Self>>,
bias: Option<FloatTensor<Self>>,
options: DeformConvOptions<2>,
) -> FloatTensor<Self> {
let size_0 = calculate_conv_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.dilation[0],
x.shape[2],
);
let size_1 = calculate_conv_output_size(
weight.shape[3],
options.stride[1],
options.padding[1],
options.dilation[1],
x.shape[3],
);
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
let client = x.client.clone();
let out = client.register_empty_tensor(shape, x.dtype);
let desc = DeformConv2dDescription {
x: x.into_description(),
offset: offset.into_description(),
weight: weight.into_description(),
mask: mask.map(|mask| mask.into_description()),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::DeformableConv2d(Box::new(desc)),
));
out
}
fn deform_conv2d_backward(
x: FloatTensor<Self>,
offset: FloatTensor<Self>,
weight: FloatTensor<Self>,
mask: Option<FloatTensor<Self>>,
bias: Option<FloatTensor<Self>>,
output_grad: FloatTensor<Self>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
let client = x.client.clone();
let input_grad = client.register_empty_tensor(x.shape.clone(), x.dtype);
let offset_grad = client.register_empty_tensor(offset.shape.clone(), offset.dtype);
let weight_grad = client.register_empty_tensor(weight.shape.clone(), weight.dtype);
let mask_grad = mask
.as_ref()
.map(|mask| client.register_empty_tensor(mask.shape.clone(), mask.dtype));
let bias_grad = bias
.as_ref()
.map(|bias| client.register_empty_tensor(bias.shape.clone(), bias.dtype));
let desc = DeformConv2dBackwardDescription {
x: x.into_description(),
offset: offset.into_description(),
weight: weight.into_description(),
mask: mask.map(|mask| mask.into_description()),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out_grad: output_grad.into_description(),
input_grad: input_grad.to_description_out(),
offset_grad: offset_grad.to_description_out(),
weight_grad: weight_grad.to_description_out(),
mask_grad: mask_grad
.as_ref()
.map(|mask_grad| mask_grad.to_description_out()),
bias_grad: bias_grad
.as_ref()
.map(|bias_grad| bias_grad.to_description_out()),
};
client.register(OperationDescription::Module(
ModuleOperationDescription::DeformableConv2dBackward(Box::new(desc)),
));
DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
}
}

View File

@ -0,0 +1,97 @@
use core::ops::Range;
use burn_tensor::{
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
Device, Shape, TensorData,
};
use crate::{BackendRouter, RunnerChannel};
impl<R: RunnerChannel> QTensorOps<Self> for BackendRouter<R> {
fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {
unimplemented!()
}
fn quantize(
_tensor: FloatTensor<Self>,
_scheme: &QuantizationScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn quantize_dynamic(
_tensor: FloatTensor<Self>,
_scheme: &QuantizationScheme,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
unimplemented!()
}
fn q_shape(_tensor: &QuantizedTensor<Self>) -> Shape {
unimplemented!()
}
fn q_device(_tensor: &QuantizedTensor<Self>) -> Device<Self> {
unimplemented!()
}
fn q_to_device(
_tensor: QuantizedTensor<Self>,
_device: &Device<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
async fn q_into_data(_tensor: QuantizedTensor<Self>) -> TensorData {
unimplemented!()
}
fn q_swap_dims(
_tensor: QuantizedTensor<Self>,
_dim1: usize,
_dim2: usize,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_gather(
_dim: usize,
_tensor: QuantizedTensor<Self>,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_select(
_tensor: QuantizedTensor<Self>,
_dim: usize,
_indices: IntTensor<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_slice(_tensor: QuantizedTensor<Self>, _ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
unimplemented!()
}
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
unimplemented!()
}
}

View File

@ -0,0 +1,129 @@
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, ElementConversion::elem($desc.rhs));
$handles.register_float_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float_dim_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, $desc.rhs);
$handles.register_float_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float2int_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, $desc.rhs);
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_float_cmp_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, ElementConversion::elem($desc.rhs));
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! unary_float_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_float_tensor::<B>(&$desc.input);
let output = $ops(lhs);
$handles.register_float_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_int_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, ElementConversion::elem($desc.rhs));
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! int_float_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, ElementConversion::elem($desc.rhs));
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_int_dim_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, $desc.rhs);
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! scalar_int_cmp_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
let output = $ops(lhs, ElementConversion::elem($desc.rhs));
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
}};
}
#[allow(missing_docs)]
#[macro_export(local_inner_macros)]
macro_rules! unary_int_ops {
(
$handles:expr, $desc:expr, $ops:expr
) => {{
let lhs = $handles.get_int_tensor::<B>(&$desc.input);
let output = $ops(lhs);
$handles.register_int_tensor::<B>(&$desc.out.id, output);
}};
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,122 @@
use alloc::{sync::Arc, vec::Vec};
use super::RunnerClient;
use burn_tensor::{
repr::{TensorDescription, TensorId, TensorStatus},
DType, Shape, TensorData,
};
/// Tensor primitive for the [router backend](crate::BackendRouter).
pub struct RouterTensor<C: RunnerClient> {
pub(crate) id: Arc<TensorId>,
pub(crate) shape: Vec<usize>,
pub(crate) dtype: DType,
pub(crate) client: C,
// Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`.
//
// When a tensor is dropped and is still an orphan, we need to register it as such to avoid
// memory leak.
pub(crate) is_orphan: bool,
}
impl<C: RunnerClient> RouterTensor<C> {
pub(crate) fn new(id: Arc<TensorId>, shape: Vec<usize>, dtype: DType, client: C) -> Self {
Self {
id,
shape,
dtype,
client,
is_orphan: true,
}
}
pub(crate) async fn into_data(self) -> TensorData {
self.client
.clone()
.read_tensor(self.into_description())
.await
}
pub(crate) fn into_description(mut self) -> TensorDescription {
let status = self.status();
let mut shape_out = Vec::new();
core::mem::swap(&mut self.shape, &mut shape_out);
if let TensorStatus::ReadWrite = status {
self.is_orphan = false;
}
TensorDescription {
status,
shape: shape_out,
id: *self.id.as_ref(),
dtype: self.dtype,
}
}
pub(crate) fn to_description_out(&self) -> TensorDescription {
TensorDescription {
status: TensorStatus::NotInit,
shape: self.shape.clone(),
id: *self.id.as_ref(),
dtype: self.dtype,
}
}
pub(crate) fn shape(&self) -> Shape {
Shape::from(self.shape.clone())
}
pub(crate) fn status(&self) -> TensorStatus {
if Arc::strong_count(&self.id) <= 1 {
TensorStatus::ReadWrite
} else {
TensorStatus::ReadOnly
}
}
}
impl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
format!(
"{{ id: {:?}, shape: {:?}, dtype: {:?}, should_drop: {:?}, device: {:?} }}",
self.id,
self.shape,
self.dtype,
self.is_orphan,
self.client.device().clone(),
)
.as_str(),
)
}
}
impl<C: RunnerClient> Clone for RouterTensor<C> {
fn clone(&self) -> Self {
Self {
id: self.id.clone(),
shape: self.shape.clone(),
client: self.client.clone(),
dtype: self.dtype,
is_orphan: self.is_orphan,
}
}
}
impl<C: RunnerClient> Drop for RouterTensor<C> {
fn drop(&mut self) {
if !self.is_orphan {
return;
}
match self.status() {
TensorStatus::ReadWrite => {
self.client.register_orphan(&self.id);
}
TensorStatus::ReadOnly => {}
TensorStatus::NotInit => {}
}
}
}

View File

@ -49,6 +49,9 @@ hashbrown = { workspace = true } # no_std compatible
serde = { workspace = true }
serde_bytes = { workspace = true }
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }
[dev-dependencies]
rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std

View File

@ -4,16 +4,27 @@ use crate::{
quantization::QuantizationScheme,
Shape,
};
use alloc::vec::Vec;
/// A tensor representation containing a reference to a tensor resource with a given shape.
pub struct TensorHandle<H> {
#[derive(Clone)]
pub struct TensorHandle<H: Clone> {
/// The type that can be used to point to a tensor of any kind.
pub handle: H,
/// The shape associated to the tensor.
pub shape: Shape,
}
/// A simple struct to encapsulate a quantized tensor kind.
#[derive(Clone)]
pub struct QuantizedKind<T: Clone> {
/// The quantized tensor.
pub tensor: T,
/// The scaling factor.
pub scale: T,
/// The zero-point offset.
pub offset: Option<T>,
}
/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation
/// for compilation purpose or other...
pub trait ReprBackend: Backend {
@ -28,7 +39,7 @@ pub trait ReprBackend: Backend {
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;
/// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive).
fn quantized_tensor(
handles: Vec<TensorHandle<Self::Handle>>,
handle: QuantizedKind<TensorHandle<Self::Handle>>,
scheme: QuantizationScheme,
) -> QuantizedTensor<Self>;
@ -40,5 +51,33 @@ pub trait ReprBackend: Backend {
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;
/// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle).
/// A quantized tensor has multiple handles for the tensor itself and the quantization parameters.
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Vec<Self::Handle>;
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle>;
}
/// Handle which points to a backend tensor primitive kind.
#[derive(Clone, Debug)]
pub enum HandleKind<B: Backend> {
/// Float tensor handle.
Float(B::FloatTensorPrimitive),
/// Int tensor handle.
Int(B::IntTensorPrimitive),
/// Bool tensor handle.
Bool(B::BoolTensorPrimitive),
/// Quantized tensor handle.
Quantized(B::QuantizedTensorPrimitive),
/// Empty handle (used as a dummy representation).
Empty,
}
impl<B: Backend> HandleKind<B> {
/// Returns the handle kind name.
pub fn name(&self) -> &str {
match self {
HandleKind::Float(_) => "float",
HandleKind::Int(_) => "int",
HandleKind::Bool(_) => "bool",
HandleKind::Quantized(_) => "quantized",
HandleKind::Empty => unreachable!(), // should not happen
}
}
}

View File

@ -5,9 +5,16 @@ use crate::{
},
Shape,
};
use std::{collections::HashMap, sync::Arc};
use alloc::vec::Vec;
use hashbrown::HashMap;
use super::{QuantizedTensorDescription, TensorHandle};
#[cfg(target_has_atomic = "ptr")]
use alloc::sync::Arc;
#[cfg(not(target_has_atomic = "ptr"))]
use portable_atomic_util::Arc;
use super::{QuantizedKind, QuantizedTensorDescription, TensorHandle};
/// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources
/// are used optimally.
@ -19,6 +26,16 @@ pub struct HandleContainer<H> {
pub handles_orphan: Vec<TensorId>,
}
impl<H> core::fmt::Debug for HandleContainer<H> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("HandleContainer")
.field("handles", &self.handles.keys()) // only care about the IDs when debugging
.field("counter", &self.counter)
.field("handles_orphan", &self.handles_orphan)
.finish()
}
}
/// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state
pub enum Handle<H> {
/// No [tensor handle](ReprBackend::Handle) has been created yet
@ -69,7 +86,7 @@ impl<H: Clone> HandleContainer<H> {
}
/// Get the tensor handle for the given [tensor description](TensorDescription).
fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle<H> {
pub fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle<H> {
TensorHandle {
handle: self.get_handle(&tensor.id, &tensor.status),
shape: Shape::from(&tensor.shape),
@ -112,12 +129,14 @@ impl<H: Clone> HandleContainer<H> {
where
B: ReprBackend<Handle = H>,
{
let qtensor = self.get_tensor_handle(&tensor.tensor);
let scale = self.get_tensor_handle(&tensor.qparams.scale);
let handles = if let Some(offset) = &tensor.qparams.offset {
vec![qtensor, scale, self.get_tensor_handle(offset)]
} else {
vec![qtensor, scale]
let handles = QuantizedKind {
tensor: self.get_tensor_handle(&tensor.tensor),
scale: self.get_tensor_handle(&tensor.qparams.scale),
offset: tensor
.qparams
.offset
.as_ref()
.map(|offset| self.get_tensor_handle(offset)),
};
B::quantized_tensor(handles, tensor.scheme.clone())
}
@ -134,20 +153,20 @@ impl<H: Clone> HandleContainer<H> {
/// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).
pub fn register_quantized_tensor<B>(
&mut self,
ids: &[&TensorId],
id: &QuantizedKind<TensorId>,
tensor: B::QuantizedTensorPrimitive,
) where
B: ReprBackend<Handle = H>,
{
let handles = B::quantized_tensor_handle(tensor);
assert_eq!(
ids.len(),
handles.len(),
"Number of tensor ids and handles must match"
);
for (handle, id) in handles.into_iter().zip(ids) {
self.handles.insert(**id, Handle::Existing(handle));
self.handles
.insert(id.tensor, Handle::Existing(handles.tensor));
self.handles
.insert(id.scale, Handle::Existing(handles.scale));
if let (Some(id), Some(handle)) = (id.offset, handles.offset) {
self.handles.insert(id, Handle::Existing(handle));
}
}

View File

@ -1,5 +1,8 @@
use core::ops::Range;
use serde::{Deserialize, Serialize};
use std::ops::Range;
use alloc::boxed::Box;
use alloc::{vec, vec::Vec};
use crate::{
ops::{
@ -214,6 +217,13 @@ pub enum BaseOperationDescription {
Cat(CatOperationDescription),
/// Cast operation, no direct operation and should be supported by fusion backend.
Cast(UnaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [equal](crate::ops::FloatTensorOps::float_empty).
/// Int => [equal](crate::ops::IntTensorOps::int_empty).
/// Bool => [equal](crate::ops::BoolTensorOps::bool_empty).
Empty(TensorDescription),
}
/// Numeric operations on int and float tensors.
@ -1286,6 +1296,7 @@ impl BaseOperationDescription {
}
BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(),
BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out],
BaseOperationDescription::Empty(desc) => vec![desc],
}
}
}
@ -1605,7 +1616,7 @@ impl ModuleOperationDescription {
}
impl core::hash::Hash for RandomOperationDescription {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.out.hash(state);
match self.distribution {
@ -1618,14 +1629,14 @@ impl core::hash::Hash for RandomOperationDescription {
}
impl<E> core::hash::Hash for ScalarOperationDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.lhs.hash(state);
self.out.hash(state);
}
}
impl<E> core::hash::Hash for MaskFillOperationDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.tensor.hash(state);
self.mask.hash(state);
self.out.hash(state);
@ -1633,14 +1644,14 @@ impl<E> core::hash::Hash for MaskFillOperationDescription<E> {
}
impl<E> core::hash::Hash for ClampOperationDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.tensor.hash(state);
self.out.hash(state);
}
}
impl<E> core::hash::Hash for NumericOperationDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
match self {
NumericOperationDescription::Add(desc) => desc.hash(state),
NumericOperationDescription::AddScalar(desc) => desc.hash(state),

View File

@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
use alloc::vec::Vec;
use crate::DType;
/// The tensor unique identifier.

View File

@ -59,8 +59,8 @@ pub trait Backend:
+ ActivationOps<Self>
+ QTensorOps<Self>
+ Clone
+ Sized
+ Default
+ Sized
+ Send
+ Sync
+ core::fmt::Debug

View File

@ -258,3 +258,19 @@ pub enum DType {
Bool,
QFloat(QuantizationStrategy),
}
impl DType {
/// Returns true if the data type is a floating point type.
pub fn is_float(&self) -> bool {
matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)
}
/// Returns true if the data type is a signed integer type.
pub fn is_int(&self) -> bool {
matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
}
/// Returns true if the data type is a boolean type
pub fn is_bool(&self) -> bool {
matches!(self, DType::Bool)
}
}

View File

@ -0,0 +1,12 @@
use alloc::vec::Vec;
/// Computes the output shape for binary operations with broadcasting support.
pub fn binary_ops_shape(lhs: &[usize], rhs: &[usize]) -> Vec<usize> {
let mut shape_out = Vec::with_capacity(lhs.len());
for (l, r) in lhs.iter().zip(rhs.iter()) {
shape_out.push(usize::max(*l, *r));
}
shape_out
}

View File

@ -1,5 +1,6 @@
mod activation;
mod alias;
mod binary;
mod bool_tensor;
mod int_tensor;
mod modules;
@ -8,6 +9,7 @@ mod tensor;
pub use activation::*;
pub use alias::*;
pub use binary::*;
pub use bool_tensor::*;
pub use int_tensor::*;
pub use modules::*;