mirror of https://github.com/tracel-ai/burn.git
WIP
This commit is contained in:
parent
613a98094e
commit
623948d8e3
|
@ -5,7 +5,7 @@ use crate::{
|
|||
tensor::AutodiffTensor,
|
||||
};
|
||||
use burn_tensor::{
|
||||
backend::{AutodiffBackend, Backend, BackendMovement},
|
||||
backend::{AutodiffBackend, Backend, BackendBridge, BackendPrecisionSettings},
|
||||
Element,
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
|
@ -20,20 +20,29 @@ pub struct Autodiff<B, C = NoCheckpointing> {
|
|||
_checkpoint_strategy: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<B, C, TF, TI> BackendMovement<Self, TF, TI> for Autodiff<B, C>
|
||||
impl<B, C, TF: Element, TI: Element> BackendBridge<BackendPrecisionSettings<TF, TI>>
|
||||
for Autodiff<B, C>
|
||||
where
|
||||
B: Backend + BackendMovement<B, TF, TI>,
|
||||
B: Backend + BackendBridge<BackendPrecisionSettings<TF, TI>>,
|
||||
C: CheckpointStrategy,
|
||||
TF: Element,
|
||||
TI: Element,
|
||||
{
|
||||
type TargetBackend = Autodiff<<B as BackendMovement<B, TF, TI>>::TargetBackend, C>;
|
||||
type InputBackend = Self;
|
||||
type TargetBackend =
|
||||
Autodiff<<B as BackendBridge<BackendPrecisionSettings<TF, TI>>>::TargetBackend, C>;
|
||||
|
||||
fn move_float<const D: usize>(
|
||||
fn bridge_float<const D: usize>(
|
||||
tensor: burn_tensor::ops::FloatTensor<Self, D>,
|
||||
settings: BackendPrecisionSettings<TF, TI>,
|
||||
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn bridge_int<const D: usize>(
|
||||
tensor: burn_tensor::ops::IntTensor<Self, D>,
|
||||
settings: BackendPrecisionSettings<TF, TI>,
|
||||
) -> burn_tensor::ops::IntTensor<Self::TargetBackend, D> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
||||
|
@ -42,12 +51,11 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
|||
type FullPrecisionElem = B::FullPrecisionElem;
|
||||
type FullPrecisionBackend = Autodiff<B::FullPrecisionBackend>;
|
||||
|
||||
type FloatTensorPrimitive<const D: usize> = AutodiffTensor<B, D>;
|
||||
type FloatElem = B::FloatElem;
|
||||
|
||||
type IntTensorPrimitive<const D: usize> = B::IntTensorPrimitive<D>;
|
||||
type IntElem = B::IntElem;
|
||||
|
||||
type FloatTensorPrimitive<const D: usize> = AutodiffTensor<B, D>;
|
||||
type IntTensorPrimitive<const D: usize> = B::IntTensorPrimitive<D>;
|
||||
type BoolTensorPrimitive<const D: usize> = B::BoolTensorPrimitive<D>;
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::backend::{Backend, BackendBridge, BackendPrecisionSettings};
|
||||
use candle_core::DeviceLocation;
|
||||
|
||||
use crate::{
|
||||
|
@ -66,6 +66,26 @@ impl Default for CandleDevice {
|
|||
}
|
||||
}
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement, TF: FloatCandleElement, TI: IntCandleElement>
|
||||
BackendBridge<BackendPrecisionSettings<TF, TI>> for Candle<F, I>
|
||||
{
|
||||
type InputBackend = Self;
|
||||
|
||||
type TargetBackend = Candle<TF, TI>;
|
||||
|
||||
fn bridge_float<const D: usize>(
|
||||
tensor: burn_tensor::ops::FloatTensor<Self::InputBackend, D>,
|
||||
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn bridge_int<const D: usize>(
|
||||
tensor: burn_tensor::ops::IntTensor<Self::InputBackend, D>,
|
||||
) -> burn_tensor::ops::IntTensor<Self::TargetBackend, D> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
|
||||
type Device = CandleDevice;
|
||||
|
||||
|
|
|
@ -2,10 +2,7 @@ use crate::element::FloatNdArrayElement;
|
|||
use crate::NdArrayTensor;
|
||||
use alloc::string::String;
|
||||
use burn_common::stub::Mutex;
|
||||
use burn_tensor::{
|
||||
backend::{Backend, BackendMovement},
|
||||
Tensor,
|
||||
};
|
||||
use burn_tensor::{backend::{Backend, BackendBridge, BackendPrecisionSettings, DoublePrecision}, Tensor};
|
||||
use core::marker::PhantomData;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
|
@ -24,28 +21,35 @@ impl Default for NdArrayDevice {
|
|||
}
|
||||
}
|
||||
|
||||
struct NdArraySettings<F: FloatNdArrayElement> {
|
||||
_float: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: FloatNdArrayElement, TF: FloatNdArrayElement> BackendMovement<Self, TF, i32>
|
||||
for NdArray<F>
|
||||
impl<F, TF> BackendBridge<BackendPrecisionSettings<TF, i64>> for NdArray<F>
|
||||
where
|
||||
F: FloatNdArrayElement,
|
||||
TF: FloatNdArrayElement,
|
||||
{
|
||||
type InputBackend = Self;
|
||||
type TargetBackend = NdArray<TF>;
|
||||
|
||||
fn move_float<const D: usize>(
|
||||
fn bridge_float<const D: usize>(
|
||||
tensor: burn_tensor::ops::FloatTensor<Self, D>,
|
||||
_settings: BackendPrecisionSettings<TF, i64>,
|
||||
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
|
||||
let array = tensor.array.mapv(|a| a.elem()).into_shared();
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn bridge_int<const D: usize>(
|
||||
tensor: burn_tensor::ops::IntTensor<Self, D>,
|
||||
_settings: BackendPrecisionSettings<TF, i64>,
|
||||
) -> burn_tensor::ops::IntTensor<Self::TargetBackend, D> {
|
||||
let array = tensor.array;
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn allo() {
|
||||
let tensor: Tensor<NdArray<f32>, 2> = Tensor::ones([32, 32], &Default::default());
|
||||
let tensor_full: Tensor<NdArray<f64>, 2> = tensor.clone().cast();
|
||||
|
||||
tensor + tensor_full.cast();
|
||||
let tensor: Tensor<NdArray<f32>, 1> = Tensor::ones([32], &Default::default());
|
||||
let tensor_full: Tensor<NdArray<f64>, 1> = tensor.bridge(DoublePrecision::default());
|
||||
}
|
||||
|
||||
/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
|
||||
|
@ -63,13 +67,12 @@ impl<E: FloatNdArrayElement> Backend for NdArray<E> {
|
|||
type FullPrecisionBackend = NdArray<f32>;
|
||||
|
||||
type FloatTensorPrimitive<const D: usize> = NdArrayTensor<E, D>;
|
||||
type FloatElem = E;
|
||||
|
||||
type IntTensorPrimitive<const D: usize> = NdArrayTensor<i64, D>;
|
||||
type IntElem = i32;
|
||||
|
||||
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
|
||||
|
||||
type FloatElem = E;
|
||||
type IntElem = i64;
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::convert::TryInto;
|
||||
|
||||
use crate::backend::BackendMovement;
|
||||
use crate::backend::{BackendBridge, BackendBridgeSettings};
|
||||
use crate::check;
|
||||
use crate::check::TensorCheck;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::stats;
|
||||
use crate::tensor::{Data, Distribution, Shape};
|
||||
use crate::Int;
|
||||
use crate::Tensor;
|
||||
use crate::{check, Element};
|
||||
|
||||
impl<const D: usize, B> Tensor<B, D>
|
||||
where
|
||||
|
@ -32,16 +32,15 @@ where
|
|||
core::mem::swap(&mut tensor_new, self);
|
||||
}
|
||||
|
||||
/// Applies element wise exponential operation.
|
||||
pub fn cast<F: Element, I: Element>(
|
||||
/// TODO
|
||||
pub fn bridge<S: BackendBridgeSettings>(
|
||||
self,
|
||||
) -> Tensor<<B as BackendMovement<B, F, I>>::TargetBackend, D>
|
||||
settings: S,
|
||||
) -> Tensor<<B as BackendBridge<S>>::TargetBackend, D>
|
||||
where
|
||||
B: BackendMovement<B, F, I>,
|
||||
B: BackendBridge<S, InputBackend = B>,
|
||||
{
|
||||
Tensor::from_primitive(<B as BackendMovement<B, F, I>>::move_float(
|
||||
self.into_primitive(),
|
||||
))
|
||||
Tensor::from_primitive(B::bridge_float(self.into_primitive(), settings))
|
||||
}
|
||||
|
||||
/// Applies element wise exponential operation.
|
||||
|
|
|
@ -64,7 +64,10 @@ pub trait Backend:
|
|||
+ Sync
|
||||
+ core::fmt::Debug
|
||||
+ 'static
|
||||
+ BackendMovement<Self, f32, i32>
|
||||
+ BackendBridge<
|
||||
BackendPrecisionSettings<Self::FullPrecisionElem, IntElem<Self>>,
|
||||
InputBackend = Self,
|
||||
>
|
||||
{
|
||||
/// Device type.
|
||||
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
|
||||
|
@ -76,17 +79,19 @@ pub trait Backend:
|
|||
|
||||
/// Tensor primitive to be used for all float operations.
|
||||
type FloatTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
||||
/// Float element type.
|
||||
type FloatElem: Element;
|
||||
|
||||
/// Tensor primitive to be used for all int operations.
|
||||
type IntTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
||||
/// Int element type.
|
||||
type IntElem: Element;
|
||||
|
||||
/// Tensor primitive to be used for all bool operations.
|
||||
type BoolTensorPrimitive<const D: usize>: Clone + Send + Sync + 'static + core::fmt::Debug;
|
||||
|
||||
/// Float element type.
|
||||
type FloatElem: Element;
|
||||
|
||||
/// Int element type.
|
||||
type IntElem: Element;
|
||||
|
||||
/// If autodiff is enabled.
|
||||
fn ad_enabled() -> bool {
|
||||
false
|
||||
|
@ -100,39 +105,84 @@ pub trait Backend:
|
|||
|
||||
/// Sync the backend, ensure that all computation are finished.
|
||||
fn sync(_device: &Self::Device) {}
|
||||
}
|
||||
|
||||
fn move_float<const D: usize, TF: Element, TI: Element>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<<Self as BackendMovement<Self, TF, TI>>::TargetBackend, D>
|
||||
pub type BackendPrecision<B: Backend> = BackendPrecisionSettings<FloatElem<B>, IntElem<B>>;
|
||||
|
||||
pub type FullPrecision = BackendPrecisionSettings<f32, i32>;
|
||||
pub type DoublePrecision = BackendPrecisionSettings<f64, i64>;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct BackendPrecisionSettings<F: Element, I: Element> {
|
||||
_float: PhantomData<F>,
|
||||
_int: PhantomData<I>,
|
||||
}
|
||||
|
||||
pub trait BackendBridgeSettings {
|
||||
type FloatElem: Element;
|
||||
type IntElem: Element;
|
||||
}
|
||||
|
||||
impl<F: Element, I: Element> BackendBridgeSettings for BackendPrecisionSettings<F, I> {
|
||||
type FloatElem = F;
|
||||
type IntElem = I;
|
||||
}
|
||||
|
||||
pub trait BackendBridge<S: BackendBridgeSettings> {
|
||||
type InputBackend: Backend;
|
||||
type TargetBackend: Backend<FloatElem = S::FloatElem, IntElem = S::IntElem>;
|
||||
|
||||
fn bridge_float<const D: usize>(
|
||||
tensor: FloatTensor<Self::InputBackend, D>,
|
||||
settings: S,
|
||||
) -> FloatTensor<Self::TargetBackend, D>;
|
||||
fn bridge_int<const D: usize>(
|
||||
tensor: IntTensor<Self::InputBackend, D>,
|
||||
settings: S,
|
||||
) -> IntTensor<Self::TargetBackend, D>;
|
||||
}
|
||||
|
||||
|
||||
|
||||
impl<Target: Backend> BackendBridgeSettings for FromDataBackendBridge<Target> {
|
||||
type FloatElem = FloatElem<Target>;
|
||||
type IntElem = IntElem<Target>;
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct FromDataBackendBridge<Target: Backend> {
|
||||
device: Device<Target>,
|
||||
_b: PhantomData<Target>,
|
||||
}
|
||||
|
||||
impl<Input: Backend, Target: Backend> BackendBridge<FromDataBackendBridge<Target>> for Input {
|
||||
type InputBackend = Input;
|
||||
type TargetBackend = Target;
|
||||
|
||||
fn bridge_float<const D: usize>(
|
||||
tensor: FloatTensor<Self::InputBackend, D>,
|
||||
settings: FromDataBackendBridge<Target>,
|
||||
) -> FloatTensor<Self::TargetBackend, D> {
|
||||
let data = Input::float_into_data(tensor).read_sync().unwrap();
|
||||
Target::float_from_data(data.convert(), &settings.device)
|
||||
}
|
||||
|
||||
fn bridge_int<const D: usize>(
|
||||
tensor: IntTensor<Self::InputBackend, D>,
|
||||
settings: FromDataBackendBridge<Target>,
|
||||
) -> IntTensor<Self::TargetBackend, D> {
|
||||
let data = Input::int_into_data(tensor).read_sync().unwrap();
|
||||
Target::int_from_data(data.convert(), &settings.device)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn a_function<B: Backend>(
|
||||
tensor: FloatTensor<B, 1>,
|
||||
) -> FloatTensor<<B as BackendBridge<FullPrecision>>::TargetBackend, 1>
|
||||
where
|
||||
Self: BackendMovement<Self, TF, TI>,
|
||||
B: BackendBridge<FullPrecision, InputBackend = B>,
|
||||
{
|
||||
<Self as BackendMovement<Self, TF, TI>>::move_float(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Settings<F: Element, I: Element> {
|
||||
_f: PhantomData<F>,
|
||||
_i: PhantomData<I>,
|
||||
}
|
||||
|
||||
pub trait BackendMovement<B: Backend, TF, TI>
|
||||
where
|
||||
TF: Element,
|
||||
TI: Element,
|
||||
{
|
||||
type TargetBackend: Backend<FloatElem = TF, IntElem = TI>;
|
||||
|
||||
fn move_float<const D: usize>(tensor: FloatTensor<B, D>)
|
||||
-> FloatTensor<Self::TargetBackend, D>;
|
||||
}
|
||||
|
||||
pub fn a_function<B>(tensor: FloatTensor<B, 1>)
|
||||
where
|
||||
B: Backend,
|
||||
B: BackendMovement<B, f64, i64>,
|
||||
{
|
||||
let tensor_f64 = <B as Backend>::move_float::<1, f64, i64>(tensor);
|
||||
B::bridge_float(tensor, FullPrecision::default())
|
||||
}
|
||||
|
||||
/// Trait that allows a backend to support autodiff.
|
||||
|
|
Loading…
Reference in New Issue