mirror of https://github.com/tracel-ai/burn.git
feat: implement TensorBase for ADTensor
This commit is contained in:
parent
902f431fc1
commit
6ee628a748
|
@ -66,10 +66,6 @@ impl<P: HasAfEnum + Default + Copy + std::fmt::Debug, const D: usize> TensorBase
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from<O: TensorBase<P, D>>(other: O) -> Self {
|
|
||||||
Self::from_data(other.into_data(), Device::CPU)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn shape(&self) -> &Shape<D> {
|
fn shape(&self) -> &Shape<D> {
|
||||||
&self.shape
|
&self.shape
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,7 +80,6 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{
|
use crate::{
|
||||||
backend::tch::TchTensor,
|
backend::tch::TchTensor,
|
||||||
node_init,
|
|
||||||
tape::{Tape, TapeRef},
|
tape::{Tape, TapeRef},
|
||||||
Data, TensorBase,
|
Data, TensorBase,
|
||||||
};
|
};
|
||||||
|
@ -95,8 +94,8 @@ mod tests {
|
||||||
let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu);
|
let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu);
|
||||||
let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu);
|
let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu);
|
||||||
|
|
||||||
let tensor_ad_1 = ADTensor::new(node_init!(root tensor_1), tape.clone());
|
let tensor_ad_1 = ADTensor::from_tensor(tensor_1, tape.clone());
|
||||||
let tensor_ad_2 = ADTensor::new(node_init!(root tensor_2), tape.clone());
|
let tensor_ad_2 = ADTensor::from_tensor(tensor_2, tape.clone());
|
||||||
|
|
||||||
let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2);
|
let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2);
|
||||||
let data_ad_3 = tensor_ad_3.tensor().into_data();
|
let data_ad_3 = tensor_ad_3.tensor().into_data();
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
node::{NodeRef, Ones, Zeros},
|
node::{NodeRef, Ones, Zeros},
|
||||||
|
node_init,
|
||||||
ops::InitRecordedOps,
|
ops::InitRecordedOps,
|
||||||
tape::TapeRef,
|
tape::TapeRef,
|
||||||
FloatTensor, Shape,
|
FloatTensor, Shape, TensorBase,
|
||||||
};
|
};
|
||||||
use num_traits::Float;
|
use num_traits::Float;
|
||||||
|
|
||||||
|
@ -14,15 +15,29 @@ pub struct ADTensor<P, const D: usize, T> {
|
||||||
pub tape: TapeRef,
|
pub tape: TapeRef,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T, P, const D: usize> TensorBase<P, D> for ADTensor<P, D, T>
|
||||||
|
where
|
||||||
|
P: Float + Zeros<P> + Default + 'static,
|
||||||
|
T: FloatTensor<P, D> + Clone + Zeros<T> + Ones<T> + 'static,
|
||||||
|
{
|
||||||
|
fn shape(&self) -> &Shape<D> {
|
||||||
|
&self.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_data(self) -> crate::Data<P, D> {
|
||||||
|
self.tensor().into_data()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T, P, const D: usize> ADTensor<P, D, T>
|
impl<T, P, const D: usize> ADTensor<P, D, T>
|
||||||
where
|
where
|
||||||
P: Float + Zeros<P> + Default + 'static,
|
P: Float + Zeros<P> + Default + 'static,
|
||||||
T: FloatTensor<P, D> + Clone + Zeros<T> + Ones<T> + 'static,
|
T: FloatTensor<P, D> + Clone + Zeros<T> + Ones<T> + 'static,
|
||||||
{
|
{
|
||||||
pub fn new(node: NodeRef<T>, tape: TapeRef) -> Self {
|
pub fn from_tensor(tensor: T, tape: TapeRef) -> Self {
|
||||||
let tensor = node.borrow().value();
|
|
||||||
let shape = tensor.shape().clone();
|
let shape = tensor.shape().clone();
|
||||||
let kind = ADKind::new();
|
let kind = ADKind::new();
|
||||||
|
let node = node_init!(root tensor);
|
||||||
|
|
||||||
let ops = InitRecordedOps::new(node.clone());
|
let ops = InitRecordedOps::new(node.clone());
|
||||||
let ops = Box::new(ops);
|
let ops = Box::new(ops);
|
||||||
|
|
|
@ -79,11 +79,8 @@ impl<P: tch::kind::Element + Default, const D: usize> TchTensor<P, D> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<P, D> {
|
||||||
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TensorBase<P, D>
|
pub fn empty(shape: Shape<D>) -> Self {
|
||||||
for TchTensor<P, D>
|
|
||||||
{
|
|
||||||
fn empty(shape: Shape<D>) -> Self {
|
|
||||||
let shape_tch = TchShape::from(shape.clone());
|
let shape_tch = TchShape::from(shape.clone());
|
||||||
let device = tch::Device::Cpu;
|
let device = tch::Device::Cpu;
|
||||||
let kind = TchKind::new();
|
let kind = TchKind::new();
|
||||||
|
@ -96,11 +93,11 @@ impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
|
||||||
shape,
|
shape,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn from<O: TensorBase<P, D>>(other: O) -> Self {
|
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TensorBase<P, D>
|
||||||
Self::from_data(other.into_data(), tch::Device::Cpu)
|
for TchTensor<P, D>
|
||||||
}
|
{
|
||||||
|
|
||||||
fn shape(&self) -> &Shape<D> {
|
fn shape(&self) -> &Shape<D> {
|
||||||
&self.shape
|
&self.shape
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,8 +19,6 @@ pub trait FloatTensor<P: num_traits::Float, const D: usize>:
|
||||||
pub trait TensorBase<P, const D: usize> {
|
pub trait TensorBase<P, const D: usize> {
|
||||||
fn shape(&self) -> &Shape<D>;
|
fn shape(&self) -> &Shape<D>;
|
||||||
fn into_data(self) -> Data<P, D>;
|
fn into_data(self) -> Data<P, D>;
|
||||||
fn from<O: TensorBase<P, D>>(other: O) -> Self;
|
|
||||||
fn empty(shape: Shape<D>) -> Self;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait TensorOpsAdd<P, const D: usize>:
|
pub trait TensorOpsAdd<P, const D: usize>:
|
||||||
|
|
Loading…
Reference in New Issue