From 6ee628a748c6560b5962e40795c616c7ccbde96c Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 18 Jul 2022 19:31:00 -0400 Subject: [PATCH] feat: implement TensorBase for ADTensor --- .../src/tensor/backend/arrayfire/tensor.rs | 4 ---- .../src/tensor/backend/autodiff/ops/mul.rs | 5 ++--- .../src/tensor/backend/autodiff/tensor.rs | 21 ++++++++++++++++--- burn-tensor/src/tensor/backend/tch/tensor.rs | 15 ++++++------- burn-tensor/src/tensor/tensor.rs | 2 -- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/burn-tensor/src/tensor/backend/arrayfire/tensor.rs b/burn-tensor/src/tensor/backend/arrayfire/tensor.rs index 26166a081..ee3269d72 100644 --- a/burn-tensor/src/tensor/backend/arrayfire/tensor.rs +++ b/burn-tensor/src/tensor/backend/arrayfire/tensor.rs @@ -66,10 +66,6 @@ impl TensorBase } } - fn from>(other: O) -> Self { - Self::from_data(other.into_data(), Device::CPU) - } - fn shape(&self) -> &Shape { &self.shape } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs index df87d6884..c3a44d5ab 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs @@ -80,7 +80,6 @@ mod tests { use super::*; use crate::{ backend::tch::TchTensor, - node_init, tape::{Tape, TapeRef}, Data, TensorBase, }; @@ -95,8 +94,8 @@ mod tests { let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu); let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu); - let tensor_ad_1 = ADTensor::new(node_init!(root tensor_1), tape.clone()); - let tensor_ad_2 = ADTensor::new(node_init!(root tensor_2), tape.clone()); + let tensor_ad_1 = ADTensor::from_tensor(tensor_1, tape.clone()); + let tensor_ad_2 = ADTensor::from_tensor(tensor_2, tape.clone()); let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2); let data_ad_3 = tensor_ad_3.tensor().into_data(); diff --git a/burn-tensor/src/tensor/backend/autodiff/tensor.rs b/burn-tensor/src/tensor/backend/autodiff/tensor.rs index 01794388e..1d7750f75 100644 --- a/burn-tensor/src/tensor/backend/autodiff/tensor.rs +++ b/burn-tensor/src/tensor/backend/autodiff/tensor.rs @@ -1,8 +1,9 @@ use crate::{ node::{NodeRef, Ones, Zeros}, + node_init, ops::InitRecordedOps, tape::TapeRef, - FloatTensor, Shape, + FloatTensor, Shape, TensorBase, }; use num_traits::Float; @@ -14,15 +15,29 @@ pub struct ADTensor { pub tape: TapeRef, } +impl TensorBase for ADTensor +where + P: Float + Zeros

+ Default + 'static, + T: FloatTensor + Clone + Zeros + Ones + 'static, +{ + fn shape(&self) -> &Shape { + &self.shape + } + + fn into_data(self) -> crate::Data { + self.tensor().into_data() + } +} + impl ADTensor where P: Float + Zeros

+ Default + 'static, T: FloatTensor + Clone + Zeros + Ones + 'static, { - pub fn new(node: NodeRef, tape: TapeRef) -> Self { - let tensor = node.borrow().value(); + pub fn from_tensor(tensor: T, tape: TapeRef) -> Self { let shape = tensor.shape().clone(); let kind = ADKind::new(); + let node = node_init!(root tensor); let ops = InitRecordedOps::new(node.clone()); let ops = Box::new(ops); diff --git a/burn-tensor/src/tensor/backend/tch/tensor.rs b/burn-tensor/src/tensor/backend/tch/tensor.rs index c8bb00d2c..4cb1bda37 100644 --- a/burn-tensor/src/tensor/backend/tch/tensor.rs +++ b/burn-tensor/src/tensor/backend/tch/tensor.rs @@ -79,11 +79,8 @@ impl TchTensor { } } } - -impl TensorBase - for TchTensor -{ - fn empty(shape: Shape) -> Self { +impl TchTensor { + pub fn empty(shape: Shape) -> Self { let shape_tch = TchShape::from(shape.clone()); let device = tch::Device::Cpu; let kind = TchKind::new(); @@ -96,11 +93,11 @@ impl T shape, } } +} - fn from>(other: O) -> Self { - Self::from_data(other.into_data(), tch::Device::Cpu) - } - +impl TensorBase + for TchTensor +{ fn shape(&self) -> &Shape { &self.shape } diff --git a/burn-tensor/src/tensor/tensor.rs b/burn-tensor/src/tensor/tensor.rs index fdda66e4c..5747b61ca 100644 --- a/burn-tensor/src/tensor/tensor.rs +++ b/burn-tensor/src/tensor/tensor.rs @@ -19,8 +19,6 @@ pub trait FloatTensor: pub trait TensorBase { fn shape(&self) -> &Shape; fn into_data(self) -> Data; - fn from>(other: O) -> Self; - fn empty(shape: Shape) -> Self; } pub trait TensorOpsAdd: