This commit is contained in:
nathaniel 2024-05-03 13:35:27 -04:00
parent a8661a2f53
commit 806d4a844b
3 changed files with 47 additions and 0 deletions

View File

@ -9,6 +9,7 @@ extern crate derive_new;
extern crate alloc;
mod multi;
mod tensor;
/// Burn Tensor representaton

View File

@ -0,0 +1,43 @@
use crate::{
backend::{Backend, BackendBridge},
ops::FloatTensor,
DType,
};
pub enum MultiPrecisionFloatTensor<B: MultiPrecisionBackend, const D: usize> {
F16(FloatTensor<B::F16Backend, D>),
F32(FloatTensor<B::F32Backend, D>),
F64(FloatTensor<B::F64Backend, D>),
}
pub trait MultiPrecisionBackend: Sized {
type Bridge<O: Backend, T: Backend>: BackendBridge<O, Target = T>;
type F32Backend: Backend;
type F64Backend: Backend;
type F16Backend: Backend;
fn cast(
tensor: MultiPrecisionFloatTensor<Self, 2>,
dtype: DType,
) -> MultiPrecisionFloatTensor<Self, 2> {
match tensor {
MultiPrecisionFloatTensor::F16(tensor) => match dtype {
DType::F16 => todo!(),
DType::F32 => {
MultiPrecisionFloatTensor::F32(
<Bridge<Self, Self::F16Backend, Self::F32Backend>>::into_target(
tensor, None,
),
)
}
DType::F64 => todo!(),
_ => panic!("Unsupported."),
},
MultiPrecisionFloatTensor::F32(_) => todo!(),
MultiPrecisionFloatTensor::F64(_) => todo!(),
}
}
}
type Bridge<B, O, T> = <B as MultiPrecisionBackend>::Bridge<O, T>;

View File

@ -0,0 +1,3 @@
mod backend;
pub use backend::*;