This commit is contained in:
nathaniel 2024-03-02 12:27:43 -05:00
parent 613a98094e
commit 623948d8e3
5 changed files with 151 additions and 71 deletions

View File

@ -5,7 +5,7 @@ use crate::{
tensor::AutodiffTensor,
};
use burn_tensor::{
backend::{AutodiffBackend, Backend, BackendMovement},
backend::{AutodiffBackend, Backend, BackendBridge, BackendPrecisionSettings},
Element,
};
use core::marker::PhantomData;
@ -20,20 +20,29 @@ pub struct Autodiff<B, C = NoCheckpointing> {
_checkpoint_strategy: PhantomData<C>,
}
impl<B, C, TF, TI> BackendMovement<Self, TF, TI> for Autodiff<B, C>
impl<B, C, TF: Element, TI: Element> BackendBridge<BackendPrecisionSettings<TF, TI>>
for Autodiff<B, C>
where
B: Backend + BackendMovement<B, TF, TI>,
B: Backend + BackendBridge<BackendPrecisionSettings<TF, TI>>,
C: CheckpointStrategy,
TF: Element,
TI: Element,
{
type TargetBackend = Autodiff<<B as BackendMovement<B, TF, TI>>::TargetBackend, C>;
type InputBackend = Self;
type TargetBackend =
Autodiff<<B as BackendBridge<BackendPrecisionSettings<TF, TI>>>::TargetBackend, C>;
fn move_float<const D: usize>(
fn bridge_float<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self, D>,
settings: BackendPrecisionSettings<TF, TI>,
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
todo!()
}
fn bridge_int<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self, D>,
settings: BackendPrecisionSettings<TF, TI>,
) -> burn_tensor::ops::IntTensor<Self::TargetBackend, D> {
todo!()
}
}
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
@ -42,12 +51,11 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
type FullPrecisionElem = B::FullPrecisionElem;
type FullPrecisionBackend = Autodiff<B::FullPrecisionBackend>;
type FloatTensorPrimitive<const D: usize> = AutodiffTensor<B, D>;
type FloatElem = B::FloatElem;
type IntTensorPrimitive<const D: usize> = B::IntTensorPrimitive<D>;
type IntElem = B::IntElem;
type FloatTensorPrimitive<const D: usize> = AutodiffTensor<B, D>;
type IntTensorPrimitive<const D: usize> = B::IntTensorPrimitive<D>;
type BoolTensorPrimitive<const D: usize> = B::BoolTensorPrimitive<D>;
fn ad_enabled() -> bool {

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData;
use burn_tensor::backend::Backend;
use burn_tensor::backend::{Backend, BackendBridge, BackendPrecisionSettings};
use candle_core::DeviceLocation;
use crate::{
@ -66,6 +66,26 @@ impl Default for CandleDevice {
}
}
impl<F: FloatCandleElement, I: IntCandleElement, TF: FloatCandleElement, TI: IntCandleElement>
BackendBridge<BackendPrecisionSettings<TF, TI>> for Candle<F, I>
{
type InputBackend = Self;
type TargetBackend = Candle<TF, TI>;
fn bridge_float<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self::InputBackend, D>,
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
todo!()
}
fn bridge_int<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self::InputBackend, D>,
) -> burn_tensor::ops::IntTensor<Self::TargetBackend, D> {
todo!()
}
}
impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type Device = CandleDevice;

View File

