mirror of https://github.com/tracel-ai/burn.git
Experiments
This commit is contained in:
parent
4efc683df4
commit
613a98094e
|
@ -4,7 +4,10 @@ use crate::{
|
||||||
graph::backward::backward,
|
graph::backward::backward,
|
||||||
tensor::AutodiffTensor,
|
tensor::AutodiffTensor,
|
||||||
};
|
};
|
||||||
use burn_tensor::backend::{AutodiffBackend, Backend};
|
use burn_tensor::{
|
||||||
|
backend::{AutodiffBackend, Backend, BackendMovement},
|
||||||
|
Element,
|
||||||
|
};
|
||||||
use core::marker::PhantomData;
|
use core::marker::PhantomData;
|
||||||
|
|
||||||
/// Enable auto-differentiation on a backend.
|
/// Enable auto-differentiation on a backend.
|
||||||
|
@ -17,6 +20,22 @@ pub struct Autodiff<B, C = NoCheckpointing> {
|
||||||
_checkpoint_strategy: PhantomData<C>,
|
_checkpoint_strategy: PhantomData<C>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<B, C, TF, TI> BackendMovement<Self, TF, TI> for Autodiff<B, C>
|
||||||
|
where
|
||||||
|
B: Backend + BackendMovement<B, TF, TI>,
|
||||||
|
C: CheckpointStrategy,
|
||||||
|
TF: Element,
|
||||||
|
TI: Element,
|
||||||
|
{
|
||||||
|
type TargetBackend = Autodiff<<B as BackendMovement<B, TF, TI>>::TargetBackend, C>;
|
||||||
|
|
||||||
|
fn move_float<const D: usize>(
|
||||||
|
tensor: burn_tensor::ops::FloatTensor<Self, D>,
|
||||||
|
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
||||||
type Device = B::Device;
|
type Device = B::Device;
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,10 @@ use crate::element::FloatNdArrayElement;
|
||||||
use crate::NdArrayTensor;
|
use crate::NdArrayTensor;
|
||||||
use alloc::string::String;
|
use alloc::string::String;
|
||||||
use burn_common::stub::Mutex;
|
use burn_common::stub::Mutex;
|
||||||
use burn_tensor::backend::Backend;
|
use burn_tensor::{
|
||||||
|
backend::{Backend, BackendMovement},
|
||||||
|
Tensor,
|
||||||
|
};
|
||||||
use core::marker::PhantomData;
|
use core::marker::PhantomData;
|
||||||
use rand::{rngs::StdRng, SeedableRng};
|
use rand::{rngs::StdRng, SeedableRng};
|
||||||
|
|
||||||
|
@ -21,6 +24,30 @@ impl Default for NdArrayDevice {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct NdArraySettings<F: FloatNdArrayElement> {
|
||||||
|
_float: PhantomData<F>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: FloatNdArrayElement, TF: FloatNdArrayElement> BackendMovement<Self, TF, i32>
|
||||||
|
for NdArray<F>
|
||||||
|
{
|
||||||
|
type TargetBackend = NdArray<TF>;
|
||||||
|
|
||||||
|
fn move_float<const D: usize>(
|
||||||
|
tensor: burn_tensor::ops::FloatTensor<Self, D>,
|
||||||
|
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
|
||||||
|
let array = tensor.array.mapv(|a| a.elem()).into_shared();
|
||||||
|
NdArrayTensor { array }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn allo() {
|
||||||
|
let tensor: Tensor<NdArray<f32>, 2> = Tensor::ones([32, 32], &Default::default());
|
||||||
|
let tensor_full: Tensor<NdArray<f64>, 2> = tensor.clone().cast();
|
||||||
|
|
||||||
|
tensor + tensor_full.cast();
|
||||||
|
}
|
||||||
|
|
||||||
/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
|
/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
|
||||||
///
|
///
|
||||||
/// This backend is compatible with CPUs and can be compiled for almost any platform, including
|
/// This backend is compatible with CPUs and can be compiled for almost any platform, including
|
||||||
|
@ -39,7 +66,7 @@ impl<E: FloatNdArrayElement> Backend for NdArray<E> {
|
||||||
type FloatElem = E;
|
type FloatElem = E;
|
||||||
|
|
||||||
type IntTensorPrimitive<const D: usize> = NdArrayTensor<i64, D>;
|
type IntTensorPrimitive<const D: usize> = NdArrayTensor<i64, D>;
|
||||||
type IntElem = i64;
|
type IntElem = i32;
|
||||||
|
|
||||||
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
|
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
use alloc::vec::Vec;
|
use alloc::vec::Vec;
|
||||||
use core::convert::TryInto;
|
use core::convert::TryInto;
|
||||||
|
|
||||||
use crate::check;
|
use crate::backend::BackendMovement;
|
||||||
use crate::check::TensorCheck;
|
use crate::check::TensorCheck;
|
||||||
use crate::tensor::backend::Backend;
|
use crate::tensor::backend::Backend;
|
||||||
use crate::tensor::stats;
|
use crate::tensor::stats;
|
||||||
use crate::tensor::{Data, Distribution, Shape};
|
use crate::tensor::{Data, Distribution, Shape};
|
||||||
use crate::Int;
|
use crate::Int;
|
||||||
use crate::Tensor;
|
use crate::Tensor;
|
||||||
|
use crate::{check, Element};
|
||||||
|
|
||||||
impl<const D: usize, B> Tensor<B, D>
|
impl<const D: usize, B> Tensor<B, D>
|
||||||
where
|
where
|
||||||
|
@ -31,6 +32,18 @@ where
|
||||||
core::mem::swap(&mut tensor_new, self);
|
core::mem::swap(&mut tensor_new, self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies element wise exponential operation.
|
||||||
|
pub fn cast<F: Element, I: Element>(
|
||||||
|
self,
|
||||||
|
) -> Tensor<<B as BackendMovement<B, F, I>>::TargetBackend, D>
|
||||||
|
where
|
||||||
|
B: BackendMovement<B, F, I>,
|
||||||
|
{
|
||||||
|
Tensor::from_primitive(<B as BackendMovement<B, F, I>>::move_float(
|
||||||
|
self.into_primitive(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
/// Applies element wise exponential operation.
|
/// Applies element wise exponential operation.
|
||||||
///
|
///
|
||||||
/// `y = e^x`
|
/// `y = e^x`
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use core::marker::PhantomData;
|
||||||
|
|
||||||
use alloc::string::String;
|
use alloc::string::String;
|
||||||
|
|
||||||
use crate::ops::*;
|
use crate::ops::*;
|
||||||
|
@ -62,6 +64,7 @@ pub trait Backend:
|
||||||
+ Sync
|
+ Sync
|
||||||
+ core::fmt::Debug
|
+ core::fmt::Debug
|
||||||
+ 'static
|
+ 'static
|
||||||
|
+ BackendMovement<Self, f32, i32>
|
||||||
{
|
{
|
||||||
/// Device type.
|
/// Device type.
|
||||||
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
|
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
|
||||||
|
@ -97,6 +100,39 @@ pub trait Backend:
|
||||||
|
|
||||||
/// Sync the backend, ensure that all computation are finished.
|
/// Sync the backend, ensure that all computation are finished.
|
||||||
fn sync(_device: &Self::Device) {}
|
fn sync(_device: &Self::Device) {}
|
||||||
|
|
||||||
|
fn move_float<const D: usize, TF: Element, TI: Element>(
|
||||||
|
tensor: FloatTensor<Self, D>,
|
||||||
|
) -> FloatTensor<<Self as BackendMovement<Self, TF, TI>>::TargetBackend, D>
|
||||||
|
where
|
||||||
|
Self: BackendMovement<Self, TF, TI>,
|
||||||
|
{
|
||||||
|
<Self as BackendMovement<Self, TF, TI>>::move_float(tensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Settings<F: Element, I: Element> {
|
||||||
|
_f: PhantomData<F>,
|
||||||
|
_i: PhantomData<I>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait BackendMovement<B: Backend, TF, TI>
|
||||||
|
where
|
||||||
|
TF: Element,
|
||||||
|
TI: Element,
|
||||||
|
{
|
||||||
|
type TargetBackend: Backend<FloatElem = TF, IntElem = TI>;
|
||||||
|
|
||||||
|
fn move_float<const D: usize>(tensor: FloatTensor<B, D>)
|
||||||
|
-> FloatTensor<Self::TargetBackend, D>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn a_function<B>(tensor: FloatTensor<B, 1>)
|
||||||
|
where
|
||||||
|
B: Backend,
|
||||||
|
B: BackendMovement<B, f64, i64>,
|
||||||
|
{
|
||||||
|
let tensor_f64 = <B as Backend>::move_float::<1, f64, i64>(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait that allows a backend to support autodiff.
|
/// Trait that allows a backend to support autodiff.
|
||||||
|
|
Loading…
Reference in New Issue