diff --git a/crates/burn-autodiff/src/backend.rs b/crates/burn-autodiff/src/backend.rs index 8c3e87b20..da95f463a 100644 --- a/crates/burn-autodiff/src/backend.rs +++ b/crates/burn-autodiff/src/backend.rs @@ -4,7 +4,10 @@ use crate::{ graph::backward::backward, tensor::AutodiffTensor, }; -use burn_tensor::backend::{AutodiffBackend, Backend}; +use burn_tensor::{ + backend::{AutodiffBackend, Backend, BackendMovement}, + Element, +}; use core::marker::PhantomData; /// Enable auto-differentiation on a backend. @@ -17,6 +20,22 @@ pub struct Autodiff { _checkpoint_strategy: PhantomData, } +impl BackendMovement for Autodiff +where + B: Backend + BackendMovement, + C: CheckpointStrategy, + TF: Element, + TI: Element, +{ + type TargetBackend = Autodiff<>::TargetBackend, C>; + + fn move_float( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + todo!() + } +} + impl Backend for Autodiff { type Device = B::Device; diff --git a/crates/burn-ndarray/src/backend.rs b/crates/burn-ndarray/src/backend.rs index 0bfdc766e..6c8e7ac88 100644 --- a/crates/burn-ndarray/src/backend.rs +++ b/crates/burn-ndarray/src/backend.rs @@ -2,7 +2,10 @@ use crate::element::FloatNdArrayElement; use crate::NdArrayTensor; use alloc::string::String; use burn_common::stub::Mutex; -use burn_tensor::backend::Backend; +use burn_tensor::{ + backend::{Backend, BackendMovement}, + Tensor, +}; use core::marker::PhantomData; use rand::{rngs::StdRng, SeedableRng}; @@ -21,6 +24,30 @@ impl Default for NdArrayDevice { } } +struct NdArraySettings { + _float: PhantomData, +} + +impl BackendMovement + for NdArray +{ + type TargetBackend = NdArray; + + fn move_float( + tensor: burn_tensor::ops::FloatTensor, + ) -> burn_tensor::ops::FloatTensor { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + NdArrayTensor { array } + } +} + +fn allo() { + let tensor: Tensor, 2> = Tensor::ones([32, 32], &Default::default()); + let tensor_full: Tensor, 2> = tensor.clone().cast(); + + tensor + tensor_full.cast(); +} + /// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations. /// /// This backend is compatible with CPUs and can be compiled for almost any platform, including @@ -39,7 +66,7 @@ impl Backend for NdArray { type FloatElem = E; type IntTensorPrimitive = NdArrayTensor; - type IntElem = i64; + type IntElem = i32; type BoolTensorPrimitive = NdArrayTensor; diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index a6a753f0b..4660f14f6 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -1,13 +1,14 @@ use alloc::vec::Vec; use core::convert::TryInto; -use crate::check; +use crate::backend::BackendMovement; use crate::check::TensorCheck; use crate::tensor::backend::Backend; use crate::tensor::stats; use crate::tensor::{Data, Distribution, Shape}; use crate::Int; use crate::Tensor; +use crate::{check, Element}; impl Tensor where @@ -31,6 +32,18 @@ where core::mem::swap(&mut tensor_new, self); } + /// Applies element wise exponential operation. + pub fn cast( + self, + ) -> Tensor<>::TargetBackend, D> + where + B: BackendMovement, + { + Tensor::from_primitive(>::move_float( + self.into_primitive(), + )) + } + /// Applies element wise exponential operation. /// /// `y = e^x` diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index d3c0b38f3..705271c5c 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -1,3 +1,5 @@ +use core::marker::PhantomData; + use alloc::string::String; use crate::ops::*; @@ -62,6 +64,7 @@ pub trait Backend: + Sync + core::fmt::Debug + 'static + + BackendMovement { /// Device type. type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync; @@ -97,6 +100,39 @@ pub trait Backend: /// Sync the backend, ensure that all computation are finished. fn sync(_device: &Self::Device) {} + + fn move_float( + tensor: FloatTensor, + ) -> FloatTensor<>::TargetBackend, D> + where + Self: BackendMovement, + { + >::move_float(tensor) + } +} + +pub struct Settings { + _f: PhantomData, + _i: PhantomData, +} + +pub trait BackendMovement +where + TF: Element, + TI: Element, +{ + type TargetBackend: Backend; + + fn move_float(tensor: FloatTensor) + -> FloatTensor; +} + +pub fn a_function(tensor: FloatTensor) +where + B: Backend, + B: BackendMovement, +{ + let tensor_f64 = ::move_float::<1, f64, i64>(tensor); } /// Trait that allows a backend to support autodiff.