@ -2,10 +2,7 @@ use crate::element::FloatNdArrayElement;
use crate::NdArrayTensor;
use alloc::string::String;
use burn_common::stub::Mutex;
use burn_tensor::{
backend::{Backend, BackendMovement},
Tensor,
};
use burn_tensor::{backend::{Backend, BackendBridge, BackendPrecisionSettings, DoublePrecision}, Tensor};
use core::marker::PhantomData;
use rand::{rngs::StdRng, SeedableRng};
@ -24,28 +21,35 @@ impl Default for NdArrayDevice {
}
}
struct NdArraySettings<F: FloatNdArrayElement> {
_float: PhantomData<F>,
}
impl<F: FloatNdArrayElement, TF: FloatNdArrayElement> BackendMovement<Self, TF, i32>
for NdArray<F>
impl<F, TF> BackendBridge<BackendPrecisionSettings<TF, i64>> for NdArray<F>
where
F: FloatNdArrayElement,
TF: FloatNdArrayElement,
{
type InputBackend = Self;
type TargetBackend = NdArray<TF>;
fn move_float<const D: usize>(
fn bridge_float<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self, D>,
_settings: BackendPrecisionSettings<TF, i64>,
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
let array = tensor.array.mapv(|a| a.elem()).into_shared();
NdArrayTensor { array }
}
fn bridge_int<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self, D>,
_settings: BackendPrecisionSettings<TF, i64>,
) -> burn_tensor::ops::IntTensor<Self::TargetBackend, D> {
let array = tensor.array;
NdArrayTensor { array }
}
}
#[allow(dead_code)]
fn allo() {
let tensor: Tensor<NdArray<f32>, 2> = Tensor::ones([32, 32], &Default::default());
let tensor_full: Tensor<NdArray<f64>, 2> = tensor.clone().cast();
tensor + tensor_full.cast();
let tensor: Tensor<NdArray<f32>, 1> = Tensor::ones([32], &Default::default());
let tensor_full: Tensor<NdArray<f64>, 1> = tensor.bridge(DoublePrecision::default());
}
/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
@ -63,13 +67,12 @@ impl<E: FloatNdArrayElement> Backend for NdArray<E> {
type FullPrecisionBackend = NdArray<f32>;
type FloatTensorPrimitive<const D: usize> = NdArrayTensor<E, D>;
type FloatElem = E;
type IntTensorPrimitive<const D: usize> = NdArrayTensor<i64, D>;
type IntElem = i32;
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
type FloatElem = E;
type IntElem = i64;
fn ad_enabled() -> bool {
false
}

View File

@ -1,14 +1,14 @@
use alloc::vec::Vec;
use core::convert::TryInto;
use crate::backend::BackendMovement;
use crate::backend::{BackendBridge, BackendBridgeSettings};
use crate::check;
use crate::check::TensorCheck;
use crate::tensor::backend::Backend;
use crate::tensor::stats;
use crate::tensor::{Data, Distribution, Shape};
use crate::Int;
use crate::Tensor;
use crate::{check, Element};
impl<const D: usize, B> Tensor<B, D>
where
@ -32,16 +32,15 @@ where
core::mem::swap(&mut tensor_new, self);
}
/// Applies element wise exponential operation.
pub fn cast<F: Element, I: Element>(
/// TODO
pub fn bridge<S: BackendBridgeSettings>(
self,
) -> Tensor<<B as BackendMovement<B, F, I>>::TargetBackend, D>
settings: S,
) -> Tensor<<B as BackendBridge<S>>::TargetBackend, D>
where
B: BackendMovement<B, F, I>,
B: BackendBridge<S, InputBackend = B>,
{
Tensor::from_primitive(<B as BackendMovement<B, F, I>>::move_float(
self.into_primitive(),
))
Tensor::from_primitive(B::bridge_float(self.into_primitive(), settings))
}
/// Applies element wise exponential operation.

View File

