feat: implement TensorBase for ADTensor

This commit is contained in:
nathaniel 2022-07-18 19:31:00 -04:00
parent 902f431fc1
commit 6ee628a748
5 changed files with 26 additions and 21 deletions

View File

@ -66,10 +66,6 @@ impl<P: HasAfEnum + Default + Copy + std::fmt::Debug, const D: usize> TensorBase
} }
} }
fn from<O: TensorBase<P, D>>(other: O) -> Self {
Self::from_data(other.into_data(), Device::CPU)
}
fn shape(&self) -> &Shape<D> { fn shape(&self) -> &Shape<D> {
&self.shape &self.shape
} }

View File

@ -80,7 +80,6 @@ mod tests {
use super::*; use super::*;
use crate::{ use crate::{
backend::tch::TchTensor, backend::tch::TchTensor,
node_init,
tape::{Tape, TapeRef}, tape::{Tape, TapeRef},
Data, TensorBase, Data, TensorBase,
}; };
@ -95,8 +94,8 @@ mod tests {
let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu); let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu);
let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu); let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu);
let tensor_ad_1 = ADTensor::new(node_init!(root tensor_1), tape.clone()); let tensor_ad_1 = ADTensor::from_tensor(tensor_1, tape.clone());
let tensor_ad_2 = ADTensor::new(node_init!(root tensor_2), tape.clone()); let tensor_ad_2 = ADTensor::from_tensor(tensor_2, tape.clone());
let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2); let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2);
let data_ad_3 = tensor_ad_3.tensor().into_data(); let data_ad_3 = tensor_ad_3.tensor().into_data();

View File

@ -1,8 +1,9 @@
use crate::{ use crate::{
node::{NodeRef, Ones, Zeros}, node::{NodeRef, Ones, Zeros},
node_init,
ops::InitRecordedOps, ops::InitRecordedOps,
tape::TapeRef, tape::TapeRef,
FloatTensor, Shape, FloatTensor, Shape, TensorBase,
}; };
use num_traits::Float; use num_traits::Float;
@ -14,15 +15,29 @@ pub struct ADTensor<P, const D: usize, T> {
pub tape: TapeRef, pub tape: TapeRef,
} }
impl<T, P, const D: usize> TensorBase<P, D> for ADTensor<P, D, T>
where
P: Float + Zeros<P> + Default + 'static,
T: FloatTensor<P, D> + Clone + Zeros<T> + Ones<T> + 'static,
{
fn shape(&self) -> &Shape<D> {
&self.shape
}
fn into_data(self) -> crate::Data<P, D> {
self.tensor().into_data()
}
}
impl<T, P, const D: usize> ADTensor<P, D, T> impl<T, P, const D: usize> ADTensor<P, D, T>
where where
P: Float + Zeros<P> + Default + 'static, P: Float + Zeros<P> + Default + 'static,
T: FloatTensor<P, D> + Clone + Zeros<T> + Ones<T> + 'static, T: FloatTensor<P, D> + Clone + Zeros<T> + Ones<T> + 'static,
{ {
pub fn new(node: NodeRef<T>, tape: TapeRef) -> Self { pub fn from_tensor(tensor: T, tape: TapeRef) -> Self {
let tensor = node.borrow().value();
let shape = tensor.shape().clone(); let shape = tensor.shape().clone();
let kind = ADKind::new(); let kind = ADKind::new();
let node = node_init!(root tensor);
let ops = InitRecordedOps::new(node.clone()); let ops = InitRecordedOps::new(node.clone());
let ops = Box::new(ops); let ops = Box::new(ops);

View File

@ -79,11 +79,8 @@ impl<P: tch::kind::Element + Default, const D: usize> TchTensor<P, D> {
} }
} }
} }
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<P, D> {
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TensorBase<P, D> pub fn empty(shape: Shape<D>) -> Self {
for TchTensor<P, D>
{
fn empty(shape: Shape<D>) -> Self {
let shape_tch = TchShape::from(shape.clone()); let shape_tch = TchShape::from(shape.clone());
let device = tch::Device::Cpu; let device = tch::Device::Cpu;
let kind = TchKind::new(); let kind = TchKind::new();
@ -96,11 +93,11 @@ impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
shape, shape,
} }
} }
}
fn from<O: TensorBase<P, D>>(other: O) -> Self { impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TensorBase<P, D>
Self::from_data(other.into_data(), tch::Device::Cpu) for TchTensor<P, D>
} {
fn shape(&self) -> &Shape<D> { fn shape(&self) -> &Shape<D> {
&self.shape &self.shape
} }

View File

@ -19,8 +19,6 @@ pub trait FloatTensor<P: num_traits::Float, const D: usize>:
pub trait TensorBase<P, const D: usize> { pub trait TensorBase<P, const D: usize> {
fn shape(&self) -> &Shape<D>; fn shape(&self) -> &Shape<D>;
fn into_data(self) -> Data<P, D>; fn into_data(self) -> Data<P, D>;
fn from<O: TensorBase<P, D>>(other: O) -> Self;
fn empty(shape: Shape<D>) -> Self;
} }
pub trait TensorOpsAdd<P, const D: usize>: pub trait TensorOpsAdd<P, const D: usize>: