mirror of https://github.com/tracel-ai/burn.git
Experiments
This commit is contained in:
parent
4efc683df4
commit
613a98094e
|
@ -4,7 +4,10 @@ use crate::{
|
|||
graph::backward::backward,
|
||||
tensor::AutodiffTensor,
|
||||
};
|
||||
use burn_tensor::backend::{AutodiffBackend, Backend};
|
||||
use burn_tensor::{
|
||||
backend::{AutodiffBackend, Backend, BackendMovement},
|
||||
Element,
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
/// Enable auto-differentiation on a backend.
|
||||
|
@ -17,6 +20,22 @@ pub struct Autodiff<B, C = NoCheckpointing> {
|
|||
_checkpoint_strategy: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<B, C, TF, TI> BackendMovement<Self, TF, TI> for Autodiff<B, C>
|
||||
where
|
||||
B: Backend + BackendMovement<B, TF, TI>,
|
||||
C: CheckpointStrategy,
|
||||
TF: Element,
|
||||
TI: Element,
|
||||
{
|
||||
type TargetBackend = Autodiff<<B as BackendMovement<B, TF, TI>>::TargetBackend, C>;
|
||||
|
||||
fn move_float<const D: usize>(
|
||||
tensor: burn_tensor::ops::FloatTensor<Self, D>,
|
||||
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
|
||||
type Device = B::Device;
|
||||
|
||||
|
|
|
@ -2,7 +2,10 @@ use crate::element::FloatNdArrayElement;
|
|||
use crate::NdArrayTensor;
|
||||
use alloc::string::String;
|
||||
use burn_common::stub::Mutex;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{
|
||||
backend::{Backend, BackendMovement},
|
||||
Tensor,
|
||||
};
|
||||
use core::marker::PhantomData;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
|
@ -21,6 +24,30 @@ impl Default for NdArrayDevice {
|
|||
}
|
||||
}
|
||||
|
||||
struct NdArraySettings<F: FloatNdArrayElement> {
|
||||
_float: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: FloatNdArrayElement, TF: FloatNdArrayElement> BackendMovement<Self, TF, i32>
|
||||
for NdArray<F>
|
||||
{
|
||||
type TargetBackend = NdArray<TF>;
|
||||
|
||||
fn move_float<const D: usize>(
|
||||
tensor: burn_tensor::ops::FloatTensor<Self, D>,
|
||||
) -> burn_tensor::ops::FloatTensor<Self::TargetBackend, D> {
|
||||
let array = tensor.array.mapv(|a| a.elem()).into_shared();
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
}
|
||||
|
||||
fn allo() {
|
||||
let tensor: Tensor<NdArray<f32>, 2> = Tensor::ones([32, 32], &Default::default());
|
||||
let tensor_full: Tensor<NdArray<f64>, 2> = tensor.clone().cast();
|
||||
|
||||
tensor + tensor_full.cast();
|
||||
}
|
||||
|
||||
/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
|
||||
///
|
||||
/// This backend is compatible with CPUs and can be compiled for almost any platform, including
|
||||
|
@ -39,7 +66,7 @@ impl<E: FloatNdArrayElement> Backend for NdArray<E> {
|
|||
type FloatElem = E;
|
||||
|
||||
type IntTensorPrimitive<const D: usize> = NdArrayTensor<i64, D>;
|
||||
type IntElem = i64;
|
||||
type IntElem = i32;
|
||||
|
||||
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
|
||||
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::convert::TryInto;
|
||||
|
||||
use crate::check;
|
||||
use crate::backend::BackendMovement;
|
||||
use crate::check::TensorCheck;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::stats;
|
||||
use crate::tensor::{Data, Distribution, Shape};
|
||||
use crate::Int;
|
||||
use crate::Tensor;
|
||||
use crate::{check, Element};
|
||||
|
||||
impl<const D: usize, B> Tensor<B, D>
|
||||
where
|
||||
|
@ -31,6 +32,18 @@ where
|
|||
core::mem::swap(&mut tensor_new, self);
|
||||
}
|
||||
|
||||
/// Applies element wise exponential operation.
|
||||
pub fn cast<F: Element, I: Element>(
|
||||
self,
|
||||
) -> Tensor<<B as BackendMovement<B, F, I>>::TargetBackend, D>
|
||||
where
|
||||
B: BackendMovement<B, F, I>,
|
||||
{
|
||||
Tensor::from_primitive(<B as BackendMovement<B, F, I>>::move_float(
|
||||
self.into_primitive(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Applies element wise exponential operation.
|
||||
///
|
||||
/// `y = e^x`
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use alloc::string::String;
|
||||
|
||||
use crate::ops::*;
|
||||
|
@ -62,6 +64,7 @@ pub trait Backend:
|
|||
+ Sync
|
||||
+ core::fmt::Debug
|
||||
+ 'static
|
||||
+ BackendMovement<Self, f32, i32>
|
||||
{
|
||||
/// Device type.
|
||||
type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync;
|
||||
|
@ -97,6 +100,39 @@ pub trait Backend:
|
|||
|
||||
/// Sync the backend, ensure that all computation are finished.
|
||||
fn sync(_device: &Self::Device) {}
|
||||
|
||||
fn move_float<const D: usize, TF: Element, TI: Element>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<<Self as BackendMovement<Self, TF, TI>>::TargetBackend, D>
|
||||
where
|
||||
Self: BackendMovement<Self, TF, TI>,
|
||||
{
|
||||
<Self as BackendMovement<Self, TF, TI>>::move_float(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Settings<F: Element, I: Element> {
|
||||
_f: PhantomData<F>,
|
||||
_i: PhantomData<I>,
|
||||
}
|
||||
|
||||
pub trait BackendMovement<B: Backend, TF, TI>
|
||||
where
|
||||
TF: Element,
|
||||
TI: Element,
|
||||
{
|
||||
type TargetBackend: Backend<FloatElem = TF, IntElem = TI>;
|
||||
|
||||
fn move_float<const D: usize>(tensor: FloatTensor<B, D>)
|
||||
-> FloatTensor<Self::TargetBackend, D>;
|
||||
}
|
||||
|
||||
pub fn a_function<B>(tensor: FloatTensor<B, 1>)
|
||||
where
|
||||
B: Backend,
|
||||
B: BackendMovement<B, f64, i64>,
|
||||
{
|
||||
let tensor_f64 = <B as Backend>::move_float::<1, f64, i64>(tensor);
|
||||
}
|
||||
|
||||
/// Trait that allows a backend to support autodiff.
|
||||
|
|
Loading…
Reference in New Issue