@ -64,7 +64,10 @@ pub trait Backend:
+ Sync
+ core::fmt::Debug
+ 'static
+ BackendMovement<Self, f32, i32>
+ BackendBridge<
BackendPrecisionSettings<Self::FullPrecisionElem, IntElem<Self>>,
InputBackend = Self,
>
{
/// Device type.
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
@ -76,17 +79,19 @@ pub trait Backend:
/// Tensor primitive to be used for all float operations.
type FloatTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
/// Float element type.
type FloatElem: Element;
/// Tensor primitive to be used for all int operations.
type IntTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
/// Int element type.
type IntElem: Element;
/// Tensor primitive to be used for all bool operations.
type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
/// Float element type.
type FloatElem: Element;
/// Int element type.
type IntElem: Element;
/// If autodiff is enabled.
fn ad_enabled() -> bool {
false
@ -100,39 +105,84 @@ pub trait Backend:
/// Sync the backend, ensure that all computation are finished.
fn sync(_device: &Self::Device) {}
}
fn move_float<const D: usize, TF: Element, TI: Element>(
tensor: FloatTensor<Self, D>,
) -> FloatTensor<<Self as BackendMovement<Self, TF, TI>>::TargetBackend, D>
where
Self: BackendMovement<Self, TF, TI>,
{
<Self as BackendMovement<Self, TF, TI>>::move_float(tensor)
pub type BackendPrecision<B: Backend> = BackendPrecisionSettings<FloatElem<B>, IntElem<B>>;
pub type FullPrecision = BackendPrecisionSettings<f32, i32>;
pub type DoublePrecision = BackendPrecisionSettings<f64, i64>;
#[derive(Default)]
pub struct BackendPrecisionSettings<F: Element, I: Element> {
_float: PhantomData<F>,
_int: PhantomData<I>,
}
pub trait BackendBridgeSettings {
type FloatElem: Element;
type IntElem: Element;
}
impl<F: Element, I: Element> BackendBridgeSettings for BackendPrecisionSettings<F, I> {
type FloatElem = F;
type IntElem = I;
}
pub trait BackendBridge<S: BackendBridgeSettings> {
type InputBackend: Backend;
type TargetBackend: Backend<FloatElem = S::FloatElem, IntElem = S::IntElem>;
fn bridge_float<const D: usize>(
tensor: FloatTensor<Self::InputBackend, D>,
settings: S,
) -> FloatTensor<Self::TargetBackend, D>;
fn bridge_int<const D: usize>(
tensor: IntTensor<Self::InputBackend, D>,
settings: S,
) -> IntTensor<Self::TargetBackend, D>;
}
impl<Target: Backend> BackendBridgeSettings for FromDataBackendBridge<Target> {
type FloatElem = FloatElem<Target>;
type IntElem = IntElem<Target>;
}
#[derive(new)]
pub struct FromDataBackendBridge<Target: Backend> {
device: Device<Target>,
_b: PhantomData<Target>,
}
impl<Input: Backend, Target: Backend> BackendBridge<FromDataBackendBridge<Target>> for Input {
type InputBackend = Input;
type TargetBackend = Target;
fn bridge_float<const D: usize>(
tensor: FloatTensor<Self::InputBackend, D>,
settings: FromDataBackendBridge<Target>,
) -> FloatTensor<Self::TargetBackend, D> {
let data = Input::float_into_data(tensor).read_sync().unwrap();
Target::float_from_data(data.convert(), &settings.device)
}
fn bridge_int<const D: usize>(
tensor: IntTensor<Self::InputBackend, D>,
settings: FromDataBackendBridge<Target>,
) -> IntTensor<Self::TargetBackend, D> {
let data = Input::int_into_data(tensor).read_sync().unwrap();
Target::int_from_data(data.convert(), &settings.device)
}
}
pub struct Settings<F: Element, I: Element> {
_f: PhantomData<F>,
_i: PhantomData<I>,
}
pub trait BackendMovement<B: Backend, TF, TI>
pub fn a_function<B: Backend>(
tensor: FloatTensor<B, 1>,
) -> FloatTensor<<B as BackendBridge<FullPrecision>>::TargetBackend, 1>
where
TF: Element,
TI: Element,
B: BackendBridge<FullPrecision, InputBackend = B>,
{
type TargetBackend: Backend<FloatElem = TF, IntElem = TI>;
fn move_float<const D: usize>(tensor: FloatTensor<B, D>)
-> FloatTensor<Self::TargetBackend, D>;
}
pub fn a_function<B>(tensor: FloatTensor<B, 1>)
where
B: Backend,
B: BackendMovement<B, f64, i64>,
{
let tensor_f64 = <B as Backend>::move_float::<1, f64, i64>(tensor);
B::bridge_float(tensor, FullPrecision::default())
}
/// Trait that allows a backend to support autodiff.