mirror of https://github.com/tracel-ai/burn.git
Add `BackendRouter` to handle multiple backends (#2353)
* WIP * it compiles * WIP * Remove const D (w/ rebase) + WIP BackendRouter * Add missing types from merge * Rework traits, types and add MultiBackendBridge & RunnerClientLocator (WIP) * First draft ByteBridge to_backend(tensor, device) * Refactor into modules * Add mutex, fix types * Remove StreamId and implement ReprBackend for Fusion (WIP) * float_add op working (w/o fusion) * Small cleanup * Remove comment * Cleanup * Fix fusion ReprBackend implementation (duhhh) * Add runner ops * More ops * Cleanup * Add name * Update Cargo.lock * Undo fusion stream changes to common * Clippy + cleanup * Fix no-std * Deal with unused tensors * Clippy baby * Fix comment * Fix tensor handle orphans management * Implement runner read_tensor for other dtypes * Move backend router to its own crate * Refactor repr quantized tensor handle * Fix typo * Implement repr backend for ndarray * Add router tests w/ ndarray and wgpu backends (+ fix tests) * Add empty base operation and autodiff tests * Add precision bridge * Remove dep from local changes * Apply same mask_where broadcast fix * Add float and int elem associative types (should match for each backend) * Add simple byte bridge test * Remove comment * Remove bridge types * Screw you windows * Remove dead code * Add unreachable message * Add seed * Fix clippy * Set the seed anytime a new client is initialized --------- Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
This commit is contained in:
parent
604dbae57d
commit
eaf50e617c
|
@ -706,6 +706,18 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-router"
|
||||
version = "0.15.0"
|
||||
dependencies = [
|
||||
"burn-autodiff",
|
||||
"burn-ndarray",
|
||||
"burn-tensor",
|
||||
"burn-wgpu",
|
||||
"hashbrown 0.14.5",
|
||||
"spin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-tch"
|
||||
version = "0.15.0"
|
||||
|
@ -732,6 +744,7 @@ dependencies = [
|
|||
"half",
|
||||
"hashbrown 0.14.5",
|
||||
"num-traits",
|
||||
"portable-atomic-util",
|
||||
"rand",
|
||||
"rand_distr",
|
||||
"serde",
|
||||
|
|
|
@ -4,8 +4,8 @@ use crate::{
|
|||
};
|
||||
use burn_tensor::{
|
||||
backend::{Backend, DeviceOps},
|
||||
ops::FloatTensor,
|
||||
repr::{OperationDescription, ReprBackend},
|
||||
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
|
||||
repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle},
|
||||
Device,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
@ -166,3 +166,43 @@ pub trait FusionBackend:
|
|||
/// Pointer to the full precision fusion backend.
|
||||
type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;
|
||||
}
|
||||
|
||||
// Fusion implements `ReprBackend` to enable router backend usage.
|
||||
impl<B: FusionBackend> ReprBackend for Fusion<B> {
|
||||
type Handle = FusionTensor<B::FusionRuntime>;
|
||||
|
||||
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
|
||||
handle.handle
|
||||
}
|
||||
|
||||
fn quantized_tensor(
|
||||
_handles: QuantizedKind<TensorHandle<Self::Handle>>,
|
||||
_scheme: burn_tensor::quantization::QuantizationScheme,
|
||||
) -> QuantizedTensor<Self> {
|
||||
todo!() // not as simple
|
||||
}
|
||||
|
||||
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn quantized_tensor_handle(_tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle> {
|
||||
todo!() // not as simple
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,16 +73,6 @@ macro_rules! binary_int_cmp_ops {
|
|||
};
|
||||
}
|
||||
|
||||
pub(crate) fn binary_ops_shape(lhs: &[usize], rhs: &[usize]) -> Vec<usize> {
|
||||
let mut shape_out = Vec::with_capacity(lhs.len());
|
||||
|
||||
for (l, r) in lhs.iter().zip(rhs.iter()) {
|
||||
shape_out.push(usize::max(*l, *r));
|
||||
}
|
||||
|
||||
shape_out
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_int_ops {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use burn_tensor::{
|
||||
ops::{FloatTensor, IntTensor},
|
||||
ops::{binary_ops_shape, FloatTensor, IntTensor},
|
||||
DType, Element, TensorData,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
@ -7,7 +7,6 @@ use std::marker::PhantomData;
|
|||
use crate::{
|
||||
client::FusionClient,
|
||||
get_client,
|
||||
ops::binary::binary_ops_shape,
|
||||
stream::{execution::Operation, StreamId},
|
||||
Fusion, FusionBackend,
|
||||
};
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
use crate::{
|
||||
binary_float_cmp_ops, binary_float_ops,
|
||||
client::FusionClient,
|
||||
get_client,
|
||||
ops::binary::binary_ops_shape,
|
||||
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
|
||||
get_client, scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
|
||||
stream::{execution::Operation, StreamId},
|
||||
unary_float_ops, Fusion, FusionBackend,
|
||||
};
|
||||
use burn_tensor::{
|
||||
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
|
||||
ops::{binary_ops_shape, BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
|
||||
repr::*,
|
||||
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData,
|
||||
};
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
use crate::{
|
||||
binary_int_cmp_ops, binary_int_ops,
|
||||
client::FusionClient,
|
||||
get_client,
|
||||
ops::binary::binary_ops_shape,
|
||||
scalar_int_cmp_ops, scalar_int_ops,
|
||||
get_client, scalar_int_cmp_ops, scalar_int_ops,
|
||||
stream::{execution::Operation, StreamId},
|
||||
unary_int_ops, Fusion, FusionBackend,
|
||||
};
|
||||
use burn_tensor::{
|
||||
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
|
||||
ops::{binary_ops_shape, BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
|
||||
repr::{self, *},
|
||||
DType, Device, Distribution, Element, ElementConversion, Shape, TensorData,
|
||||
};
|
||||
|
|
|
@ -6,6 +6,7 @@ use burn_tensor::{
|
|||
repr::{
|
||||
DequantizeOperationDescription, FloatOperationDescription, HandleContainer,
|
||||
OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription,
|
||||
QuantizedKind,
|
||||
},
|
||||
DType, Device, Element, Shape, TensorData,
|
||||
};
|
||||
|
@ -25,19 +26,17 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
|||
let tensor = B::q_from_data(data, device);
|
||||
let shape = B::q_shape(&tensor);
|
||||
|
||||
let mut handles = B::quantized_tensor_handle(tensor);
|
||||
let handles = B::quantized_tensor_handle(tensor);
|
||||
let qparams = match strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(_) => {
|
||||
let num_handles = handles.len();
|
||||
assert_eq!(
|
||||
num_handles, 3,
|
||||
"Expected 3 handles for quantized tensor, got {num_handles}"
|
||||
);
|
||||
let offset = handles.pop().unwrap();
|
||||
let scale = handles.pop().unwrap();
|
||||
let offset = if let Some(offset) = handles.offset {
|
||||
offset
|
||||
} else {
|
||||
panic!("Expected offset for quantized tensor.");
|
||||
};
|
||||
FusionQuantizationParameters {
|
||||
scale: client.register_tensor(
|
||||
scale,
|
||||
handles.scale,
|
||||
vec![1],
|
||||
StreamId::current(),
|
||||
B::FloatElem::dtype(),
|
||||
|
@ -51,15 +50,13 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
}
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
|
||||
let num_handles = handles.len();
|
||||
assert_eq!(
|
||||
num_handles, 2,
|
||||
"Expected 2 handles for quantized tensor, got {num_handles}"
|
||||
assert!(
|
||||
handles.offset.is_none(),
|
||||
"Offset should not be provided for symmetric quantization."
|
||||
);
|
||||
let scale = handles.pop().unwrap();
|
||||
FusionQuantizationParameters {
|
||||
scale: client.register_tensor(
|
||||
scale,
|
||||
handles.scale,
|
||||
vec![1],
|
||||
StreamId::current(),
|
||||
B::FloatElem::dtype(),
|
||||
|
@ -69,7 +66,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
|||
}
|
||||
};
|
||||
let qtensor = client.register_tensor(
|
||||
handles.pop().unwrap(),
|
||||
handles.tensor,
|
||||
shape.dims,
|
||||
StreamId::current(),
|
||||
B::QuantizedEncoding::dtype(),
|
||||
|
@ -111,17 +108,20 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let qparams = QuantizationParametersPrimitive { scale, offset };
|
||||
let output = B::quantize(tensor, &self.desc.scheme, qparams);
|
||||
if let Some(offset) = &self.desc.qparams.offset {
|
||||
handles.register_quantized_tensor::<B>(
|
||||
&[&self.desc.out.id, &self.desc.qparams.scale.id, &offset.id],
|
||||
output,
|
||||
);
|
||||
let q_ids = if let Some(offset) = &self.desc.qparams.offset {
|
||||
QuantizedKind {
|
||||
tensor: self.desc.out.id,
|
||||
scale: self.desc.qparams.scale.id,
|
||||
offset: Some(offset.id),
|
||||
}
|
||||
} else {
|
||||
handles.register_quantized_tensor::<B>(
|
||||
&[&self.desc.out.id, &self.desc.qparams.scale.id],
|
||||
output,
|
||||
);
|
||||
}
|
||||
QuantizedKind {
|
||||
tensor: self.desc.out.id,
|
||||
scale: self.desc.qparams.scale.id,
|
||||
offset: None,
|
||||
}
|
||||
};
|
||||
handles.register_quantized_tensor::<B>(&q_ids, output);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,8 @@ use crate::{
|
|||
FusionBackend, FusionRuntime,
|
||||
};
|
||||
use burn_tensor::repr::{
|
||||
HandleContainer, OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId,
|
||||
HandleContainer, OperationDescription, QuantizedKind, QuantizedTensorDescription,
|
||||
TensorDescription, TensorId,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -183,18 +184,28 @@ where
|
|||
let scale_id = server_device.create_empty_handle();
|
||||
let offset_id = server_device.create_empty_handle();
|
||||
|
||||
let q_ids = QuantizedKind {
|
||||
tensor: *tensor_id,
|
||||
scale: *scale_id,
|
||||
offset: Some(*offset_id),
|
||||
};
|
||||
server_device
|
||||
.handles
|
||||
.register_quantized_tensor::<B>(&[&tensor_id, &scale_id, &offset_id], tensor);
|
||||
.register_quantized_tensor::<B>(&q_ids, tensor);
|
||||
|
||||
vec![tensor_id, scale_id, offset_id]
|
||||
} else {
|
||||
let tensor_id = server_device.create_empty_handle();
|
||||
let scale_id = server_device.create_empty_handle();
|
||||
|
||||
let q_ids = QuantizedKind {
|
||||
tensor: *tensor_id,
|
||||
scale: *scale_id,
|
||||
offset: None,
|
||||
};
|
||||
server_device
|
||||
.handles
|
||||
.register_quantized_tensor::<B>(&[&tensor_id, &scale_id], tensor);
|
||||
.register_quantized_tensor::<B>(&q_ids, tensor);
|
||||
|
||||
vec![tensor_id, scale_id]
|
||||
}
|
||||
|
|
|
@ -968,6 +968,9 @@ impl RelativeOps for BaseOperationDescription {
|
|||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
BaseOperationDescription::Empty(desc) => {
|
||||
BaseOperationDescription::Empty(desc.to_relative(converter))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ template = []
|
|||
burn-common = { path = "../burn-common", version = "0.15.0" }
|
||||
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = [
|
||||
"cubecl",
|
||||
"cubecl", "repr",
|
||||
] }
|
||||
cubecl = { workspace = true, features = ["linalg"] }
|
||||
|
||||
|
|
|
@ -7,6 +7,13 @@ use cubecl::server::ComputeServer;
|
|||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use std::{marker::PhantomData, sync::Mutex};
|
||||
|
||||
#[cfg(not(feature = "fusion"))]
|
||||
use burn_tensor::{
|
||||
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
|
||||
quantization::QuantizationScheme,
|
||||
repr::{HandleKind, QuantizedKind, ReprBackend, TensorHandle},
|
||||
};
|
||||
|
||||
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||
|
||||
/// Generic tensor backend that can be compiled just-in-time to any shader runtime
|
||||
|
@ -82,3 +89,62 @@ where
|
|||
type JitDevice = R::Device;
|
||||
type JitServer = R::Server;
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "fusion"))]
|
||||
impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R, F, I> {
|
||||
type Handle = HandleKind<Self>;
|
||||
|
||||
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
|
||||
match handle.handle {
|
||||
HandleKind::Float(handle) => handle,
|
||||
_ => panic!("Expected float handle, got {}", handle.handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
|
||||
match handle.handle {
|
||||
HandleKind::Int(handle) => handle,
|
||||
_ => panic!("Expected int handle, got {}", handle.handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
|
||||
match handle.handle {
|
||||
HandleKind::Bool(handle) => handle,
|
||||
_ => panic!("Expected bool handle, got {}", handle.handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn quantized_tensor(
|
||||
handles: QuantizedKind<TensorHandle<Self::Handle>>,
|
||||
_scheme: QuantizationScheme,
|
||||
) -> QuantizedTensor<Self> {
|
||||
let handle = handles.tensor.handle;
|
||||
match handle {
|
||||
HandleKind::Quantized(handle) => handle,
|
||||
_ => panic!("Expected quantized handle, got {}", handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
|
||||
HandleKind::Float(tensor)
|
||||
}
|
||||
|
||||
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
|
||||
HandleKind::Int(tensor)
|
||||
}
|
||||
|
||||
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
|
||||
HandleKind::Bool(tensor)
|
||||
}
|
||||
|
||||
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle> {
|
||||
QuantizedKind {
|
||||
tensor: HandleKind::Quantized(tensor),
|
||||
// The quantized tensor primitive already encapsulates the required quantization
|
||||
// parameters so we set the scale as an empty handle (unused).
|
||||
scale: HandleKind::Empty,
|
||||
offset: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
|||
};
|
||||
use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
|
||||
use burn_tensor::quantization::QuantizationScheme;
|
||||
use burn_tensor::repr::TensorHandle;
|
||||
use burn_tensor::repr::{QuantizedKind, TensorHandle};
|
||||
use burn_tensor::{repr::ReprBackend, Shape};
|
||||
use core::marker::PhantomData;
|
||||
use cubecl::client::ComputeClient;
|
||||
|
@ -79,41 +79,22 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R
|
|||
}
|
||||
|
||||
fn quantized_tensor(
|
||||
handles: Vec<TensorHandle<Self::Handle>>,
|
||||
handles: QuantizedKind<TensorHandle<Self::Handle>>,
|
||||
scheme: QuantizationScheme,
|
||||
) -> burn_tensor::ops::QuantizedTensor<Self> {
|
||||
match handles.len() {
|
||||
// NOTE: the order of the handles is known [qtensor, scale, <offset>]
|
||||
3 => {
|
||||
let mut handles = handles;
|
||||
let offset = handles.pop().unwrap();
|
||||
let scale = handles.pop().unwrap();
|
||||
let qtensor = handles.pop().unwrap();
|
||||
QJitTensor {
|
||||
qtensor: qtensor.handle.into_tensor(qtensor.shape),
|
||||
scheme,
|
||||
qparams: JitQuantizationParameters {
|
||||
scale: scale.handle.into_tensor(scale.shape),
|
||||
offset: Some(offset.handle.into_tensor(offset.shape)),
|
||||
},
|
||||
}
|
||||
}
|
||||
2 => {
|
||||
let mut handles = handles;
|
||||
let scale = handles.pop().unwrap();
|
||||
let qtensor = handles.pop().unwrap();
|
||||
QJitTensor {
|
||||
qtensor: qtensor.handle.into_tensor(qtensor.shape),
|
||||
scheme,
|
||||
qparams: JitQuantizationParameters {
|
||||
scale: scale.handle.into_tensor(scale.shape),
|
||||
offset: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("Expected handles for the quantized tensor and its quantization parameters.")
|
||||
}
|
||||
let qtensor = handles.tensor.handle.into_tensor(handles.tensor.shape);
|
||||
let scale = handles.scale.handle.into_tensor(handles.scale.shape);
|
||||
let offset = handles.offset;
|
||||
|
||||
let qparams = JitQuantizationParameters {
|
||||
scale,
|
||||
offset: offset.map(|h| h.handle.into_tensor(h.shape)),
|
||||
};
|
||||
|
||||
QJitTensor {
|
||||
qtensor,
|
||||
scheme,
|
||||
qparams,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -131,14 +112,14 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R
|
|||
|
||||
fn quantized_tensor_handle(
|
||||
tensor: burn_tensor::ops::QuantizedTensor<Self>,
|
||||
) -> Vec<Self::Handle> {
|
||||
) -> QuantizedKind<Self::Handle> {
|
||||
let qtensor: JitFusionHandle<R> = tensor.qtensor.into();
|
||||
let scale: JitFusionHandle<R> = tensor.qparams.scale.into();
|
||||
if let Some(offset) = tensor.qparams.offset {
|
||||
let offset: JitFusionHandle<R> = offset.into();
|
||||
vec![qtensor, scale, offset]
|
||||
} else {
|
||||
vec![qtensor, scale]
|
||||
|
||||
QuantizedKind {
|
||||
tensor: qtensor,
|
||||
scale,
|
||||
offset: tensor.qparams.offset.map(|offset| offset.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ blas-openblas-system = [
|
|||
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", optional = true }
|
||||
burn-common = { path = "../burn-common", version = "0.15.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false, features = ["repr"] }
|
||||
|
||||
atomic_float = { workspace = true }
|
||||
blas-src = { workspace = true, default-features = false, optional = true } # no-std compatible
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
use crate::element::{FloatNdArrayElement, QuantElement};
|
||||
use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
|
||||
use crate::PrecisionBridge;
|
||||
use crate::{NdArrayQTensor, NdArrayTensor};
|
||||
use alloc::string::String;
|
||||
use burn_common::stub::Mutex;
|
||||
use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
|
||||
use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
|
||||
use burn_tensor::quantization::QuantizationScheme;
|
||||
use burn_tensor::repr::{HandleKind, QuantizedKind, ReprBackend, TensorHandle};
|
||||
use core::marker::PhantomData;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
|
@ -35,20 +38,21 @@ impl Default for NdArrayDevice {
|
|||
/// This backend is compatible with CPUs and can be compiled for almost any platform, including
|
||||
/// `wasm`, `arm`, and `x86`.
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
pub struct NdArray<E = f32, Q = i8> {
|
||||
pub struct NdArray<E = f32, I = i64, Q = i8> {
|
||||
_e: PhantomData<E>,
|
||||
_i: PhantomData<I>,
|
||||
_q: PhantomData<Q>,
|
||||
}
|
||||
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> Backend for NdArray<E, Q> {
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q> {
|
||||
type Device = NdArrayDevice;
|
||||
type FullPrecisionBridge = PrecisionBridge<f32>;
|
||||
|
||||
type FloatTensorPrimitive = NdArrayTensor<E>;
|
||||
type FloatElem = E;
|
||||
|
||||
type IntTensorPrimitive = NdArrayTensor<i64>;
|
||||
type IntElem = i64;
|
||||
type IntTensorPrimitive = NdArrayTensor<I>;
|
||||
type IntElem = I;
|
||||
|
||||
type BoolTensorPrimitive = NdArrayTensor<bool>;
|
||||
|
||||
|
@ -69,3 +73,63 @@ impl<E: FloatNdArrayElement, Q: QuantElement> Backend for NdArray<E, Q> {
|
|||
*seed = Some(rng);
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ReprBackend
|
||||
for NdArray<E, I, Q>
|
||||
{
|
||||
type Handle = HandleKind<Self>;
|
||||
|
||||
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
|
||||
match handle.handle {
|
||||
HandleKind::Float(handle) => handle,
|
||||
_ => panic!("Expected float handle, got {}", handle.handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
|
||||
match handle.handle {
|
||||
HandleKind::Int(handle) => handle,
|
||||
_ => panic!("Expected int handle, got {}", handle.handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
|
||||
match handle.handle {
|
||||
HandleKind::Bool(handle) => handle,
|
||||
_ => panic!("Expected bool handle, got {}", handle.handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn quantized_tensor(
|
||||
handles: QuantizedKind<TensorHandle<Self::Handle>>,
|
||||
_scheme: QuantizationScheme,
|
||||
) -> QuantizedTensor<Self> {
|
||||
let handle = handles.tensor.handle;
|
||||
match handle {
|
||||
HandleKind::Quantized(handle) => handle,
|
||||
_ => panic!("Expected quantized handle, got {}", handle.name()),
|
||||
}
|
||||
}
|
||||
|
||||
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
|
||||
HandleKind::Float(tensor)
|
||||
}
|
||||
|
||||
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
|
||||
HandleKind::Int(tensor)
|
||||
}
|
||||
|
||||
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
|
||||
HandleKind::Bool(tensor)
|
||||
}
|
||||
|
||||
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle> {
|
||||
QuantizedKind {
|
||||
tensor: HandleKind::Quantized(tensor),
|
||||
// The quantized tensor primitive already encapsulates the required quantization
|
||||
// parameters so we set the scale as an empty handle (unused).
|
||||
scale: HandleKind::Empty,
|
||||
offset: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use crate::{element::QuantElement, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor};
|
||||
use crate::{
|
||||
element::{IntNdArrayElement, QuantElement},
|
||||
FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor,
|
||||
};
|
||||
use burn_tensor::{backend::BackendBridge, ops::FloatTensor};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
|
@ -8,13 +11,15 @@ pub struct PrecisionBridge<E: FloatNdArrayElement> {
|
|||
_e: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<TElem, OElem, QElem> BackendBridge<NdArray<OElem, QElem>> for PrecisionBridge<TElem>
|
||||
impl<TElem, OElem, QElem, IntElem> BackendBridge<NdArray<OElem, IntElem, QElem>>
|
||||
for PrecisionBridge<TElem>
|
||||
where
|
||||
TElem: FloatNdArrayElement,
|
||||
OElem: FloatNdArrayElement,
|
||||
QElem: QuantElement,
|
||||
IntElem: IntNdArrayElement,
|
||||
{
|
||||
type Target = NdArray<TElem>;
|
||||
type Target = NdArray<TElem, IntElem, QElem>;
|
||||
|
||||
fn into_target(
|
||||
tensor: FloatTensor<NdArray<OElem>>,
|
||||
|
|
|
@ -16,6 +16,8 @@ where
|
|||
{
|
||||
}
|
||||
|
||||
pub trait IntNdArrayElement: NdArrayElement + core::ops::Rem<Output = Self> + Signed {}
|
||||
|
||||
/// A general element for ndarray backend.
|
||||
pub trait NdArrayElement:
|
||||
Element
|
||||
|
@ -49,6 +51,9 @@ impl QuantElement for i8 {}
|
|||
impl FloatNdArrayElement for f64 {}
|
||||
impl FloatNdArrayElement for f32 {}
|
||||
|
||||
impl IntNdArrayElement for i64 {}
|
||||
impl IntNdArrayElement for i32 {}
|
||||
|
||||
macro_rules! make_elem {
|
||||
(
|
||||
double
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
use crate::{
|
||||
element::{FloatNdArrayElement, QuantElement},
|
||||
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
|
||||
tensor::NdArrayTensor,
|
||||
NdArray,
|
||||
};
|
||||
use burn_tensor::{ops::ActivationOps, ElementConversion};
|
||||
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> ActivationOps<Self> for NdArray<E, Q> {
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ActivationOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
{
|
||||
fn relu(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
|
||||
let zero = 0.elem();
|
||||
let array = tensor
|
||||
|
|
|
@ -261,10 +261,10 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn gather(
|
||||
pub fn gather<I: NdArrayElement>(
|
||||
dim: usize,
|
||||
mut tensor: NdArrayTensor<E>,
|
||||
mut indices: NdArrayTensor<i64>,
|
||||
mut indices: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<E> {
|
||||
let ndims = tensor.shape().num_dims();
|
||||
if dim != ndims - 1 {
|
||||
|
@ -284,7 +284,7 @@ where
|
|||
let indices = indices.slice(s!(b, ..));
|
||||
|
||||
for (i, index) in indices.iter().enumerate() {
|
||||
output[[b, i]] = tensor[[b, *index as usize]];
|
||||
output[[b, i]] = tensor[[b, index.elem::<i64>() as usize]];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -300,10 +300,10 @@ where
|
|||
output
|
||||
}
|
||||
|
||||
pub fn scatter(
|
||||
pub fn scatter<I: NdArrayElement>(
|
||||
dim: usize,
|
||||
mut tensor: NdArrayTensor<E>,
|
||||
mut indices: NdArrayTensor<i64>,
|
||||
mut indices: NdArrayTensor<I>,
|
||||
mut value: NdArrayTensor<E>,
|
||||
) -> NdArrayTensor<E> {
|
||||
let ndims = tensor.shape().num_dims();
|
||||
|
@ -338,7 +338,7 @@ where
|
|||
let indices = indices.slice(s!(b, ..));
|
||||
|
||||
for (i, index) in indices.iter().enumerate() {
|
||||
let index = *index as usize;
|
||||
let index = index.elem::<i64>() as usize;
|
||||
tensor[[b, index]] += value[[b, i]];
|
||||
}
|
||||
}
|
||||
|
@ -403,33 +403,33 @@ where
|
|||
batch_size
|
||||
}
|
||||
|
||||
pub fn select(
|
||||
pub fn select<I: NdArrayElement>(
|
||||
tensor: NdArrayTensor<E>,
|
||||
dim: usize,
|
||||
indices: NdArrayTensor<i64>,
|
||||
indices: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<E> {
|
||||
let array = tensor.array.select(
|
||||
Axis(dim),
|
||||
&indices
|
||||
.array
|
||||
.into_iter()
|
||||
.map(|i| i as usize)
|
||||
.map(|i| i.elem::<i64>() as usize)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
NdArrayTensor::new(array.into_shared())
|
||||
}
|
||||
|
||||
pub fn select_assign(
|
||||
pub fn select_assign<I: NdArrayElement>(
|
||||
tensor: NdArrayTensor<E>,
|
||||
dim: usize,
|
||||
indices: NdArrayTensor<i64>,
|
||||
indices: NdArrayTensor<I>,
|
||||
value: NdArrayTensor<E>,
|
||||
) -> NdArrayTensor<E> {
|
||||
let mut output_array = tensor.array.into_owned();
|
||||
|
||||
for (index_value, index) in indices.array.into_iter().enumerate() {
|
||||
let mut view = output_array.index_axis_mut(Axis(dim), index as usize);
|
||||
let mut view = output_array.index_axis_mut(Axis(dim), index.elem::<i64>() as usize);
|
||||
let value = value.array.index_axis(Axis(dim), index_value);
|
||||
|
||||
view.zip_mut_with(&value, |a, b| *a += *b);
|
||||
|
@ -437,11 +437,11 @@ where
|
|||
|
||||
NdArrayTensor::new(output_array.into_shared())
|
||||
}
|
||||
pub fn argmax(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<i64> {
|
||||
pub fn argmax<I: NdArrayElement>(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
|
||||
arg(tensor, dim, CmpType::Max)
|
||||
}
|
||||
|
||||
pub fn argmin(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<i64> {
|
||||
pub fn argmin<I: NdArrayElement>(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
|
||||
arg(tensor, dim, CmpType::Min)
|
||||
}
|
||||
|
||||
|
@ -523,11 +523,11 @@ enum CmpType {
|
|||
Max,
|
||||
}
|
||||
|
||||
fn arg<E: NdArrayElement>(
|
||||
fn arg<E: NdArrayElement, I: NdArrayElement>(
|
||||
tensor: NdArrayTensor<E>,
|
||||
dim: usize,
|
||||
cmp: CmpType,
|
||||
) -> NdArrayTensor<i64> {
|
||||
) -> NdArrayTensor<I> {
|
||||
let mut reshape = tensor.array.shape().to_vec();
|
||||
reshape[dim] = 1;
|
||||
|
||||
|
@ -546,7 +546,7 @@ fn arg<E: NdArrayElement>(
|
|||
}
|
||||
});
|
||||
|
||||
idx as i64
|
||||
(idx as i64).elem()
|
||||
});
|
||||
|
||||
let output = output.to_shape(Dim(reshape.as_slice())).unwrap();
|
||||
|
|
|
@ -7,7 +7,7 @@ use core::ops::Range;
|
|||
use ndarray::{IntoDimension, Zip};
|
||||
|
||||
// Current crate
|
||||
use crate::element::{FloatNdArrayElement, QuantElement};
|
||||
use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
|
||||
use crate::NdArrayDevice;
|
||||
use crate::{tensor::NdArrayTensor, NdArray};
|
||||
|
||||
|
@ -16,7 +16,9 @@ use burn_tensor::{backend::Backend, Shape, TensorData};
|
|||
|
||||
use super::NdArrayOps;
|
||||
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> BoolTensorOps<Self> for NdArray<E, Q> {
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
{
|
||||
fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<bool> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
@ -43,11 +45,11 @@ impl<E: FloatNdArrayElement, Q: QuantElement> BoolTensorOps<Self> for NdArray<E,
|
|||
NdArrayOps::slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn bool_into_int(tensor: NdArrayTensor<bool>) -> NdArrayTensor<i64> {
|
||||
fn bool_into_int(tensor: NdArrayTensor<bool>) -> NdArrayTensor<I> {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
NdArray::<E>::int_from_data(
|
||||
TensorData::new(values, shape).convert::<i64>(),
|
||||
NdArray::<E, I>::int_from_data(
|
||||
TensorData::new(values, shape).convert::<I>(),
|
||||
&NdArrayDevice::Cpu,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ use ndarray::{
|
|||
};
|
||||
|
||||
use crate::{
|
||||
element::{FloatNdArrayElement, QuantElement},
|
||||
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
|
||||
ops::padding::{apply_padding_4d, apply_padding_5d},
|
||||
sharing::UnsafeSharedRef,
|
||||
tensor::NdArrayTensor,
|
||||
|
@ -98,7 +98,7 @@ fn conv3d_mad_inner<E: FloatNdArrayElement>(
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
pub(crate) fn conv2d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E>,
|
||||
weight: NdArrayTensor<E>,
|
||||
bias: Option<NdArrayTensor<E>>,
|
||||
|
@ -126,7 +126,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
in_width,
|
||||
);
|
||||
|
||||
let x = apply_padding_4d::<E, Q>(x, options.padding, 0i32.elem()).array;
|
||||
let x = apply_padding_4d::<E, I, Q>(x, options.padding, 0i32.elem()).array;
|
||||
|
||||
// Convert inputs from dynamic indexes to static to improve perf.
|
||||
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
|
@ -310,7 +310,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
|||
NdArrayTensor::new(output.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
pub(crate) fn conv3d<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
pub(crate) fn conv3d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E>,
|
||||
weight: NdArrayTensor<E>,
|
||||
bias: Option<NdArrayTensor<E>>,
|
||||
|
@ -345,7 +345,7 @@ pub(crate) fn conv3d<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
in_width,
|
||||
);
|
||||
|
||||
let x = apply_padding_5d::<E, Q>(x, options.padding, 0i32.elem()).array;
|
||||
let x = apply_padding_5d::<E, I, Q>(x, options.padding, 0i32.elem()).array;
|
||||
|
||||
// Convert inputs from dynamic indexes to static to improve perf.
|
||||
let x = x.into_dimensionality::<ndarray::Ix5>().unwrap();
|
||||
|
|
|
@ -252,7 +252,7 @@ pub mod backward {
|
|||
#[cfg(target_has_atomic = "32")]
|
||||
use core::sync::atomic::Ordering;
|
||||
|
||||
use crate::NdArray;
|
||||
use crate::{element::IntNdArrayElement, NdArray};
|
||||
use atomic_float::AtomicF32;
|
||||
use burn_tensor::ops::DeformConv2dBackward;
|
||||
use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4};
|
||||
|
@ -260,7 +260,11 @@ pub mod backward {
|
|||
use super::*;
|
||||
|
||||
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
|
||||
pub(crate) fn deform_conv2d_backward<F: FloatNdArrayElement, Q: QuantElement>(
|
||||
pub(crate) fn deform_conv2d_backward<
|
||||
F: FloatNdArrayElement,
|
||||
I: IntNdArrayElement,
|
||||
Q: QuantElement,
|
||||
>(
|
||||
input: NdArrayTensor<F>,
|
||||
offset: NdArrayTensor<F>,
|
||||
weight: NdArrayTensor<F>,
|
||||
|
@ -268,7 +272,7 @@ pub mod backward {
|
|||
bias: Option<NdArrayTensor<F>>,
|
||||
out_grad: NdArrayTensor<F>,
|
||||
args: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<NdArray<F, Q>> {
|
||||
) -> DeformConv2dBackward<NdArray<F, I, Q>> {
|
||||
let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims();
|
||||
let [_, _, kernel_h, kernel_w] = weight.shape().dims();
|
||||
let groups = args.weight_groups;
|
||||
|
|
|
@ -11,8 +11,8 @@ use ndarray::IntoDimension;
|
|||
use ndarray::Zip;
|
||||
|
||||
// Current crate
|
||||
use crate::element::ExpElement;
|
||||
use crate::element::FloatNdArrayElement;
|
||||
use crate::element::IntNdArrayElement;
|
||||
use crate::element::QuantElement;
|
||||
use crate::{tensor::NdArrayTensor, NdArray};
|
||||
use crate::{NdArrayDevice, SEED};
|
||||
|
@ -22,71 +22,73 @@ use burn_tensor::{backend::Backend, Shape, TensorData};
|
|||
|
||||
use super::{NdArrayMathOps, NdArrayOps};
|
||||
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E, Q> {
|
||||
fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<i64> {
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
{
|
||||
fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<I> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn int_shape(tensor: &NdArrayTensor<i64>) -> Shape {
|
||||
fn int_shape(tensor: &NdArrayTensor<I>) -> Shape {
|
||||
tensor.shape()
|
||||
}
|
||||
|
||||
async fn int_into_data(tensor: NdArrayTensor<i64>) -> TensorData {
|
||||
async fn int_into_data(tensor: NdArrayTensor<I>) -> TensorData {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
TensorData::new(values, shape)
|
||||
}
|
||||
|
||||
fn int_to_device(tensor: NdArrayTensor<i64>, _device: &NdArrayDevice) -> NdArrayTensor<i64> {
|
||||
fn int_to_device(tensor: NdArrayTensor<I>, _device: &NdArrayDevice) -> NdArrayTensor<I> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn int_reshape(tensor: NdArrayTensor<i64>, shape: Shape) -> NdArrayTensor<i64> {
|
||||
fn int_reshape(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
|
||||
NdArrayOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_slice(tensor: NdArrayTensor<i64>, ranges: &[Range<usize>]) -> NdArrayTensor<i64> {
|
||||
fn int_slice(tensor: NdArrayTensor<I>, ranges: &[Range<usize>]) -> NdArrayTensor<I> {
|
||||
NdArrayOps::slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn int_device(_tensor: &NdArrayTensor<i64>) -> <NdArray<E> as Backend>::Device {
|
||||
fn int_device(_tensor: &NdArrayTensor<I>) -> <NdArray<E> as Backend>::Device {
|
||||
NdArrayDevice::Cpu
|
||||
}
|
||||
|
||||
fn int_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<i64> {
|
||||
fn int_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
|
||||
let values = vec![0; shape.num_elements()];
|
||||
NdArrayTensor::from_data(TensorData::new(values, shape))
|
||||
}
|
||||
|
||||
fn int_mask_where(
|
||||
tensor: NdArrayTensor<i64>,
|
||||
tensor: NdArrayTensor<I>,
|
||||
mask: NdArrayTensor<bool>,
|
||||
source: NdArrayTensor<i64>,
|
||||
) -> NdArrayTensor<i64> {
|
||||
source: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::mask_where(tensor, mask, source)
|
||||
}
|
||||
|
||||
fn int_mask_fill(
|
||||
tensor: NdArrayTensor<i64>,
|
||||
tensor: NdArrayTensor<I>,
|
||||
mask: NdArrayTensor<bool>,
|
||||
value: i64,
|
||||
) -> NdArrayTensor<i64> {
|
||||
value: I,
|
||||
) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::mask_fill(tensor, mask, value)
|
||||
}
|
||||
|
||||
fn int_slice_assign(
|
||||
tensor: NdArrayTensor<i64>,
|
||||
tensor: NdArrayTensor<I>,
|
||||
ranges: &[Range<usize>],
|
||||
value: NdArrayTensor<i64>,
|
||||
) -> NdArrayTensor<i64> {
|
||||
value: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<I> {
|
||||
NdArrayOps::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn int_cat(tensors: Vec<NdArrayTensor<i64>>, dim: usize) -> NdArrayTensor<i64> {
|
||||
fn int_cat(tensors: Vec<NdArrayTensor<I>>, dim: usize) -> NdArrayTensor<I> {
|
||||
NdArrayOps::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn int_equal(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<bool> {
|
||||
fn int_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
|
||||
let output = Zip::from(&lhs.array)
|
||||
.and(&rhs.array)
|
||||
.map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
|
||||
|
@ -94,196 +96,196 @@ impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E,
|
|||
NdArrayTensor::new(output)
|
||||
}
|
||||
|
||||
fn int_equal_elem(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<bool> {
|
||||
fn int_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
|
||||
let array = lhs.array.mapv(|a| a == rhs).into_shared();
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn int_greater(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<bool> {
|
||||
fn int_greater(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
|
||||
let tensor = Self::int_sub(lhs, rhs);
|
||||
Self::int_greater_elem(tensor, 0)
|
||||
Self::int_greater_elem(tensor, 0.elem())
|
||||
}
|
||||
|
||||
fn int_greater_elem(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<bool> {
|
||||
fn int_greater_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
|
||||
let array = lhs.array.mapv(|a| a > rhs).into_shared();
|
||||
NdArrayTensor::new(array)
|
||||
}
|
||||
|
||||
fn int_greater_equal(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<bool> {
|
||||
fn int_greater_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
|
||||
let tensor = Self::int_sub(lhs, rhs);
|
||||
Self::int_greater_equal_elem(tensor, 0)
|
||||
Self::int_greater_equal_elem(tensor, 0.elem())
|
||||
}
|
||||
|
||||
fn int_greater_equal_elem(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<bool> {
|
||||
fn int_greater_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
|
||||
let array = lhs.array.mapv(|a| a >= rhs).into_shared();
|
||||
NdArrayTensor::new(array)
|
||||
}
|
||||
|
||||
fn int_lower(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<bool> {
|
||||
fn int_lower(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
|
||||
let tensor = Self::int_sub(lhs, rhs);
|
||||
Self::int_lower_elem(tensor, 0)
|
||||
Self::int_lower_elem(tensor, 0.elem())
|
||||
}
|
||||
|
||||
fn int_lower_elem(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<bool> {
|
||||
fn int_lower_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
|
||||
let array = lhs.array.mapv(|a| a < rhs).into_shared();
|
||||
NdArrayTensor::new(array)
|
||||
}
|
||||
|
||||
fn int_lower_equal(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<bool> {
|
||||
fn int_lower_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
|
||||
let tensor = Self::int_sub(lhs, rhs);
|
||||
Self::int_lower_equal_elem(tensor, 0)
|
||||
Self::int_lower_equal_elem(tensor, 0.elem())
|
||||
}
|
||||
|
||||
fn int_lower_equal_elem(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<bool> {
|
||||
fn int_lower_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
|
||||
let array = lhs.array.mapv(|a| a <= rhs).into_shared();
|
||||
NdArrayTensor::new(array)
|
||||
}
|
||||
|
||||
fn int_add(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
fn int_add(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::add(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_add_scalar(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<i64> {
|
||||
fn int_add_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::add_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_sub(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
fn int_sub(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::sub(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_sub_scalar(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<i64> {
|
||||
fn int_sub_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::sub_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
fn int_mul(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::mul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul_scalar(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<i64> {
|
||||
fn int_mul_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::mul_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_div(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
fn int_div(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::div(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_div_scalar(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<i64> {
|
||||
fn int_div_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::div_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_remainder_scalar(lhs: NdArrayTensor<i64>, rhs: i64) -> NdArrayTensor<i64> {
|
||||
fn int_remainder_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::remainder_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_neg(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
Self::int_mul_scalar(tensor, -1)
|
||||
fn int_neg(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
Self::int_mul_scalar(tensor, (-1).elem())
|
||||
}
|
||||
|
||||
fn int_zeros(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<i64> {
|
||||
fn int_zeros(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
|
||||
Self::int_from_data(TensorData::zeros::<i64, _>(shape), device)
|
||||
}
|
||||
|
||||
fn int_ones(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<i64> {
|
||||
fn int_ones(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
|
||||
Self::int_from_data(TensorData::ones::<i64, _>(shape), device)
|
||||
}
|
||||
|
||||
fn int_full(
|
||||
shape: Shape,
|
||||
fill_value: i64,
|
||||
fill_value: I,
|
||||
device: &<NdArray<E> as Backend>::Device,
|
||||
) -> NdArrayTensor<i64> {
|
||||
) -> NdArrayTensor<I> {
|
||||
Self::int_from_data(TensorData::full(shape, fill_value), device)
|
||||
}
|
||||
|
||||
fn int_sum(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
fn int_sum(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::sum(tensor)
|
||||
}
|
||||
|
||||
fn int_sum_dim(tensor: NdArrayTensor<i64>, dim: usize) -> NdArrayTensor<i64> {
|
||||
fn int_sum_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_prod(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
fn int_prod(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::prod(tensor)
|
||||
}
|
||||
|
||||
fn int_prod_dim(tensor: NdArrayTensor<i64>, dim: usize) -> NdArrayTensor<i64> {
|
||||
fn int_prod_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::prod_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_mean(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
fn int_mean(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::mean(tensor)
|
||||
}
|
||||
|
||||
fn int_mean_dim(tensor: NdArrayTensor<i64>, dim: usize) -> NdArrayTensor<i64> {
|
||||
fn int_mean_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_gather(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<i64>,
|
||||
indices: NdArrayTensor<i64>,
|
||||
) -> NdArrayTensor<i64> {
|
||||
tensor: NdArrayTensor<I>,
|
||||
indices: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn int_scatter(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<i64>,
|
||||
indices: NdArrayTensor<i64>,
|
||||
value: NdArrayTensor<i64>,
|
||||
) -> NdArrayTensor<i64> {
|
||||
tensor: NdArrayTensor<I>,
|
||||
indices: NdArrayTensor<I>,
|
||||
value: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn int_select(
|
||||
tensor: NdArrayTensor<i64>,
|
||||
tensor: NdArrayTensor<I>,
|
||||
dim: usize,
|
||||
indices: NdArrayTensor<i64>,
|
||||
) -> NdArrayTensor<i64> {
|
||||
indices: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::select(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn int_select_assign(
|
||||
tensor: NdArrayTensor<i64>,
|
||||
tensor: NdArrayTensor<I>,
|
||||
dim: usize,
|
||||
indices: NdArrayTensor<i64>,
|
||||
value: NdArrayTensor<i64>,
|
||||
) -> NdArrayTensor<i64> {
|
||||
indices: NdArrayTensor<I>,
|
||||
value: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
fn int_argmax(tensor: NdArrayTensor<i64>, dim: usize) -> NdArrayTensor<i64> {
|
||||
fn int_argmax(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::argmax(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_argmin(tensor: NdArrayTensor<i64>, dim: usize) -> NdArrayTensor<i64> {
|
||||
fn int_argmin(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::argmin(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_clamp_min(tensor: NdArrayTensor<i64>, min: i64) -> NdArrayTensor<i64> {
|
||||
fn int_clamp_min(tensor: NdArrayTensor<I>, min: I) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::clamp_min(tensor, min)
|
||||
}
|
||||
|
||||
fn int_clamp_max(tensor: NdArrayTensor<i64>, max: i64) -> NdArrayTensor<i64> {
|
||||
fn int_clamp_max(tensor: NdArrayTensor<I>, max: I) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::clamp_max(tensor, max)
|
||||
}
|
||||
|
||||
fn int_clamp(tensor: NdArrayTensor<i64>, min: i64, max: i64) -> NdArrayTensor<i64> {
|
||||
fn int_clamp(tensor: NdArrayTensor<I>, min: I, max: I) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::clamp(tensor, min, max)
|
||||
}
|
||||
|
||||
fn int_abs(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
fn int_abs(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
let array = tensor.array.mapv_into(|a| a.int_abs_elem()).into_shared();
|
||||
|
||||
NdArrayTensor::new(array)
|
||||
}
|
||||
|
||||
fn int_into_float(tensor: NdArrayTensor<i64>) -> <NdArray<E> as Backend>::FloatTensorPrimitive {
|
||||
fn int_into_float(tensor: NdArrayTensor<I>) -> <NdArray<E> as Backend>::FloatTensorPrimitive {
|
||||
let array = tensor.array.mapv(|a| a.elem()).into_shared();
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn int_swap_dims(tensor: NdArrayTensor<i64>, dim1: usize, dim2: usize) -> NdArrayTensor<i64> {
|
||||
fn int_swap_dims(tensor: NdArrayTensor<I>, dim1: usize, dim2: usize) -> NdArrayTensor<I> {
|
||||
NdArrayOps::swap_dims(tensor, dim1, dim2)
|
||||
}
|
||||
|
||||
|
@ -291,7 +293,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E,
|
|||
shape: Shape,
|
||||
distribution: Distribution,
|
||||
device: &NdArrayDevice,
|
||||
) -> NdArrayTensor<i64> {
|
||||
) -> NdArrayTensor<I> {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
let mut rng = if let Some(rng_seeded) = seed.as_ref() {
|
||||
rng_seeded.clone()
|
||||
|
@ -313,32 +315,36 @@ impl<E: FloatNdArrayElement, Q: QuantElement> IntTensorOps<Self> for NdArray<E,
|
|||
tensor
|
||||
}
|
||||
|
||||
fn int_powi(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &i64, b: &i64| a.pow(*b as u32))
|
||||
fn int_powi(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
|
||||
(a.elem::<i64>().pow(b.elem::<u32>())).elem()
|
||||
})
|
||||
}
|
||||
|
||||
fn int_powf(lhs: NdArrayTensor<i64>, rhs: NdArrayTensor<E>) -> NdArrayTensor<i64> {
|
||||
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &i64, b: &E| a.pow(b.elem::<u32>()))
|
||||
fn int_powf(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<E>) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &E| {
|
||||
(a.elem::<i64>().pow(b.elem::<u32>())).elem()
|
||||
})
|
||||
}
|
||||
|
||||
fn int_powf_scalar(lhs: NdArrayTensor<i64>, rhs: f32) -> NdArrayTensor<i64> {
|
||||
NdArrayMathOps::elementwise_op_scalar(lhs, |a: i64| a.pow(rhs as u32))
|
||||
fn int_powf_scalar(lhs: NdArrayTensor<I>, rhs: f32) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| (a.elem::<i64>().pow(rhs as u32)).elem())
|
||||
}
|
||||
|
||||
fn int_permute(tensor: NdArrayTensor<i64>, axes: &[usize]) -> NdArrayTensor<i64> {
|
||||
fn int_permute(tensor: NdArrayTensor<I>, axes: &[usize]) -> NdArrayTensor<I> {
|
||||
let array = tensor.array.permuted_axes(axes.into_dimension());
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn int_flip(tensor: NdArrayTensor<i64>, axes: &[usize]) -> NdArrayTensor<i64> {
|
||||
fn int_flip(tensor: NdArrayTensor<I>, axes: &[usize]) -> NdArrayTensor<I> {
|
||||
NdArrayOps::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_sign(tensor: NdArrayTensor<i64>) -> NdArrayTensor<i64> {
|
||||
fn int_sign(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::sign_op(tensor)
|
||||
}
|
||||
|
||||
fn int_expand(tensor: NdArrayTensor<i64>, shape: Shape) -> NdArrayTensor<i64> {
|
||||
fn int_expand(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
|
||||
NdArrayOps::expand(tensor, shape)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
element::{FloatNdArrayElement, QuantElement},
|
||||
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
|
||||
ops::padding::apply_padding_4d,
|
||||
sharing::UnsafeSharedRef,
|
||||
tensor::NdArrayTensor,
|
||||
|
@ -9,7 +9,7 @@ use burn_common::{iter_range_par, run_par};
|
|||
use burn_tensor::ElementConversion;
|
||||
use ndarray::Array4;
|
||||
|
||||
pub(crate) fn max_pool2d<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
pub(crate) fn max_pool2d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
|
@ -30,7 +30,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
/ stride_width)
|
||||
+ 1;
|
||||
|
||||
let x = apply_padding_4d::<E, Q>(x, padding, inf).array;
|
||||
let x = apply_padding_4d::<E, I, Q>(x, padding, inf).array;
|
||||
|
||||
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
@ -69,13 +69,17 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
NdArrayTensor::new(output.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
pub(crate) fn max_pool2d_with_indices<
|
||||
E: FloatNdArrayElement,
|
||||
I: IntNdArrayElement,
|
||||
Q: QuantElement,
|
||||
>(
|
||||
x: NdArrayTensor<E>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> (NdArrayTensor<E>, NdArrayTensor<i64>) {
|
||||
) -> (NdArrayTensor<E>, NdArrayTensor<I>) {
|
||||
let [kernel_height, kernel_width] = kernel_size;
|
||||
let [padding_height, padding_width] = padding;
|
||||
let [stride_height, stride_width] = stride;
|
||||
|
@ -90,10 +94,10 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
/ stride_width)
|
||||
+ 1;
|
||||
|
||||
let x = apply_padding_4d::<E, Q>(x, padding, inf).array;
|
||||
let x = apply_padding_4d::<E, I, Q>(x, padding, inf).array;
|
||||
|
||||
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
|
||||
let mut indices = Array4::<i64>::zeros((batch_size, channels, out_height, out_width));
|
||||
let mut indices = Array4::<I>::zeros((batch_size, channels, out_height, out_width));
|
||||
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
|
||||
|
@ -130,7 +134,7 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
}
|
||||
|
||||
output[[b, c, oh, ow]] = max_val;
|
||||
indices[[b, c, oh, ow]] = index;
|
||||
indices[[b, c, oh, ow]] = index.elem();
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -142,14 +146,14 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
(output, indices)
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
|
||||
pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement, I: IntNdArrayElement>(
|
||||
x: NdArrayTensor<E>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
_dilation: [usize; 2],
|
||||
output_grad: NdArrayTensor<E>,
|
||||
indices: NdArrayTensor<i64>,
|
||||
indices: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<E> {
|
||||
let [_batch_size, _channels, height, width] = output_grad.shape().dims();
|
||||
let [batch_size, channels, height_x, width_x] = x.shape().dims();
|
||||
|
@ -170,7 +174,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
|
|||
|
||||
for h in 0..height {
|
||||
for w in 0..width {
|
||||
let index = indices[[b, c, h, w]];
|
||||
let index = indices[[b, c, h, w]].elem::<i64>();
|
||||
let grad = output_grad[[b, c, h, w]];
|
||||
|
||||
let index_h = index as usize / width_x;
|
||||
|
|
|
@ -7,17 +7,22 @@ use super::{
|
|||
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
|
||||
};
|
||||
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
|
||||
use crate::{element::QuantElement, ops::interpolate::nearest_interpolate_backward};
|
||||
use crate::{
|
||||
element::{IntNdArrayElement, QuantElement},
|
||||
ops::interpolate::nearest_interpolate_backward,
|
||||
};
|
||||
use burn_tensor::ops::*;
|
||||
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q> {
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
{
|
||||
fn conv2d(
|
||||
x: NdArrayTensor<E>,
|
||||
weight: NdArrayTensor<E>,
|
||||
bias: Option<NdArrayTensor<E>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> NdArrayTensor<E> {
|
||||
conv2d::<E, Q>(x, weight, bias, options)
|
||||
conv2d::<E, I, Q>(x, weight, bias, options)
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
|
@ -80,7 +85,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q>
|
|||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> NdArrayTensor<E> {
|
||||
max_pool2d::<E, Q>(x, kernel_size, stride, padding, dilation)
|
||||
max_pool2d::<E, I, Q>(x, kernel_size, stride, padding, dilation)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
|
@ -89,9 +94,9 @@ impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q>
|
|||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<NdArray<E, Q>> {
|
||||
) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
|
||||
let (output, indices) =
|
||||
max_pool2d_with_indices::<E, Q>(x, kernel_size, stride, padding, dilation);
|
||||
max_pool2d_with_indices::<E, I, Q>(x, kernel_size, stride, padding, dilation);
|
||||
|
||||
MaxPool2dWithIndices::new(output, indices)
|
||||
}
|
||||
|
@ -103,8 +108,8 @@ impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q>
|
|||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: NdArrayTensor<E>,
|
||||
indices: NdArrayTensor<i64>,
|
||||
) -> MaxPool2dBackward<NdArray<E, Q>> {
|
||||
indices: NdArrayTensor<I>,
|
||||
) -> MaxPool2dBackward<NdArray<E, I, Q>> {
|
||||
MaxPool2dBackward::new(max_pool2d_backward(
|
||||
x,
|
||||
kernel_size,
|
||||
|
@ -162,7 +167,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q>
|
|||
bias: Option<NdArrayTensor<E>>,
|
||||
options: ConvOptions<3>,
|
||||
) -> NdArrayTensor<E> {
|
||||
conv3d::<E, Q>(x, weight, bias, options)
|
||||
conv3d::<E, I, Q>(x, weight, bias, options)
|
||||
}
|
||||
|
||||
fn conv_transpose3d(
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use crate::{
|
||||
element::{FloatNdArrayElement, QuantElement},
|
||||
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
|
||||
tensor::NdArrayTensor,
|
||||
NdArray,
|
||||
};
|
||||
use burn_tensor::ops::FloatTensorOps;
|
||||
use ndarray::{Array4, Array5};
|
||||
|
||||
pub(crate) fn apply_padding_4d<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
pub(crate) fn apply_padding_4d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E>,
|
||||
padding: [usize; 2],
|
||||
elem: E,
|
||||
|
@ -22,7 +22,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
);
|
||||
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
|
||||
|
||||
x_new = NdArray::<E, Q>::float_slice_assign(
|
||||
x_new = NdArray::<E, I, Q>::float_slice_assign(
|
||||
x_new,
|
||||
&[
|
||||
0..batch_size,
|
||||
|
@ -36,7 +36,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
x_new
|
||||
}
|
||||
|
||||
pub(crate) fn apply_padding_5d<E: FloatNdArrayElement, Q: QuantElement>(
|
||||
pub(crate) fn apply_padding_5d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
|
||||
x: NdArrayTensor<E>,
|
||||
padding: [usize; 3],
|
||||
elem: E,
|
||||
|
@ -59,7 +59,7 @@ pub(crate) fn apply_padding_5d<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
);
|
||||
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
|
||||
|
||||
x_new = NdArray::<E, Q>::float_slice_assign(
|
||||
x_new = NdArray::<E, I, Q>::float_slice_assign(
|
||||
x_new,
|
||||
&[
|
||||
0..batch_size,
|
||||
|
|
|
@ -10,7 +10,7 @@ use burn_tensor::{
|
|||
};
|
||||
|
||||
use crate::{
|
||||
element::{NdArrayElement, QuantElement},
|
||||
element::{IntNdArrayElement, NdArrayElement, QuantElement},
|
||||
FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor,
|
||||
};
|
||||
|
||||
|
@ -22,7 +22,9 @@ fn into_data<E: NdArrayElement>(tensor: NdArrayTensor<E>) -> TensorData {
|
|||
TensorData::new(values, shape)
|
||||
}
|
||||
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> QTensorOps<Self> for NdArray<E, Q> {
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
{
|
||||
fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
|
||||
match data.dtype {
|
||||
DType::QFloat(strategy) => match strategy {
|
||||
|
|
|
@ -5,7 +5,7 @@ use ndarray::Zip;
|
|||
|
||||
// Current crate
|
||||
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
|
||||
use crate::element::{FloatNdArrayElement, QuantElement};
|
||||
use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
|
||||
use crate::{tensor::NdArrayTensor, NdArray};
|
||||
use crate::{NdArrayDevice, SEED};
|
||||
|
||||
|
@ -20,7 +20,9 @@ use num_traits::Float;
|
|||
|
||||
use libm::erf;
|
||||
|
||||
impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E, Q> {
|
||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
|
||||
for NdArray<E, I, Q>
|
||||
{
|
||||
fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<E> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
@ -125,7 +127,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
|
|||
fn float_gather(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<E>,
|
||||
indices: NdArrayTensor<i64>,
|
||||
indices: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<E> {
|
||||
NdArrayMathOps::gather(dim, tensor, indices)
|
||||
}
|
||||
|
@ -133,7 +135,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
|
|||
fn float_scatter(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<E>,
|
||||
indices: NdArrayTensor<i64>,
|
||||
indices: NdArrayTensor<I>,
|
||||
value: NdArrayTensor<E>,
|
||||
) -> NdArrayTensor<E> {
|
||||
NdArrayMathOps::scatter(dim, tensor, indices, value)
|
||||
|
@ -142,7 +144,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
|
|||
fn float_select(
|
||||
tensor: NdArrayTensor<E>,
|
||||
dim: usize,
|
||||
indices: NdArrayTensor<i64>,
|
||||
indices: NdArrayTensor<I>,
|
||||
) -> NdArrayTensor<E> {
|
||||
NdArrayMathOps::select(tensor, dim, indices)
|
||||
}
|
||||
|
@ -150,7 +152,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
|
|||
fn float_select_assign(
|
||||
tensor: NdArrayTensor<E>,
|
||||
dim: usize,
|
||||
indices: NdArrayTensor<i64>,
|
||||
indices: NdArrayTensor<I>,
|
||||
value: NdArrayTensor<E>,
|
||||
) -> NdArrayTensor<E> {
|
||||
NdArrayMathOps::select_assign(tensor, dim, indices, value)
|
||||
|
@ -266,11 +268,11 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
|
|||
NdArrayMathOps::sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_argmax(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<i64> {
|
||||
fn float_argmax(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::argmax(tensor, dim)
|
||||
}
|
||||
|
||||
fn float_argmin(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<i64> {
|
||||
fn float_argmin(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
|
||||
NdArrayMathOps::argmin(tensor, dim)
|
||||
}
|
||||
|
||||
|
@ -374,7 +376,7 @@ impl<E: FloatNdArrayElement, Q: QuantElement> FloatTensorOps<Self> for NdArray<E
|
|||
NdArrayMathOps::clamp(tensor, min, max)
|
||||
}
|
||||
|
||||
fn float_into_int(tensor: NdArrayTensor<E>) -> <NdArray<E> as Backend>::IntTensorPrimitive {
|
||||
fn float_into_int(tensor: NdArrayTensor<E>) -> NdArrayTensor<I> {
|
||||
let array = tensor.array.mapv(|a| a.elem()).into_shared();
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
|
|
@ -212,7 +212,7 @@ mod tests {
|
|||
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
|
||||
let qparams = QuantizationParametersPrimitive {
|
||||
scale: NdArrayTensor::from_data(TensorData::from([0.009_019_608])),
|
||||
offset: Some(NdArrayTensor::from_data(TensorData::from([72]))),
|
||||
offset: Some(NdArrayTensor::<i64>::from_data(TensorData::from([72]))),
|
||||
};
|
||||
let qtensor: NdArrayQTensor<i8> = NdArray::quantize(tensor, &scheme, qparams);
|
||||
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
[package]
|
||||
authors = ["guillaumelagrange <lagrange.guillaume.1@gmail.com>", "nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "Multi-backend router decorator for the Burn framework"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-router"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router"
|
||||
documentation = "https://docs.rs/burn-router"
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = []
|
||||
doc = ["default"]
|
||||
|
||||
[dependencies]
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false, features = ["repr"]}
|
||||
hashbrown = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", default-features = false, features = [
|
||||
"export_tests",
|
||||
] }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false, features = [
|
||||
"export_tests",
|
||||
] }
|
||||
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "0.15.0" }
|
||||
burn-wgpu = { path = "../burn-wgpu", version = "0.15.0" }
|
||||
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
|
@ -0,0 +1,3 @@
|
|||
# Burn Router
|
||||
|
||||
A multi-backend extension that forwards the tensor operations to the appropriate backend.
|
|
@ -0,0 +1,125 @@
|
|||
use alloc::{format, string::String};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{
|
||||
backend::{Backend, BackendBridge},
|
||||
ops::FloatTensor,
|
||||
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
|
||||
repr::{BaseOperationDescription, OperationDescription, UnaryOperationDescription},
|
||||
Device,
|
||||
};
|
||||
|
||||
use super::{get_client, set_seed, RouterTensor, RunnerChannel, RunnerClient};
|
||||
|
||||
/// A backend that forwards the tensor operations to the appropiate backend (given multiple backends).
|
||||
pub struct BackendRouter<R: RunnerChannel> {
|
||||
r: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<R: RunnerChannel> core::fmt::Debug for BackendRouter<R> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_fmt(format_args!("router"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RunnerChannel> Clone for BackendRouter<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self { r: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RunnerChannel> Default for BackendRouter<R> {
|
||||
fn default() -> Self {
|
||||
Self { r: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: quantization tensor primitive (w/ qparams)
|
||||
impl<R: RunnerClient> QTensorPrimitive for RouterTensor<R> {
|
||||
fn scheme(&self) -> &QuantizationScheme {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn strategy(&self) -> QuantizationStrategy {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RunnerChannel> Backend for BackendRouter<R> {
|
||||
type Device = R::Device;
|
||||
|
||||
type FullPrecisionBridge = PrecisionBridge;
|
||||
|
||||
type FloatTensorPrimitive = RouterTensor<R::Client>;
|
||||
|
||||
type FloatElem = R::FloatElem;
|
||||
|
||||
type IntTensorPrimitive = RouterTensor<R::Client>;
|
||||
|
||||
type IntElem = R::IntElem;
|
||||
|
||||
type BoolTensorPrimitive = RouterTensor<R::Client>;
|
||||
|
||||
type QuantizedTensorPrimitive = RouterTensor<R::Client>;
|
||||
|
||||
type QuantizedEncoding = u32;
|
||||
|
||||
fn name() -> String {
|
||||
format!("router<{}>", R::name())
|
||||
}
|
||||
|
||||
fn seed(seed: u64) {
|
||||
set_seed(seed)
|
||||
}
|
||||
|
||||
fn sync(device: &Self::Device) {
|
||||
let client = get_client::<R>(device);
|
||||
client.sync();
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle precision conversion.
|
||||
#[derive(Debug)]
|
||||
pub struct PrecisionBridge {}
|
||||
|
||||
impl<R: RunnerChannel> BackendBridge<BackendRouter<R>> for PrecisionBridge {
|
||||
type Target = BackendRouter<R>;
|
||||
|
||||
fn into_target(
|
||||
tensor: FloatTensor<BackendRouter<R>>,
|
||||
_device: Option<Device<Self::Target>>,
|
||||
) -> FloatTensor<Self::Target> {
|
||||
let client = tensor.client.clone();
|
||||
let out = client.register_float_tensor(tensor.shape.clone(), true);
|
||||
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseFloat(
|
||||
BaseOperationDescription::Cast(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn from_target(
|
||||
tensor: FloatTensor<Self::Target>,
|
||||
_device: Option<Device<BackendRouter<R>>>,
|
||||
) -> FloatTensor<BackendRouter<R>> {
|
||||
let client = tensor.client.clone();
|
||||
let out = client.register_float_tensor(tensor.shape.clone(), false);
|
||||
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseFloat(
|
||||
BaseOperationDescription::Cast(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
use burn_tensor::{backend::DeviceOps, Shape};
|
||||
|
||||
/// Allows tensors to be transferred between multiple backends.
|
||||
pub trait MultiBackendBridge: Send + Sync + 'static {
|
||||
/// The type that can be used to point to a tensor of any kind.
|
||||
type TensorHandle;
|
||||
/// Device type used by the backends.
|
||||
type Device: DeviceOps;
|
||||
|
||||
/// Change the backend of the given float tensor.
|
||||
fn change_backend_float(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle;
|
||||
|
||||
/// Change the backend of the given int tensor.
|
||||
fn change_backend_int(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle;
|
||||
|
||||
/// Change the backend of the given bool tensor.
|
||||
fn change_backend_bool(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle;
|
||||
|
||||
// TODO: change_backend_quantized
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{
|
||||
repr::{ReprBackend, TensorHandle},
|
||||
try_read_sync, Shape,
|
||||
};
|
||||
|
||||
use super::base::MultiBackendBridge;
|
||||
use crate::{MultiDevice2, TensorHandle2};
|
||||
|
||||
/// Simply transfers tensors between backends via the underlying [tensor data](burn_tensor::TensorData).
|
||||
pub struct ByteBridge<Backends> {
|
||||
backends: PhantomData<Backends>,
|
||||
}
|
||||
|
||||
impl<B1: ReprBackend, B2: ReprBackend> MultiBackendBridge for ByteBridge<(B1, B2)> {
|
||||
type TensorHandle = TensorHandle2<B1, B2>;
|
||||
type Device = MultiDevice2<B1, B2>;
|
||||
|
||||
fn change_backend_float(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
let msg = "Failed to read tensor data synchronously.
|
||||
This can happen on platforms that don't support blocking futures like WASM.";
|
||||
match tensor {
|
||||
TensorHandle2::Handle1(handle) => match target_device {
|
||||
MultiDevice2::Device1(device) => {
|
||||
// Same backend
|
||||
let tensor = B1::float_tensor(TensorHandle { handle, shape });
|
||||
let tensor = B1::float_to_device(tensor, device);
|
||||
let handle = B1::float_tensor_handle(tensor);
|
||||
TensorHandle2::Handle1(handle)
|
||||
}
|
||||
MultiDevice2::Device2(device) => {
|
||||
let tensor = B1::float_tensor(TensorHandle { handle, shape });
|
||||
let data = try_read_sync(B1::float_into_data(tensor)).expect(msg);
|
||||
let tensor = B2::float_from_data(data, device);
|
||||
let handle = B2::float_tensor_handle(tensor);
|
||||
TensorHandle2::Handle2(handle)
|
||||
}
|
||||
},
|
||||
TensorHandle2::Handle2(handle) => match target_device {
|
||||
MultiDevice2::Device1(device) => {
|
||||
let tensor = B2::float_tensor(TensorHandle { handle, shape });
|
||||
let data = try_read_sync(B2::float_into_data(tensor)).expect(msg);
|
||||
let tensor = B1::float_from_data(data, device);
|
||||
let handle = B1::float_tensor_handle(tensor);
|
||||
TensorHandle2::Handle1(handle)
|
||||
}
|
||||
MultiDevice2::Device2(device) => {
|
||||
// Same backend
|
||||
let tensor = B2::float_tensor(TensorHandle { handle, shape });
|
||||
let tensor = B2::float_to_device(tensor, device);
|
||||
let handle = B2::float_tensor_handle(tensor);
|
||||
TensorHandle2::Handle2(handle)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn change_backend_int(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
let msg = "Failed to read tensor data synchronously.
|
||||
This can happen on platforms that don't support blocking futures like WASM.";
|
||||
match tensor {
|
||||
TensorHandle2::Handle1(handle) => match target_device {
|
||||
MultiDevice2::Device1(device) => {
|
||||
// Same backend
|
||||
let tensor = B1::int_tensor(TensorHandle { handle, shape });
|
||||
let tensor = B1::int_to_device(tensor, device);
|
||||
let handle = B1::int_tensor_handle(tensor);
|
||||
TensorHandle2::Handle1(handle)
|
||||
}
|
||||
MultiDevice2::Device2(device) => {
|
||||
let tensor = B1::int_tensor(TensorHandle { handle, shape });
|
||||
let data = try_read_sync(B1::int_into_data(tensor)).expect(msg);
|
||||
let tensor = B2::int_from_data(data, device);
|
||||
let handle = B2::int_tensor_handle(tensor);
|
||||
TensorHandle2::Handle2(handle)
|
||||
}
|
||||
},
|
||||
TensorHandle2::Handle2(handle) => match target_device {
|
||||
MultiDevice2::Device1(device) => {
|
||||
let tensor = B2::int_tensor(TensorHandle { handle, shape });
|
||||
let data = try_read_sync(B2::int_into_data(tensor)).expect(msg);
|
||||
let tensor = B1::int_from_data(data, device);
|
||||
let handle = B1::int_tensor_handle(tensor);
|
||||
TensorHandle2::Handle1(handle)
|
||||
}
|
||||
MultiDevice2::Device2(device) => {
|
||||
// Same backend
|
||||
let tensor = B2::int_tensor(TensorHandle { handle, shape });
|
||||
let tensor = B2::int_to_device(tensor, device);
|
||||
let handle = B2::int_tensor_handle(tensor);
|
||||
TensorHandle2::Handle2(handle)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn change_backend_bool(
|
||||
tensor: Self::TensorHandle,
|
||||
shape: Shape,
|
||||
target_device: &Self::Device,
|
||||
) -> Self::TensorHandle {
|
||||
let msg = "Failed to read tensor data synchronously.
|
||||
This can happen on platforms that don't support blocking futures like WASM.";
|
||||
match tensor {
|
||||
TensorHandle2::Handle1(handle) => match target_device {
|
||||
MultiDevice2::Device1(device) => {
|
||||
// Same backend
|
||||
let tensor = B1::bool_tensor(TensorHandle { handle, shape });
|
||||
let tensor = B1::bool_to_device(tensor, device);
|
||||
let handle = B1::bool_tensor_handle(tensor);
|
||||
TensorHandle2::Handle1(handle)
|
||||
}
|
||||
MultiDevice2::Device2(device) => {
|
||||
let tensor = B1::bool_tensor(TensorHandle { handle, shape });
|
||||
let data = try_read_sync(B1::bool_into_data(tensor)).expect(msg);
|
||||
let tensor = B2::bool_from_data(data, device);
|
||||
let handle = B2::bool_tensor_handle(tensor);
|
||||
TensorHandle2::Handle2(handle)
|
||||
}
|
||||
},
|
||||
TensorHandle2::Handle2(handle) => match target_device {
|
||||
MultiDevice2::Device1(device) => {
|
||||
let tensor = B2::bool_tensor(TensorHandle { handle, shape });
|
||||
let data = try_read_sync(B2::bool_into_data(tensor)).expect(msg);
|
||||
let tensor = B1::bool_from_data(data, device);
|
||||
let handle = B1::bool_tensor_handle(tensor);
|
||||
TensorHandle2::Handle1(handle)
|
||||
}
|
||||
MultiDevice2::Device2(device) => {
|
||||
// Same backend
|
||||
let tensor = B2::bool_tensor(TensorHandle { handle, shape });
|
||||
let tensor = B2::bool_to_device(tensor, device);
|
||||
let handle = B2::bool_tensor_handle(tensor);
|
||||
TensorHandle2::Handle2(handle)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_tensor::{backend::Backend, Tensor};
|
||||
|
||||
use super::*;
|
||||
use crate::tests::{TestBackend, TestBackend1, TestBackend2};
|
||||
|
||||
#[test]
|
||||
fn should_support_dual_byte_bridge() {
|
||||
let device1 = MultiDevice2::Device1(<TestBackend1 as Backend>::Device::default());
|
||||
let device2 = MultiDevice2::Device2(<TestBackend2 as Backend>::Device::default());
|
||||
let tensor1 = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device1);
|
||||
let tensor2 = Tensor::<TestBackend, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device2);
|
||||
|
||||
let tensor1_2 = tensor1.clone().to_device(&device2);
|
||||
tensor1.into_data().assert_eq(&tensor1_2.into_data(), true);
|
||||
|
||||
let tensor2_1 = tensor2.clone().to_device(&device1);
|
||||
tensor2.into_data().assert_eq(&tensor2_1.into_data(), true);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
mod base;
|
||||
mod byte;
|
||||
|
||||
pub use base::*;
|
||||
pub use byte::*;
|
|
@ -0,0 +1,68 @@
|
|||
use alloc::{string::String, vec::Vec};
|
||||
use burn_tensor::{backend::DeviceOps, repr::TensorDescription, DType, Element};
|
||||
|
||||
use crate::{get_client, MultiBackendBridge, RouterTensor, RunnerClient};
|
||||
|
||||
/// Type alias for `<Br as MultiBackendBridge>::TensorHandle`.
|
||||
pub type TensorHandle<Br> = <Br as MultiBackendBridge>::TensorHandle;
|
||||
|
||||
/// Defines the connection channel and operations for a setup with multiple backend runner clients.
|
||||
pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized {
|
||||
/// Device type.
|
||||
type Device: DeviceOps;
|
||||
/// A bridge that can transfer tensors between multiple backends.
|
||||
type Bridge: MultiBackendBridge<Device = Self::Device>;
|
||||
/// Client type.
|
||||
type Client: RunnerClient<Device = Self::Device>;
|
||||
/// Float element type.
|
||||
type FloatElem: Element;
|
||||
/// Int element type.
|
||||
type IntElem: Element;
|
||||
|
||||
/// Name of the channel.
|
||||
fn name() -> String;
|
||||
|
||||
/// Initialize a new client for the given device.
|
||||
fn init_client(device: &Self::Device) -> Self::Client;
|
||||
|
||||
/// Get the tensor handle corresponding to the [tensor description](TensorDescription).
|
||||
fn get_tensor_handle(
|
||||
tensor: &TensorDescription,
|
||||
client: &Self::Client,
|
||||
) -> TensorHandle<Self::Bridge>;
|
||||
|
||||
// TODO: get quantized tensor handle from QuantizedTensorDescription
|
||||
|
||||
/// Create a tensor with the given handle and shape.
|
||||
fn register_tensor(
|
||||
client: &Self::Client,
|
||||
handle: TensorHandle<Self::Bridge>,
|
||||
shape: Vec<usize>,
|
||||
dtype: DType,
|
||||
) -> RouterTensor<Self::Client>;
|
||||
|
||||
/// Change the tensor to a different client backend.
|
||||
fn change_client_backend(
|
||||
tensor: RouterTensor<Self::Client>,
|
||||
device: &Self::Device, // target device
|
||||
) -> RouterTensor<Self::Client> {
|
||||
// Get tensor handle from current client
|
||||
let original_client = tensor.client.clone();
|
||||
let desc = tensor.into_description();
|
||||
let mut handle = Self::get_tensor_handle(&desc, &original_client);
|
||||
|
||||
if desc.dtype.is_float() {
|
||||
handle = Self::Bridge::change_backend_float(handle, desc.shape.clone().into(), device);
|
||||
} else if desc.dtype.is_int() {
|
||||
handle = Self::Bridge::change_backend_int(handle, desc.shape.clone().into(), device);
|
||||
} else if desc.dtype.is_bool() {
|
||||
handle = Self::Bridge::change_backend_bool(handle, desc.shape.clone().into(), device);
|
||||
} else {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
// Register tensor handle on target client
|
||||
let target_client = get_client::<Self>(device);
|
||||
Self::register_tensor(&target_client, handle, desc.shape, desc.dtype)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,271 @@
|
|||
use alloc::{format, string::String, sync::Arc, vec::Vec};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{
|
||||
backend::{Backend, BackendBridge, DeviceId, DeviceOps},
|
||||
repr::{OperationDescription, ReprBackend, TensorDescription, TensorId},
|
||||
DType, TensorData,
|
||||
};
|
||||
|
||||
use super::{RunnerChannel, TensorHandle};
|
||||
use crate::{MultiBackendBridge, RouterTensor, Runner, RunnerClient};
|
||||
|
||||
/// A local channel with direct connection to the backend runner clients.
|
||||
pub struct DirectChannel<Backends, Bridge> {
|
||||
backends: PhantomData<Backends>,
|
||||
bridge: PhantomData<Bridge>,
|
||||
}
|
||||
|
||||
impl<Backends, Bridge> Clone for DirectChannel<Backends, Bridge> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
backends: self.backends,
|
||||
bridge: self.bridge,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B1, B2, Br> RunnerChannel for DirectChannel<(B1, B2), Br>
|
||||
where
|
||||
B1: ReprBackend,
|
||||
B2: ReprBackend<FloatElem = B1::FloatElem, IntElem = B1::IntElem>,
|
||||
Br: MultiBackendBridge<TensorHandle = TensorHandle2<B1, B2>, Device = MultiDevice2<B1, B2>>,
|
||||
// Restrict full precision backend handle to be the same
|
||||
<<B1 as Backend>::FullPrecisionBridge as BackendBridge<B1>>::Target:
|
||||
ReprBackend<Handle = B1::Handle>,
|
||||
<<B2 as Backend>::FullPrecisionBridge as BackendBridge<B2>>::Target:
|
||||
ReprBackend<Handle = B2::Handle>,
|
||||
{
|
||||
type Device = Br::Device;
|
||||
|
||||
type Bridge = Br;
|
||||
|
||||
type FloatElem = B1::FloatElem;
|
||||
type IntElem = B1::IntElem;
|
||||
|
||||
type Client = MultiRunnerClient2<B1, B2>;
|
||||
|
||||
fn init_client(device: &Self::Device) -> Self::Client {
|
||||
match device {
|
||||
MultiDevice2::Device1(device) => {
|
||||
MultiRunnerClient2::RunnerClient1(Runner::new(device.clone()))
|
||||
}
|
||||
MultiDevice2::Device2(device) => {
|
||||
MultiRunnerClient2::RunnerClient2(Runner::new(device.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tensor_handle(
|
||||
tensor: &TensorDescription,
|
||||
client: &Self::Client,
|
||||
) -> TensorHandle<Self::Bridge> {
|
||||
match client {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => {
|
||||
TensorHandle2::Handle1(runner.get_tensor_handle(tensor))
|
||||
}
|
||||
MultiRunnerClient2::RunnerClient2(runner) => {
|
||||
TensorHandle2::Handle2(runner.get_tensor_handle(tensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_tensor(
|
||||
client: &Self::Client,
|
||||
handle: TensorHandle<Self::Bridge>,
|
||||
shape: Vec<usize>,
|
||||
dtype: DType,
|
||||
) -> RouterTensor<Self::Client> {
|
||||
match client {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => match handle {
|
||||
TensorHandle2::Handle1(handle) => {
|
||||
runner.register_tensor(handle, shape, dtype, client.clone())
|
||||
}
|
||||
TensorHandle2::Handle2(_) => {
|
||||
unreachable!("Can't register tensor handle for another backend.")
|
||||
}
|
||||
},
|
||||
MultiRunnerClient2::RunnerClient2(runner) => match handle {
|
||||
TensorHandle2::Handle1(_) => {
|
||||
unreachable!("Can't register tensor handle for another backend.")
|
||||
}
|
||||
TensorHandle2::Handle2(handle) => {
|
||||
runner.register_tensor(handle, shape, dtype, client.clone())
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn name() -> String {
|
||||
format!("direct<({}, {})>", B1::name(), B2::name())
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: generate this for different number of backends (up to 4?)
|
||||
|
||||
/// Handle type to interact with two backends.
|
||||
pub enum TensorHandle2<B1: ReprBackend, B2: ReprBackend> {
|
||||
/// Handle for the first backend.
|
||||
Handle1(B1::Handle),
|
||||
/// Handle for the second backend.
|
||||
Handle2(B2::Handle),
|
||||
}
|
||||
|
||||
/// Device type to interact with two backends.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum MultiDevice2<B1: Backend, B2: Backend> {
|
||||
/// Device for the first backend.
|
||||
Device1(B1::Device),
|
||||
/// Device for the second backend.
|
||||
Device2(B2::Device),
|
||||
}
|
||||
|
||||
impl<B1: Backend, B2: Backend> PartialEq for MultiDevice2<B1, B2> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Self::Device1(lhs), Self::Device1(rhs)) => lhs == rhs,
|
||||
(Self::Device2(lhs), Self::Device2(rhs)) => lhs == rhs,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B1: Backend, B2: Backend> Eq for MultiDevice2<B1, B2> {}
|
||||
|
||||
impl<B1: Backend, B2: Backend> Default for MultiDevice2<B1, B2> {
|
||||
fn default() -> Self {
|
||||
Self::Device1(B1::Device::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B1: Backend, B2: Backend> DeviceOps for MultiDevice2<B1, B2> {
|
||||
fn id(&self) -> DeviceId {
|
||||
match self {
|
||||
MultiDevice2::Device1(device) => device.id(),
|
||||
MultiDevice2::Device2(device) => device.id(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Local [`RunnerClient`] with two backends.
|
||||
#[derive(Clone)]
|
||||
pub enum MultiRunnerClient2<B1: ReprBackend, B2: ReprBackend> {
|
||||
/// Client for the first backend runner.
|
||||
RunnerClient1(Runner<B1>),
|
||||
/// Client for the second backend runner.
|
||||
RunnerClient2(Runner<B2>),
|
||||
}
|
||||
|
||||
impl<B1: ReprBackend, B2: ReprBackend> RunnerClient for MultiRunnerClient2<B1, B2>
|
||||
where
|
||||
<<B1 as Backend>::FullPrecisionBridge as BackendBridge<B1>>::Target:
|
||||
ReprBackend<Handle = B1::Handle>,
|
||||
<<B2 as Backend>::FullPrecisionBridge as BackendBridge<B2>>::Target:
|
||||
ReprBackend<Handle = B2::Handle>,
|
||||
{
|
||||
type Device = MultiDevice2<B1, B2>;
|
||||
|
||||
fn register(&self, op: OperationDescription) {
|
||||
match self {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => runner.register(op),
|
||||
MultiRunnerClient2::RunnerClient2(runner) => runner.register(op),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_tensor(&self, tensor: TensorDescription) -> TensorData {
|
||||
match self {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => runner.read_tensor(tensor).await,
|
||||
MultiRunnerClient2::RunnerClient2(runner) => runner.read_tensor(tensor).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
|
||||
match self {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => {
|
||||
let desc = runner.register_tensor_data_desc(data);
|
||||
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
|
||||
}
|
||||
MultiRunnerClient2::RunnerClient2(runner) => {
|
||||
let desc = runner.register_tensor_data_desc(data);
|
||||
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_empty_tensor(&self, shape: Vec<usize>, dtype: DType) -> RouterTensor<Self> {
|
||||
match self {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => {
|
||||
let desc = runner.register_empty_tensor_desc(shape, dtype);
|
||||
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
|
||||
}
|
||||
MultiRunnerClient2::RunnerClient2(runner) => {
|
||||
let desc = runner.register_empty_tensor_desc(shape, dtype);
|
||||
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_float_tensor(&self, shape: Vec<usize>, full_precision: bool) -> RouterTensor<Self> {
|
||||
match self {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => {
|
||||
let desc = runner.register_float_tensor_desc(shape, full_precision);
|
||||
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
|
||||
}
|
||||
MultiRunnerClient2::RunnerClient2(runner) => {
|
||||
let desc = runner.register_float_tensor_desc(shape, full_precision);
|
||||
RouterTensor::new(Arc::new(desc.id), desc.shape, desc.dtype, self.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn device(&self) -> Self::Device {
|
||||
match self {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => MultiDevice2::Device1(runner.device()),
|
||||
MultiRunnerClient2::RunnerClient2(runner) => MultiDevice2::Device2(runner.device()),
|
||||
}
|
||||
}
|
||||
|
||||
fn register_orphan(&self, id: &TensorId) {
|
||||
match self {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => runner.register_orphan(id),
|
||||
MultiRunnerClient2::RunnerClient2(runner) => runner.register_orphan(id),
|
||||
}
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
match self {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => runner.sync(),
|
||||
MultiRunnerClient2::RunnerClient2(runner) => runner.sync(),
|
||||
}
|
||||
}
|
||||
|
||||
fn seed(&self, seed: u64) {
|
||||
match self {
|
||||
MultiRunnerClient2::RunnerClient1(runner) => runner.seed(seed),
|
||||
MultiRunnerClient2::RunnerClient2(runner) => runner.seed(seed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: conflicting implementations because B1 and B2 cannot be differentiated (could be the same type)
|
||||
// impl<B1: ReprBackend, B2: ReprBackend> From<RouterTensor<Runner<B1>>>
|
||||
// for RouterTensor<MultiRunnerClient2<B1, B2>>
|
||||
// {
|
||||
// fn from(value: RouterTensor<Runner<B1>>) -> Self {
|
||||
// RouterTensor {
|
||||
// desc: value.desc,
|
||||
// client: MultiRunnerClient2::RunnerClient1(value.client),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl<B1: ReprBackend, B2: ReprBackend> From<RouterTensor<Runner<B2>>>
|
||||
// for RouterTensor<MultiRunnerClient2<B1, B2>>
|
||||
// {
|
||||
// fn from(value: RouterTensor<Runner<B2>>) -> Self {
|
||||
// RouterTensor {
|
||||
// desc: value.desc,
|
||||
// client: MultiRunnerClient2::RunnerClient2(value.client),
|
||||
// }
|
||||
// }
|
||||
// }
|
|
@ -0,0 +1,5 @@
|
|||
mod base;
|
||||
mod direct;
|
||||
|
||||
pub use base::*;
|
||||
pub use direct::*;
|
|
@ -0,0 +1,138 @@
|
|||
use alloc::{boxed::Box, vec::Vec};
|
||||
use core::{
|
||||
future::Future,
|
||||
ops::DerefMut,
|
||||
sync::atomic::{AtomicBool, AtomicU64, Ordering},
|
||||
};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
use spin::Mutex;
|
||||
|
||||
use burn_tensor::{
|
||||
backend::{DeviceId, DeviceOps},
|
||||
repr::{OperationDescription, TensorDescription, TensorId},
|
||||
DType, TensorData,
|
||||
};
|
||||
|
||||
use crate::{RouterTensor, RunnerChannel};
|
||||
|
||||
/// Type alias for `<R as RunnerChannel>::Client`.
|
||||
pub type Client<R> = <R as RunnerChannel>::Client;
|
||||
pub(crate) static CLIENTS: RunnerClientLocator = RunnerClientLocator::new();
|
||||
static SEED_SET: AtomicBool = AtomicBool::new(false);
|
||||
static SEED: AtomicU64 = AtomicU64::new(0);
|
||||
type Key = (core::any::TypeId, DeviceId);
|
||||
|
||||
/// Define how to interact with the runner.
|
||||
pub trait RunnerClient: Clone + Send + Sync + Sized {
|
||||
/// Device type.
|
||||
type Device: DeviceOps;
|
||||
|
||||
/// Register a new tensor operation to be executed by the (runner) server.
|
||||
fn register(&self, op: OperationDescription);
|
||||
/// Read the values contained by a tensor.
|
||||
fn read_tensor(&self, tensor: TensorDescription) -> impl Future<Output = TensorData> + Send;
|
||||
/// Create a new [RouterTensor] from the tensor data.
|
||||
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self>;
|
||||
/// Create a new [RouterTensor] with no resources associated.
|
||||
fn register_empty_tensor(&self, shape: Vec<usize>, dtype: DType) -> RouterTensor<Self>;
|
||||
/// Create a new float [RouterTensor] with no resources associated.
|
||||
fn register_float_tensor(&self, shape: Vec<usize>, full_precision: bool) -> RouterTensor<Self>;
|
||||
/// Get the current device used by all operations handled by this client.
|
||||
fn device(&self) -> Self::Device;
|
||||
/// Drop the tensor with the given [tensor id](TensorId).
|
||||
fn register_orphan(&self, id: &TensorId);
|
||||
/// Sync the runner, ensure that all computations are finished.
|
||||
fn sync(&self);
|
||||
/// Seed the runner.
|
||||
fn seed(&self, seed: u64);
|
||||
}
|
||||
|
||||
pub(crate) struct RunnerClientLocator {
|
||||
clients: Mutex<Option<HashMap<Key, Box<dyn core::any::Any + Send>>>>,
|
||||
}
|
||||
|
||||
pub(crate) fn get_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
|
||||
CLIENTS.client::<R>(device)
|
||||
}
|
||||
|
||||
pub(crate) fn set_seed(seed: u64) {
|
||||
SEED_SET.store(true, Ordering::Relaxed);
|
||||
SEED.store(seed, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn get_seed() -> Option<u64> {
|
||||
if SEED_SET.load(Ordering::Relaxed) {
|
||||
Some(SEED.load(Ordering::Relaxed))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize a new client for the given device.
|
||||
///
|
||||
/// If a (global) seed was previously set, the client seed is set.
|
||||
fn new_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
|
||||
let client = R::init_client(device);
|
||||
if let Some(seed) = get_seed() {
|
||||
client.seed(seed)
|
||||
}
|
||||
client
|
||||
}
|
||||
|
||||
impl RunnerClientLocator {
|
||||
/// Create a new client locator.
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
clients: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the runner client for the given device.
|
||||
///
|
||||
/// If a client isn't already initialized, it is created.
|
||||
pub fn client<R: RunnerChannel + 'static>(&self, device: &R::Device) -> Client<R> {
|
||||
let device_id = device.id();
|
||||
let client_id = (core::any::TypeId::of::<R>(), device_id);
|
||||
let mut clients = self.clients.lock();
|
||||
|
||||
if clients.is_none() {
|
||||
let client = new_client::<R>(device);
|
||||
Self::register_inner::<R>(client_id, client, &mut clients);
|
||||
}
|
||||
|
||||
match clients.deref_mut() {
|
||||
Some(clients) => match clients.get(&client_id) {
|
||||
Some(client) => {
|
||||
let client: &Client<R> = client.downcast_ref().unwrap();
|
||||
client.clone()
|
||||
}
|
||||
None => {
|
||||
let client = new_client::<R>(device);
|
||||
let any = Box::new(client.clone());
|
||||
clients.insert(client_id, any);
|
||||
client
|
||||
}
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn register_inner<R: RunnerChannel + 'static>(
|
||||
key: Key,
|
||||
client: Client<R>,
|
||||
clients: &mut Option<HashMap<Key, Box<dyn core::any::Any + Send>>>,
|
||||
) {
|
||||
if clients.is_none() {
|
||||
*clients = Some(HashMap::new());
|
||||
}
|
||||
|
||||
if let Some(clients) = clients {
|
||||
if clients.contains_key(&key) {
|
||||
panic!("Client already created for device {:?}", key);
|
||||
}
|
||||
|
||||
clients.insert(key, Box::new(client));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
mod base;
|
||||
|
||||
pub use base::*;
|
|
@ -0,0 +1,50 @@
|
|||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
|
||||
|
||||
//! Burn multi-backend router.
|
||||
|
||||
mod backend;
|
||||
mod bridge;
|
||||
mod channel;
|
||||
mod client;
|
||||
mod ops;
|
||||
mod runner;
|
||||
mod tensor;
|
||||
|
||||
pub use backend::*;
|
||||
pub use bridge::*;
|
||||
pub use channel::*;
|
||||
pub use client::*;
|
||||
pub use runner::*;
|
||||
pub use tensor::*;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use alloc::format;
|
||||
use alloc::vec;
|
||||
|
||||
use crate::BackendRouter;
|
||||
use crate::ByteBridge;
|
||||
use crate::DirectChannel;
|
||||
|
||||
type DirectByteChannel<Backends> = DirectChannel<Backends, ByteBridge<Backends>>;
|
||||
|
||||
pub type TestBackend1 = burn_ndarray::NdArray<f32, i32>;
|
||||
pub type TestBackend2 = burn_wgpu::Wgpu<f32, i32>;
|
||||
pub type TestBackend = BackendRouter<DirectByteChannel<(TestBackend1, TestBackend2)>>;
|
||||
|
||||
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
|
||||
pub type TestTensorBool<const D: usize> =
|
||||
burn_tensor::Tensor<TestBackend, D, burn_tensor::Bool>;
|
||||
|
||||
burn_tensor::testgen_all!();
|
||||
// TODO: add support for quantization
|
||||
// burn_tensor::testgen_quantization!();
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
burn_autodiff::testgen_all!();
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_float_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let rhs = $handles.get_float_tensor::<B>(&$desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
$handles.register_float_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_float_cmp_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let rhs = $handles.get_float_tensor::<B>(&$desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_int_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let rhs = $handles.get_int_tensor::<B>(&$desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! binary_int_cmp_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let rhs = $handles.get_int_tensor::<B>(&$desc.rhs);
|
||||
let output = $ops(lhs, rhs);
|
||||
|
||||
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
mod binary;
|
||||
mod op_activation;
|
||||
mod op_bool;
|
||||
mod op_float;
|
||||
mod op_int;
|
||||
mod op_module;
|
||||
mod op_qfloat;
|
||||
mod unary;
|
|
@ -0,0 +1,4 @@
|
|||
use crate::{BackendRouter, RunnerChannel};
|
||||
use burn_tensor::ops::ActivationOps;
|
||||
|
||||
impl<R: RunnerChannel> ActivationOps<Self> for BackendRouter<R> {}
|
|
@ -0,0 +1,302 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntElem, IntTensor};
|
||||
use burn_tensor::repr::{
|
||||
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
|
||||
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
||||
OperationDescription, PermuteOperationDescription, RepeatDimOperationDescription,
|
||||
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription,
|
||||
SwapDimsDescription, UnaryOperationDescription,
|
||||
};
|
||||
use burn_tensor::{DType, Device, Element, Shape, TensorData};
|
||||
|
||||
use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient};
|
||||
|
||||
impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
|
||||
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
// Get the runtime client on which to register the operation for execution.
|
||||
let client = get_client::<R>(device);
|
||||
let out = client.register_empty_tensor(shape.into(), DType::Bool);
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::Empty(out.to_description_out()),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_shape(tensor: &BoolTensor<Self>) -> Shape {
|
||||
Shape::from(tensor.shape.clone())
|
||||
}
|
||||
|
||||
async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
|
||||
tensor.into_data().await
|
||||
}
|
||||
|
||||
fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
let client = get_client::<R>(device);
|
||||
client.register_tensor_data(data.convert::<bool>())
|
||||
}
|
||||
|
||||
fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let out = client.register_empty_tensor(tensor.shape.clone(), IntElem::<Self>::dtype());
|
||||
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Bool(
|
||||
BoolOperationDescription::IntoInt(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let out = client.register_empty_tensor(tensor.shape.clone(), FloatElem::<Self>::dtype());
|
||||
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Bool(
|
||||
BoolOperationDescription::IntoFloat(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
|
||||
tensor.client.device()
|
||||
}
|
||||
|
||||
fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
|
||||
if &tensor.client.device() == device {
|
||||
return tensor;
|
||||
}
|
||||
R::change_client_backend(tensor, device)
|
||||
}
|
||||
|
||||
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let out = client.register_empty_tensor(shape.into(), tensor.dtype);
|
||||
|
||||
let desc = ReshapeDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::Reshape(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_slice(
|
||||
tensor: BoolTensor<Self>,
|
||||
ranges: &[core::ops::Range<usize>],
|
||||
) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let ndims = tensor.shape().num_dims();
|
||||
let mut shape: Vec<usize> = ranges.iter().map(|range| range.end - range.start).collect();
|
||||
|
||||
for i in shape.len()..ndims {
|
||||
shape.push(tensor.shape[i]);
|
||||
}
|
||||
|
||||
let out = client.register_empty_tensor(shape, tensor.dtype);
|
||||
|
||||
let desc = SliceOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
ranges: ranges.to_vec(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::Slice(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_slice_assign(
|
||||
tensor: BoolTensor<Self>,
|
||||
ranges: &[core::ops::Range<usize>],
|
||||
value: BoolTensor<Self>,
|
||||
) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype);
|
||||
|
||||
let desc = SliceAssignOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
ranges: ranges.to_vec(),
|
||||
value: value.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::SliceAssign(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
let client = lhs.client.clone();
|
||||
let out = client.register_empty_tensor(lhs.shape.clone(), DType::Bool);
|
||||
|
||||
let desc = BinaryOperationDescription {
|
||||
lhs: lhs.into_description(),
|
||||
rhs: rhs.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::Equal(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype);
|
||||
|
||||
let desc = UnaryOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Bool(BoolOperationDescription::Not(
|
||||
desc,
|
||||
)));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim1] = tensor.shape[dim2];
|
||||
shape[dim2] = tensor.shape[dim1];
|
||||
let out = client.register_empty_tensor(shape, tensor.dtype);
|
||||
|
||||
let desc = SwapDimsDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
dim1,
|
||||
dim2,
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::SwapDims(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
// Change the shape of the tensor to match the new axes
|
||||
let shape = axes.iter().map(|x| tensor.shape[*x]).collect();
|
||||
let out = client.register_empty_tensor(shape, tensor.dtype);
|
||||
|
||||
let desc = PermuteOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
axes: axes.to_vec(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::Permute(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let out = client.register_empty_tensor(tensor.shape.clone(), tensor.dtype);
|
||||
|
||||
let desc = FlipOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
axes: axes.to_vec(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::Flip(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let shape: Vec<_> = shape.into();
|
||||
let out = client.register_empty_tensor(shape.clone(), tensor.dtype);
|
||||
|
||||
let desc = ExpandOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
shape,
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::Expand(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
|
||||
let tensor_first = tensors.first().unwrap();
|
||||
let client = tensor_first.client.clone();
|
||||
let dtype = tensor_first.dtype;
|
||||
|
||||
// Calculate the output shape
|
||||
let mut shape = tensor_first.shape.clone();
|
||||
shape[dim] = 0;
|
||||
for tensor in tensors.iter() {
|
||||
shape[dim] += tensor.shape[dim];
|
||||
}
|
||||
let out = client.register_empty_tensor(shape, dtype);
|
||||
|
||||
let desc = CatOperationDescription {
|
||||
tensors: tensors.into_iter().map(|t| t.into_description()).collect(),
|
||||
dim,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::Cat(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
|
||||
let client = tensor.client.clone();
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] *= times;
|
||||
let out = client.register_empty_tensor(shape, tensor.dtype);
|
||||
|
||||
let desc = RepeatDimOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
times,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::BaseBool(
|
||||
BaseOperationDescription::RepeatDim(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,821 @@
|
|||
use alloc::{boxed::Box, vec};
|
||||
|
||||
use burn_tensor::ops::conv::{
|
||||
calculate_conv_output_size, calculate_conv_transpose_output_size, calculate_pool_output_size,
|
||||
};
|
||||
use burn_tensor::ops::{
|
||||
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor,
|
||||
IntElem, ModuleOps,
|
||||
};
|
||||
use burn_tensor::ops::{
|
||||
IntTensor, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
|
||||
MaxPool2dWithIndices,
|
||||
};
|
||||
use burn_tensor::repr::{
|
||||
AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription,
|
||||
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
|
||||
AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription,
|
||||
AvgPool2dDescription, Conv1dDescription, Conv2dDescription, Conv3dDescription,
|
||||
ConvTranspose1dDescription, ConvTranspose2dDescription, ConvTranspose3dDescription,
|
||||
DeformConv2dBackwardDescription, DeformConv2dDescription, InterpolateBackwardDescription,
|
||||
InterpolateDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
|
||||
MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
|
||||
MaxPool2dWithIndicesDescription, ModuleOperationDescription, OperationDescription,
|
||||
};
|
||||
use burn_tensor::Element;
|
||||
|
||||
use crate::{BackendRouter, RunnerChannel, RunnerClient};
|
||||
|
||||
impl<R: RunnerChannel> ModuleOps<Self> for BackendRouter<R> {
|
||||
fn conv1d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
let size = calculate_conv_output_size(
|
||||
weight.shape[2],
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
x.shape[2],
|
||||
);
|
||||
|
||||
let shape = vec![x.shape[0], weight.shape[0], size];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = Conv1dDescription {
|
||||
x: x.into_description(),
|
||||
weight: weight.into_description(),
|
||||
bias: bias.map(|bias| bias.into_description()),
|
||||
options: options.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::Conv1d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
let size_0 = calculate_conv_output_size(
|
||||
weight.shape[2],
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
x.shape[2],
|
||||
);
|
||||
let size_1 = calculate_conv_output_size(
|
||||
weight.shape[3],
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.dilation[1],
|
||||
x.shape[3],
|
||||
);
|
||||
|
||||
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = Conv2dDescription {
|
||||
x: x.into_description(),
|
||||
weight: weight.into_description(),
|
||||
bias: bias.map(|bias| bias.into_description()),
|
||||
options: options.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::Conv2d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn conv3d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
let size_0 = calculate_conv_output_size(
|
||||
weight.shape[2],
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
x.shape[2],
|
||||
);
|
||||
let size_1 = calculate_conv_output_size(
|
||||
weight.shape[3],
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.dilation[1],
|
||||
x.shape[3],
|
||||
);
|
||||
let size_2 = calculate_conv_output_size(
|
||||
weight.shape[4],
|
||||
options.stride[2],
|
||||
options.padding[2],
|
||||
options.dilation[2],
|
||||
x.shape[4],
|
||||
);
|
||||
|
||||
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1, size_2];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = Conv3dDescription {
|
||||
x: x.into_description(),
|
||||
weight: weight.into_description(),
|
||||
bias: bias.map(|bias| bias.into_description()),
|
||||
options: options.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::Conv3d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<1>,
|
||||
) -> FloatTensor<Self> {
|
||||
let size = calculate_conv_transpose_output_size(
|
||||
weight.shape[2],
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.padding_out[0],
|
||||
options.dilation[0],
|
||||
x.shape[2],
|
||||
);
|
||||
|
||||
let shape = vec![x.shape[0], weight.shape[1] * options.groups, size];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = ConvTranspose1dDescription {
|
||||
x: x.into_description(),
|
||||
weight: weight.into_description(),
|
||||
bias: bias.map(|bias| bias.into_description()),
|
||||
options: options.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::ConvTranspose1d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
let size_0 = calculate_conv_transpose_output_size(
|
||||
weight.shape[2],
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.padding_out[0],
|
||||
options.dilation[0],
|
||||
x.shape[2],
|
||||
);
|
||||
let size_1 = calculate_conv_transpose_output_size(
|
||||
weight.shape[3],
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.padding_out[1],
|
||||
options.dilation[1],
|
||||
x.shape[3],
|
||||
);
|
||||
|
||||
let shape = vec![x.shape[0], weight.shape[1] * options.groups, size_0, size_1];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = ConvTranspose2dDescription {
|
||||
x: x.into_description(),
|
||||
weight: weight.into_description(),
|
||||
bias: bias.map(|bias| bias.into_description()),
|
||||
options: options.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::ConvTranspose2d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn conv_transpose3d(
|
||||
x: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: ConvTransposeOptions<3>,
|
||||
) -> FloatTensor<Self> {
|
||||
let size_0 = calculate_conv_transpose_output_size(
|
||||
weight.shape[2],
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.padding_out[0],
|
||||
options.dilation[0],
|
||||
x.shape[2],
|
||||
);
|
||||
let size_1 = calculate_conv_transpose_output_size(
|
||||
weight.shape[3],
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.padding_out[1],
|
||||
options.dilation[1],
|
||||
x.shape[3],
|
||||
);
|
||||
let size_2 = calculate_conv_transpose_output_size(
|
||||
weight.shape[4],
|
||||
options.stride[2],
|
||||
options.padding[2],
|
||||
options.padding_out[2],
|
||||
options.dilation[2],
|
||||
x.shape[4],
|
||||
);
|
||||
|
||||
let shape = vec![
|
||||
x.shape[0],
|
||||
weight.shape[1] * options.groups,
|
||||
size_0,
|
||||
size_1,
|
||||
size_2,
|
||||
];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = ConvTranspose3dDescription {
|
||||
x: x.into_description(),
|
||||
weight: weight.into_description(),
|
||||
bias: bias.map(|bias| bias.into_description()),
|
||||
options: options.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::ConvTranspose3d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn avg_pool1d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
count_include_pad: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
let size = calculate_pool_output_size(kernel_size, stride, padding, 1, x.shape[2]);
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], size];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = AvgPool1dDescription {
|
||||
x: x.into_description(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::AvgPool1d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn avg_pool2d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
let size_0 =
|
||||
calculate_pool_output_size(kernel_size[0], stride[0], padding[0], 1, x.shape[2]);
|
||||
let size_1 =
|
||||
calculate_pool_output_size(kernel_size[1], stride[1], padding[1], 1, x.shape[3]);
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = AvgPool2dDescription {
|
||||
x: x.into_description(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::AvgPool2d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn avg_pool1d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
count_include_pad: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
|
||||
|
||||
let desc = AvgPool1dBackwardDescription {
|
||||
x: x.into_description(),
|
||||
grad: grad.into_description(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::AvgPool1dBackward(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn avg_pool2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
count_include_pad: bool,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
|
||||
|
||||
let desc = AvgPool2dBackwardDescription {
|
||||
x: x.into_description(),
|
||||
grad: grad.into_description(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::AvgPool2dBackward(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn max_pool1d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> FloatTensor<Self> {
|
||||
let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], size];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = MaxPool1dDescription {
|
||||
x: x.into_description(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::MaxPool1d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> FloatTensor<Self> {
|
||||
let size_0 = calculate_pool_output_size(
|
||||
kernel_size[0],
|
||||
stride[0],
|
||||
padding[0],
|
||||
dilation[0],
|
||||
x.shape[2],
|
||||
);
|
||||
let size_1 = calculate_pool_output_size(
|
||||
kernel_size[1],
|
||||
stride[1],
|
||||
padding[1],
|
||||
dilation[1],
|
||||
x.shape[3],
|
||||
);
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = MaxPool2dDescription {
|
||||
x: x.into_description(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::MaxPool2d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn max_pool1d_with_indices(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
) -> MaxPool1dWithIndices<Self> {
|
||||
let size = calculate_pool_output_size(kernel_size, stride, padding, dilation, x.shape[2]);
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], size];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape.clone(), x.dtype);
|
||||
let out_indices = client.register_empty_tensor(shape, IntElem::<Self>::dtype());
|
||||
|
||||
let desc = MaxPool1dWithIndicesDescription {
|
||||
x: x.into_description(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
out: out.to_description_out(),
|
||||
out_indices: out_indices.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::MaxPool1dWithIndices(desc),
|
||||
));
|
||||
|
||||
MaxPool1dWithIndices::new(out, out_indices)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
) -> MaxPool2dWithIndices<Self> {
|
||||
let size_0 = calculate_pool_output_size(
|
||||
kernel_size[0],
|
||||
stride[0],
|
||||
padding[0],
|
||||
dilation[0],
|
||||
x.shape[2],
|
||||
);
|
||||
let size_1 = calculate_pool_output_size(
|
||||
kernel_size[1],
|
||||
stride[1],
|
||||
padding[1],
|
||||
dilation[1],
|
||||
x.shape[3],
|
||||
);
|
||||
|
||||
let shape = vec![x.shape[0], x.shape[1], size_0, size_1];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape.clone(), x.dtype);
|
||||
let out_indices = client.register_empty_tensor(shape, IntElem::<Self>::dtype());
|
||||
|
||||
let desc = MaxPool2dWithIndicesDescription {
|
||||
x: x.into_description(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
out: out.to_description_out(),
|
||||
out_indices: out_indices.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::MaxPool2dWithIndices(desc),
|
||||
));
|
||||
|
||||
MaxPool2dWithIndices::new(out, out_indices)
|
||||
}
|
||||
|
||||
fn max_pool1d_with_indices_backward(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
dilation: usize,
|
||||
output_grad: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> MaxPool1dBackward<Self> {
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
|
||||
|
||||
let desc = MaxPool1dWithIndicesBackwardDescription {
|
||||
x: x.into_description(),
|
||||
grad: output_grad.into_description(),
|
||||
indices: indices.into_description(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::MaxPool1dWithIndicesBackward(desc),
|
||||
));
|
||||
|
||||
MaxPool1dBackward::new(out)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: FloatTensor<Self>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
dilation: [usize; 2],
|
||||
output_grad: FloatTensor<Self>,
|
||||
indices: IntTensor<Self>,
|
||||
) -> MaxPool2dBackward<Self> {
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
|
||||
|
||||
let desc = MaxPool2dWithIndicesBackwardDescription {
|
||||
x: x.into_description(),
|
||||
grad: output_grad.into_description(),
|
||||
indices: indices.into_description(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::MaxPool2dWithIndicesBackward(desc),
|
||||
));
|
||||
|
||||
MaxPool2dBackward::new(out)
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {
|
||||
let shape = vec![x.shape[0], x.shape[1], output_size];
|
||||
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape.clone(), x.dtype);
|
||||
|
||||
let desc = AdaptiveAvgPool1dDescription {
|
||||
x: x.into_description(),
|
||||
output_size,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::AdaptiveAvgPool1d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
|
||||
let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
|
||||
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape.clone(), x.dtype);
|
||||
|
||||
let desc = AdaptiveAvgPool2dDescription {
|
||||
x: x.into_description(),
|
||||
output_size,
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::AdaptiveAvgPool2d(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool1d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
|
||||
|
||||
let desc = AdaptiveAvgPool1dBackwardDescription {
|
||||
x: x.into_description(),
|
||||
grad: grad.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn adaptive_avg_pool2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
|
||||
|
||||
let desc = AdaptiveAvgPool2dBackwardDescription {
|
||||
x: x.into_description(),
|
||||
grad: grad.into_description(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn interpolate(
|
||||
x: FloatTensor<Self>,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
|
||||
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape.clone(), x.dtype);
|
||||
|
||||
let desc = InterpolateDescription {
|
||||
x: x.into_description(),
|
||||
output_size,
|
||||
options: options.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::Interpolate(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn interpolate_backward(
|
||||
x: FloatTensor<Self>,
|
||||
grad: FloatTensor<Self>,
|
||||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> FloatTensor<Self> {
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(x.shape.clone(), x.dtype);
|
||||
|
||||
let desc = InterpolateBackwardDescription {
|
||||
x: x.into_description(),
|
||||
grad: grad.into_description(),
|
||||
output_size,
|
||||
options: options.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::InterpolateBackward(desc),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
x: FloatTensor<Self>,
|
||||
offset: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
mask: Option<FloatTensor<Self>>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> FloatTensor<Self> {
|
||||
let size_0 = calculate_conv_output_size(
|
||||
weight.shape[2],
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
x.shape[2],
|
||||
);
|
||||
let size_1 = calculate_conv_output_size(
|
||||
weight.shape[3],
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.dilation[1],
|
||||
x.shape[3],
|
||||
);
|
||||
|
||||
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
|
||||
let client = x.client.clone();
|
||||
let out = client.register_empty_tensor(shape, x.dtype);
|
||||
|
||||
let desc = DeformConv2dDescription {
|
||||
x: x.into_description(),
|
||||
offset: offset.into_description(),
|
||||
weight: weight.into_description(),
|
||||
mask: mask.map(|mask| mask.into_description()),
|
||||
bias: bias.map(|bias| bias.into_description()),
|
||||
options: options.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::DeformableConv2d(Box::new(desc)),
|
||||
));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
x: FloatTensor<Self>,
|
||||
offset: FloatTensor<Self>,
|
||||
weight: FloatTensor<Self>,
|
||||
mask: Option<FloatTensor<Self>>,
|
||||
bias: Option<FloatTensor<Self>>,
|
||||
output_grad: FloatTensor<Self>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
let client = x.client.clone();
|
||||
|
||||
let input_grad = client.register_empty_tensor(x.shape.clone(), x.dtype);
|
||||
let offset_grad = client.register_empty_tensor(offset.shape.clone(), offset.dtype);
|
||||
let weight_grad = client.register_empty_tensor(weight.shape.clone(), weight.dtype);
|
||||
let mask_grad = mask
|
||||
.as_ref()
|
||||
.map(|mask| client.register_empty_tensor(mask.shape.clone(), mask.dtype));
|
||||
let bias_grad = bias
|
||||
.as_ref()
|
||||
.map(|bias| client.register_empty_tensor(bias.shape.clone(), bias.dtype));
|
||||
|
||||
let desc = DeformConv2dBackwardDescription {
|
||||
x: x.into_description(),
|
||||
offset: offset.into_description(),
|
||||
weight: weight.into_description(),
|
||||
mask: mask.map(|mask| mask.into_description()),
|
||||
bias: bias.map(|bias| bias.into_description()),
|
||||
options: options.into(),
|
||||
out_grad: output_grad.into_description(),
|
||||
input_grad: input_grad.to_description_out(),
|
||||
offset_grad: offset_grad.to_description_out(),
|
||||
weight_grad: weight_grad.to_description_out(),
|
||||
mask_grad: mask_grad
|
||||
.as_ref()
|
||||
.map(|mask_grad| mask_grad.to_description_out()),
|
||||
bias_grad: bias_grad
|
||||
.as_ref()
|
||||
.map(|bias_grad| bias_grad.to_description_out()),
|
||||
};
|
||||
|
||||
client.register(OperationDescription::Module(
|
||||
ModuleOperationDescription::DeformableConv2dBackward(Box::new(desc)),
|
||||
));
|
||||
|
||||
DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
use core::ops::Range;
|
||||
|
||||
use burn_tensor::{
|
||||
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
|
||||
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
|
||||
Device, Shape, TensorData,
|
||||
};
|
||||
|
||||
use crate::{BackendRouter, RunnerChannel};
|
||||
|
||||
impl<R: RunnerChannel> QTensorOps<Self> for BackendRouter<R> {
|
||||
fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn quantize(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantizationScheme,
|
||||
_qparams: QuantizationParametersPrimitive<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn quantize_dynamic(
|
||||
_tensor: FloatTensor<Self>,
|
||||
_scheme: &QuantizationScheme,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_shape(_tensor: &QuantizedTensor<Self>) -> Shape {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_device(_tensor: &QuantizedTensor<Self>) -> Device<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_to_device(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_device: &Device<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn q_into_data(_tensor: QuantizedTensor<Self>) -> TensorData {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_swap_dims(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim1: usize,
|
||||
_dim2: usize,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_gather(
|
||||
_dim: usize,
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_select(
|
||||
_tensor: QuantizedTensor<Self>,
|
||||
_dim: usize,
|
||||
_indices: IntTensor<Self>,
|
||||
) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_slice(_tensor: QuantizedTensor<Self>, _ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, ElementConversion::elem($desc.rhs));
|
||||
|
||||
$handles.register_float_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float_dim_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, $desc.rhs);
|
||||
|
||||
$handles.register_float_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float2int_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, $desc.rhs);
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_float_cmp_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, ElementConversion::elem($desc.rhs));
|
||||
|
||||
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! unary_float_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_float_tensor::<B>(&$desc.input);
|
||||
let output = $ops(lhs);
|
||||
|
||||
$handles.register_float_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_int_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, ElementConversion::elem($desc.rhs));
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! int_float_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, ElementConversion::elem($desc.rhs));
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_int_dim_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, $desc.rhs);
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! scalar_int_cmp_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);
|
||||
let output = $ops(lhs, ElementConversion::elem($desc.rhs));
|
||||
|
||||
$handles.register_bool_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! unary_int_ops {
|
||||
(
|
||||
$handles:expr, $desc:expr, $ops:expr
|
||||
) => {{
|
||||
let lhs = $handles.get_int_tensor::<B>(&$desc.input);
|
||||
let output = $ops(lhs);
|
||||
|
||||
$handles.register_int_tensor::<B>(&$desc.out.id, output);
|
||||
}};
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,122 @@
|
|||
use alloc::{sync::Arc, vec::Vec};
|
||||
|
||||
use super::RunnerClient;
|
||||
use burn_tensor::{
|
||||
repr::{TensorDescription, TensorId, TensorStatus},
|
||||
DType, Shape, TensorData,
|
||||
};
|
||||
|
||||
/// Tensor primitive for the [router backend](crate::BackendRouter).
|
||||
pub struct RouterTensor<C: RunnerClient> {
|
||||
pub(crate) id: Arc<TensorId>,
|
||||
pub(crate) shape: Vec<usize>,
|
||||
pub(crate) dtype: DType,
|
||||
pub(crate) client: C,
|
||||
|
||||
// Orphan means that a tensor is never converted into a description when it becomes `ReadWrite`.
|
||||
//
|
||||
// When a tensor is dropped and is still an orphan, we need to register it as such to avoid
|
||||
// memory leak.
|
||||
pub(crate) is_orphan: bool,
|
||||
}
|
||||
|
||||
impl<C: RunnerClient> RouterTensor<C> {
|
||||
pub(crate) fn new(id: Arc<TensorId>, shape: Vec<usize>, dtype: DType, client: C) -> Self {
|
||||
Self {
|
||||
id,
|
||||
shape,
|
||||
dtype,
|
||||
client,
|
||||
is_orphan: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn into_data(self) -> TensorData {
|
||||
self.client
|
||||
.clone()
|
||||
.read_tensor(self.into_description())
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) fn into_description(mut self) -> TensorDescription {
|
||||
let status = self.status();
|
||||
let mut shape_out = Vec::new();
|
||||
core::mem::swap(&mut self.shape, &mut shape_out);
|
||||
|
||||
if let TensorStatus::ReadWrite = status {
|
||||
self.is_orphan = false;
|
||||
}
|
||||
|
||||
TensorDescription {
|
||||
status,
|
||||
shape: shape_out,
|
||||
id: *self.id.as_ref(),
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_description_out(&self) -> TensorDescription {
|
||||
TensorDescription {
|
||||
status: TensorStatus::NotInit,
|
||||
shape: self.shape.clone(),
|
||||
id: *self.id.as_ref(),
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn shape(&self) -> Shape {
|
||||
Shape::from(self.shape.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn status(&self) -> TensorStatus {
|
||||
if Arc::strong_count(&self.id) <= 1 {
|
||||
TensorStatus::ReadWrite
|
||||
} else {
|
||||
TensorStatus::ReadOnly
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(
|
||||
format!(
|
||||
"{{ id: {:?}, shape: {:?}, dtype: {:?}, should_drop: {:?}, device: {:?} }}",
|
||||
self.id,
|
||||
self.shape,
|
||||
self.dtype,
|
||||
self.is_orphan,
|
||||
self.client.device().clone(),
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: RunnerClient> Clone for RouterTensor<C> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
id: self.id.clone(),
|
||||
shape: self.shape.clone(),
|
||||
client: self.client.clone(),
|
||||
dtype: self.dtype,
|
||||
is_orphan: self.is_orphan,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: RunnerClient> Drop for RouterTensor<C> {
|
||||
fn drop(&mut self) {
|
||||
if !self.is_orphan {
|
||||
return;
|
||||
}
|
||||
|
||||
match self.status() {
|
||||
TensorStatus::ReadWrite => {
|
||||
self.client.register_orphan(&self.id);
|
||||
}
|
||||
TensorStatus::ReadOnly => {}
|
||||
TensorStatus::NotInit => {}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -49,6 +49,9 @@ hashbrown = { workspace = true } # no_std compatible
|
|||
serde = { workspace = true }
|
||||
serde_bytes = { workspace = true }
|
||||
|
||||
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
|
||||
portable-atomic-util = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std
|
||||
|
||||
|
|
|
@ -4,16 +4,27 @@ use crate::{
|
|||
quantization::QuantizationScheme,
|
||||
Shape,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// A tensor representation containing a reference to a tensor resource with a given shape.
|
||||
pub struct TensorHandle<H> {
|
||||
#[derive(Clone)]
|
||||
pub struct TensorHandle<H: Clone> {
|
||||
/// The type that can be used to point to a tensor of any kind.
|
||||
pub handle: H,
|
||||
/// The shape associated to the tensor.
|
||||
pub shape: Shape,
|
||||
}
|
||||
|
||||
/// A simple struct to encapsulate a quantized tensor kind.
|
||||
#[derive(Clone)]
|
||||
pub struct QuantizedKind<T: Clone> {
|
||||
/// The quantized tensor.
|
||||
pub tensor: T,
|
||||
/// The scaling factor.
|
||||
pub scale: T,
|
||||
/// The zero-point offset.
|
||||
pub offset: Option<T>,
|
||||
}
|
||||
|
||||
/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation
|
||||
/// for compilation purpose or other...
|
||||
pub trait ReprBackend: Backend {
|
||||
|
@ -28,7 +39,7 @@ pub trait ReprBackend: Backend {
|
|||
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;
|
||||
/// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive).
|
||||
fn quantized_tensor(
|
||||
handles: Vec<TensorHandle<Self::Handle>>,
|
||||
handle: QuantizedKind<TensorHandle<Self::Handle>>,
|
||||
scheme: QuantizationScheme,
|
||||
) -> QuantizedTensor<Self>;
|
||||
|
||||
|
@ -40,5 +51,33 @@ pub trait ReprBackend: Backend {
|
|||
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;
|
||||
/// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle).
|
||||
/// A quantized tensor has multiple handles for the tensor itself and the quantization parameters.
|
||||
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Vec<Self::Handle>;
|
||||
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle>;
|
||||
}
|
||||
|
||||
/// Handle which points to a backend tensor primitive kind.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum HandleKind<B: Backend> {
|
||||
/// Float tensor handle.
|
||||
Float(B::FloatTensorPrimitive),
|
||||
/// Int tensor handle.
|
||||
Int(B::IntTensorPrimitive),
|
||||
/// Bool tensor handle.
|
||||
Bool(B::BoolTensorPrimitive),
|
||||
/// Quantized tensor handle.
|
||||
Quantized(B::QuantizedTensorPrimitive),
|
||||
/// Empty handle (used as a dummy representation).
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl<B: Backend> HandleKind<B> {
|
||||
/// Returns the handle kind name.
|
||||
pub fn name(&self) -> &str {
|
||||
match self {
|
||||
HandleKind::Float(_) => "float",
|
||||
HandleKind::Int(_) => "int",
|
||||
HandleKind::Bool(_) => "bool",
|
||||
HandleKind::Quantized(_) => "quantized",
|
||||
HandleKind::Empty => unreachable!(), // should not happen
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,9 +5,16 @@ use crate::{
|
|||
},
|
||||
Shape,
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use alloc::vec::Vec;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
use super::{QuantizedTensorDescription, TensorHandle};
|
||||
#[cfg(target_has_atomic = "ptr")]
|
||||
use alloc::sync::Arc;
|
||||
|
||||
#[cfg(not(target_has_atomic = "ptr"))]
|
||||
use portable_atomic_util::Arc;
|
||||
|
||||
use super::{QuantizedKind, QuantizedTensorDescription, TensorHandle};
|
||||
|
||||
/// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources
|
||||
/// are used optimally.
|
||||
|
@ -19,6 +26,16 @@ pub struct HandleContainer<H> {
|
|||
pub handles_orphan: Vec<TensorId>,
|
||||
}
|
||||
|
||||
impl<H> core::fmt::Debug for HandleContainer<H> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.debug_struct("HandleContainer")
|
||||
.field("handles", &self.handles.keys()) // only care about the IDs when debugging
|
||||
.field("counter", &self.counter)
|
||||
.field("handles_orphan", &self.handles_orphan)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state
|
||||
pub enum Handle<H> {
|
||||
/// No [tensor handle](ReprBackend::Handle) has been created yet
|
||||
|
@ -69,7 +86,7 @@ impl<H: Clone> HandleContainer<H> {
|
|||
}
|
||||
|
||||
/// Get the tensor handle for the given [tensor description](TensorDescription).
|
||||
fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle<H> {
|
||||
pub fn get_tensor_handle(&mut self, tensor: &TensorDescription) -> TensorHandle<H> {
|
||||
TensorHandle {
|
||||
handle: self.get_handle(&tensor.id, &tensor.status),
|
||||
shape: Shape::from(&tensor.shape),
|
||||
|
@ -112,12 +129,14 @@ impl<H: Clone> HandleContainer<H> {
|
|||
where
|
||||
B: ReprBackend<Handle = H>,
|
||||
{
|
||||
let qtensor = self.get_tensor_handle(&tensor.tensor);
|
||||
let scale = self.get_tensor_handle(&tensor.qparams.scale);
|
||||
let handles = if let Some(offset) = &tensor.qparams.offset {
|
||||
vec![qtensor, scale, self.get_tensor_handle(offset)]
|
||||
} else {
|
||||
vec![qtensor, scale]
|
||||
let handles = QuantizedKind {
|
||||
tensor: self.get_tensor_handle(&tensor.tensor),
|
||||
scale: self.get_tensor_handle(&tensor.qparams.scale),
|
||||
offset: tensor
|
||||
.qparams
|
||||
.offset
|
||||
.as_ref()
|
||||
.map(|offset| self.get_tensor_handle(offset)),
|
||||
};
|
||||
B::quantized_tensor(handles, tensor.scheme.clone())
|
||||
}
|
||||
|
@ -134,20 +153,20 @@ impl<H: Clone> HandleContainer<H> {
|
|||
/// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).
|
||||
pub fn register_quantized_tensor<B>(
|
||||
&mut self,
|
||||
ids: &[&TensorId],
|
||||
id: &QuantizedKind<TensorId>,
|
||||
tensor: B::QuantizedTensorPrimitive,
|
||||
) where
|
||||
B: ReprBackend<Handle = H>,
|
||||
{
|
||||
let handles = B::quantized_tensor_handle(tensor);
|
||||
assert_eq!(
|
||||
ids.len(),
|
||||
handles.len(),
|
||||
"Number of tensor ids and handles must match"
|
||||
);
|
||||
|
||||
for (handle, id) in handles.into_iter().zip(ids) {
|
||||
self.handles.insert(**id, Handle::Existing(handle));
|
||||
self.handles
|
||||
.insert(id.tensor, Handle::Existing(handles.tensor));
|
||||
self.handles
|
||||
.insert(id.scale, Handle::Existing(handles.scale));
|
||||
|
||||
if let (Some(id), Some(handle)) = (id.offset, handles.offset) {
|
||||
self.handles.insert(id, Handle::Existing(handle));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
use core::ops::Range;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::Range;
|
||||
|
||||
use alloc::boxed::Box;
|
||||
use alloc::{vec, vec::Vec};
|
||||
|
||||
use crate::{
|
||||
ops::{
|
||||
|
@ -214,6 +217,13 @@ pub enum BaseOperationDescription {
|
|||
Cat(CatOperationDescription),
|
||||
/// Cast operation, no direct operation and should be supported by fusion backend.
|
||||
Cast(UnaryOperationDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [equal](crate::ops::FloatTensorOps::float_empty).
|
||||
/// Int => [equal](crate::ops::IntTensorOps::int_empty).
|
||||
/// Bool => [equal](crate::ops::BoolTensorOps::bool_empty).
|
||||
Empty(TensorDescription),
|
||||
}
|
||||
|
||||
/// Numeric operations on int and float tensors.
|
||||
|
@ -1286,6 +1296,7 @@ impl BaseOperationDescription {
|
|||
}
|
||||
BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(),
|
||||
BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out],
|
||||
BaseOperationDescription::Empty(desc) => vec![desc],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1605,7 +1616,7 @@ impl ModuleOperationDescription {
|
|||
}
|
||||
|
||||
impl core::hash::Hash for RandomOperationDescription {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
self.out.hash(state);
|
||||
|
||||
match self.distribution {
|
||||
|
@ -1618,14 +1629,14 @@ impl core::hash::Hash for RandomOperationDescription {
|
|||
}
|
||||
|
||||
impl<E> core::hash::Hash for ScalarOperationDescription<E> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
self.lhs.hash(state);
|
||||
self.out.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> core::hash::Hash for MaskFillOperationDescription<E> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
self.tensor.hash(state);
|
||||
self.mask.hash(state);
|
||||
self.out.hash(state);
|
||||
|
@ -1633,14 +1644,14 @@ impl<E> core::hash::Hash for MaskFillOperationDescription<E> {
|
|||
}
|
||||
|
||||
impl<E> core::hash::Hash for ClampOperationDescription<E> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
self.tensor.hash(state);
|
||||
self.out.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> core::hash::Hash for NumericOperationDescription<E> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
NumericOperationDescription::Add(desc) => desc.hash(state),
|
||||
NumericOperationDescription::AddScalar(desc) => desc.hash(state),
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use crate::DType;
|
||||
|
||||
/// The tensor unique identifier.
|
||||
|
|
|
@ -59,8 +59,8 @@ pub trait Backend:
|
|||
+ ActivationOps<Self>
|
||||
+ QTensorOps<Self>
|
||||
+ Clone
|
||||
+ Sized
|
||||
+ Default
|
||||
+ Sized
|
||||
+ Send
|
||||
+ Sync
|
||||
+ core::fmt::Debug
|
||||
|
|
|
@ -258,3 +258,19 @@ pub enum DType {
|
|||
Bool,
|
||||
QFloat(QuantizationStrategy),
|
||||
}
|
||||
|
||||
impl DType {
|
||||
/// Returns true if the data type is a floating point type.
|
||||
pub fn is_float(&self) -> bool {
|
||||
matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)
|
||||
}
|
||||
/// Returns true if the data type is a signed integer type.
|
||||
pub fn is_int(&self) -> bool {
|
||||
matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
|
||||
}
|
||||
|
||||
/// Returns true if the data type is a boolean type
|
||||
pub fn is_bool(&self) -> bool {
|
||||
matches!(self, DType::Bool)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
/// Computes the output shape for binary operations with broadcasting support.
|
||||
pub fn binary_ops_shape(lhs: &[usize], rhs: &[usize]) -> Vec<usize> {
|
||||
let mut shape_out = Vec::with_capacity(lhs.len());
|
||||
|
||||
for (l, r) in lhs.iter().zip(rhs.iter()) {
|
||||
shape_out.push(usize::max(*l, *r));
|
||||
}
|
||||
|
||||
shape_out
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
mod activation;
|
||||
mod alias;
|
||||
mod binary;
|
||||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
mod modules;
|
||||
|
@ -8,6 +9,7 @@ mod tensor;
|
|||
|
||||
pub use activation::*;
|
||||
pub use alias::*;
|
||||
pub use binary::*;
|
||||
pub use bool_tensor::*;
|
||||
pub use int_tensor::*;
|
||||
pub use modules::*;
|
||||
|
|
Loading…
Reference in New Issue