Experiments

This commit is contained in:
nathaniel 2024-02-29 14:34:44 -05:00
parent 4efc683df4
commit 613a98094e
4 changed files with 99 additions and 4 deletions

View File

@ -4,7 +4,10 @@ use crate::{
graph::backward::backward, graph::backward::backward,
tensor::AutodiffTensor, tensor::AutodiffTensor,
}; };
use burn_tensor::backend::{AutodiffBackend, Backend}; use burn_tensor::{
backend::{AutodiffBackend, Backend, BackendMovement},
Element,
};
use core::marker::PhantomData; use core::marker::PhantomData;
/// Enable auto-differentiation on a backend. /// Enable auto-differentiation on a backend.
@ -17,6 +20,22 @@ pub struct Autodiff<B, C = NoCheckpointing> {
_checkpoint_strategy: PhantomData<C>, _checkpoint_strategy: PhantomData<C>,
} }
impl<B, C, TF, TI> BackendMovement<Self, TF, TI> for Autodiff<B, C>
where
B: Backend + BackendMovement<B, TF, TI>,
C: CheckpointStrategy,
TF: Element,
TI: Element,
{
type TargetBackend = Autodiff<<B as BackendMovement<B, TF, TI>>::TargetBackend, C>;
fn move_float<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self, D>,
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
todo!()
}
}
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> { impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
type Device = B::Device; type Device = B::Device;

View File

@ -2,7 +2,10 @@ use crate::element::FloatNdArrayElement;
use crate::NdArrayTensor; use crate::NdArrayTensor;
use alloc::string::String; use alloc::string::String;
use burn_common::stub::Mutex; use burn_common::stub::Mutex;
use burn_tensor::backend::Backend; use burn_tensor::{
backend::{Backend, BackendMovement},
Tensor,
};
use core::marker::PhantomData; use core::marker::PhantomData;
use rand::{rngs::StdRng, SeedableRng}; use rand::{rngs::StdRng, SeedableRng};
@ -21,6 +24,30 @@ impl Default for NdArrayDevice {
} }
} }
struct NdArraySettings<F: FloatNdArrayElement> {
_float: PhantomData<F>,
}
impl<F: FloatNdArrayElement, TF: FloatNdArrayElement> BackendMovement<Self, TF, i32>
for NdArray<F>
{
type TargetBackend = NdArray<TF>;
fn move_float<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self, D>,
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
let array = tensor.array.mapv(|a| a.elem()).into_shared();
NdArrayTensor { array }
}
}
fn allo() {
let tensor: Tensor<NdArray<f32>, 2> = Tensor::ones([32, 32], &Default::default());
let tensor_full: Tensor<NdArray<f64>, 2> = tensor.clone().cast();
tensor + tensor_full.cast();
}
/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations. /// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
/// ///
/// This backend is compatible with CPUs and can be compiled for almost any platform, including /// This backend is compatible with CPUs and can be compiled for almost any platform, including
@ -39,7 +66,7 @@ impl<E: FloatNdArrayElement> Backend for NdArray<E> {
type FloatElem = E; type FloatElem = E;
type IntTensorPrimitive<const D: usize> = NdArrayTensor<i64, D>; type IntTensorPrimitive<const D: usize> = NdArrayTensor<i64, D>;
type IntElem = i64; type IntElem = i32;
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>; type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;

View File

@ -1,13 +1,14 @@
use alloc::vec::Vec; use alloc::vec::Vec;
use core::convert::TryInto; use core::convert::TryInto;
use crate::check; use crate::backend::BackendMovement;
use crate::check::TensorCheck; use crate::check::TensorCheck;
use crate::tensor::backend::Backend; use crate::tensor::backend::Backend;
use crate::tensor::stats; use crate::tensor::stats;
use crate::tensor::{Data, Distribution, Shape}; use crate::tensor::{Data, Distribution, Shape};
use crate::Int; use crate::Int;
use crate::Tensor; use crate::Tensor;
use crate::{check, Element};
impl<const D: usize, B> Tensor<B, D> impl<const D: usize, B> Tensor<B, D>
where where
@ -31,6 +32,18 @@ where
core::mem::swap(&mut tensor_new, self); core::mem::swap(&mut tensor_new, self);
} }
/// Applies element wise exponential operation.
pub fn cast<F: Element, I: Element>(
self,
) -> Tensor<<B as BackendMovement<B, F, I>>::TargetBackend, D>
where
B: BackendMovement<B, F, I>,
{
Tensor::from_primitive(<B as BackendMovement<B, F, I>>::move_float(
self.into_primitive(),
))
}
/// Applies element wise exponential operation. /// Applies element wise exponential operation.
/// ///
/// `y = e^x` /// `y = e^x`

View File

@ -1,3 +1,5 @@
use core::marker::PhantomData;
use alloc::string::String; use alloc::string::String;
use crate::ops::*; use crate::ops::*;
@ -62,6 +64,7 @@ pub trait Backend:
+ Sync + Sync
+ core::fmt::Debug + core::fmt::Debug
+ 'static + 'static
+ BackendMovement<Self, f32, i32>
{ {
/// Device type. /// Device type.
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync; type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
@ -97,6 +100,39 @@ pub trait Backend:
/// Sync the backend, ensure that all computation are finished. /// Sync the backend, ensure that all computation are finished.
fn sync(_device: &Self::Device) {} fn sync(_device: &Self::Device) {}
fn move_float<const D: usize, TF: Element, TI: Element>(
tensor: FloatTensor<Self, D>,
) -> FloatTensor<<Self as BackendMovement<Self, TF, TI>>::TargetBackend, D>
where
Self: BackendMovement<Self, TF, TI>,
{
<Self as BackendMovement<Self, TF, TI>>::move_float(tensor)
}
}
pub struct Settings<F: Element, I: Element> {
_f: PhantomData<F>,
_i: PhantomData<I>,
}
pub trait BackendMovement<B: Backend, TF, TI>
where
TF: Element,
TI: Element,
{
type TargetBackend: Backend<FloatElem = TF, IntElem = TI>;
fn move_float<const D: usize>(tensor: FloatTensor<B, D>)
-> FloatTensor<Self::TargetBackend, D>;
}
pub fn a_function<B>(tensor: FloatTensor<B, 1>)
where
B: Backend,
B: BackendMovement<B, f64, i64>,
{
let tensor_f64 = <B as Backend>::move_float::<1, f64, i64>(tensor);
} }
/// Trait that allows a backend to support autodiff. /// Trait that allows a backend to support autodiff.