From 806d4a844baea13ca8947f36ce42945478e318b7 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Fri, 3 May 2024 13:35:27 -0400 Subject: [PATCH] Draft --- crates/burn-tensor/src/lib.rs | 1 + crates/burn-tensor/src/multi/backend.rs | 43 +++++++++++++++++++++++++ crates/burn-tensor/src/multi/mod.rs | 3 ++ 3 files changed, 47 insertions(+) create mode 100644 crates/burn-tensor/src/multi/backend.rs create mode 100644 crates/burn-tensor/src/multi/mod.rs diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index 18156030b..1026d84af 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -9,6 +9,7 @@ extern crate derive_new; extern crate alloc; +mod multi; mod tensor; /// Burn Tensor representaton diff --git a/crates/burn-tensor/src/multi/backend.rs b/crates/burn-tensor/src/multi/backend.rs new file mode 100644 index 000000000..52535e00a --- /dev/null +++ b/crates/burn-tensor/src/multi/backend.rs @@ -0,0 +1,43 @@ +use crate::{ + backend::{Backend, BackendBridge}, + ops::FloatTensor, + DType, +}; + +pub enum MultiPrecisionFloatTensor { + F16(FloatTensor), + F32(FloatTensor), + F64(FloatTensor), +} + +pub trait MultiPrecisionBackend: Sized { + type Bridge: BackendBridge; + + type F32Backend: Backend; + type F64Backend: Backend; + type F16Backend: Backend; + + fn cast( + tensor: MultiPrecisionFloatTensor, + dtype: DType, + ) -> MultiPrecisionFloatTensor { + match tensor { + MultiPrecisionFloatTensor::F16(tensor) => match dtype { + DType::F16 => todo!(), + DType::F32 => { + MultiPrecisionFloatTensor::F32( + >::into_target( + tensor, None, + ), + ) + } + DType::F64 => todo!(), + _ => panic!("Unsupported."), + }, + MultiPrecisionFloatTensor::F32(_) => todo!(), + MultiPrecisionFloatTensor::F64(_) => todo!(), + } + } +} + +type Bridge = ::Bridge; diff --git a/crates/burn-tensor/src/multi/mod.rs b/crates/burn-tensor/src/multi/mod.rs new file mode 100644 index 000000000..4e63e7bb0 --- /dev/null +++ b/crates/burn-tensor/src/multi/mod.rs @@ -0,0 +1,3 @@ +mod backend; + +pub use backend::*;