mirror of https://github.com/tracel-ai/burn.git
feat: implement TensorBase for ADTensor
This commit is contained in:
parent
902f431fc1
commit
6ee628a748
|
@ -66,10 +66,6 @@ impl<P: HasAfEnum + Default + Copy + std::fmt::Debug, const D: usize> TensorBase
|
|||
}
|
||||
}
|
||||
|
||||
fn from<O: TensorBase<P, D>>(other: O) -> Self {
|
||||
Self::from_data(other.into_data(), Device::CPU)
|
||||
}
|
||||
|
||||
fn shape(&self) -> &Shape<D> {
|
||||
&self.shape
|
||||
}
|
||||
|
|
|
@ -80,7 +80,6 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::{
|
||||
backend::tch::TchTensor,
|
||||
node_init,
|
||||
tape::{Tape, TapeRef},
|
||||
Data, TensorBase,
|
||||
};
|
||||
|
@ -95,8 +94,8 @@ mod tests {
|
|||
let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu);
|
||||
let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu);
|
||||
|
||||
let tensor_ad_1 = ADTensor::new(node_init!(root tensor_1), tape.clone());
|
||||
let tensor_ad_2 = ADTensor::new(node_init!(root tensor_2), tape.clone());
|
||||
let tensor_ad_1 = ADTensor::from_tensor(tensor_1, tape.clone());
|
||||
let tensor_ad_2 = ADTensor::from_tensor(tensor_2, tape.clone());
|
||||
|
||||
let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2);
|
||||
let data_ad_3 = tensor_ad_3.tensor().into_data();
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use crate::{
|
||||
node::{NodeRef, Ones, Zeros},
|
||||
node_init,
|
||||
ops::InitRecordedOps,
|
||||
tape::TapeRef,
|
||||
FloatTensor, Shape,
|
||||
FloatTensor, Shape, TensorBase,
|
||||
};
|
||||
use num_traits::Float;
|
||||
|
||||
|
@ -14,15 +15,29 @@ pub struct ADTensor<P, const D: usize, T> {
|
|||
pub tape: TapeRef,
|
||||
}
|
||||
|
||||
impl<T, P, const D: usize> TensorBase<P, D> for ADTensor<P, D, T>
|
||||
where
|
||||
P: Float + Zeros<P> + Default + 'static,
|
||||
T: FloatTensor<P, D> + Clone + Zeros<T> + Ones<T> + 'static,
|
||||
{
|
||||
fn shape(&self) -> &Shape<D> {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
fn into_data(self) -> crate::Data<P, D> {
|
||||
self.tensor().into_data()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, P, const D: usize> ADTensor<P, D, T>
|
||||
where
|
||||
P: Float + Zeros<P> + Default + 'static,
|
||||
T: FloatTensor<P, D> + Clone + Zeros<T> + Ones<T> + 'static,
|
||||
{
|
||||
pub fn new(node: NodeRef<T>, tape: TapeRef) -> Self {
|
||||
let tensor = node.borrow().value();
|
||||
pub fn from_tensor(tensor: T, tape: TapeRef) -> Self {
|
||||
let shape = tensor.shape().clone();
|
||||
let kind = ADKind::new();
|
||||
let node = node_init!(root tensor);
|
||||
|
||||
let ops = InitRecordedOps::new(node.clone());
|
||||
let ops = Box::new(ops);
|
||||
|
|
|
@ -79,11 +79,8 @@ impl<P: tch::kind::Element + Default, const D: usize> TchTensor<P, D> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TensorBase<P, D>
|
||||
for TchTensor<P, D>
|
||||
{
|
||||
fn empty(shape: Shape<D>) -> Self {
|
||||
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<P, D> {
|
||||
pub fn empty(shape: Shape<D>) -> Self {
|
||||
let shape_tch = TchShape::from(shape.clone());
|
||||
let device = tch::Device::Cpu;
|
||||
let kind = TchKind::new();
|
||||
|
@ -96,11 +93,11 @@ impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
|
|||
shape,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from<O: TensorBase<P, D>>(other: O) -> Self {
|
||||
Self::from_data(other.into_data(), tch::Device::Cpu)
|
||||
}
|
||||
|
||||
impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TensorBase<P, D>
|
||||
for TchTensor<P, D>
|
||||
{
|
||||
fn shape(&self) -> &Shape<D> {
|
||||
&self.shape
|
||||
}
|
||||
|
|
|
@ -19,8 +19,6 @@ pub trait FloatTensor<P: num_traits::Float, const D: usize>:
|
|||
pub trait TensorBase<P, const D: usize> {
|
||||
fn shape(&self) -> &Shape<D>;
|
||||
fn into_data(self) -> Data<P, D>;
|
||||
fn from<O: TensorBase<P, D>>(other: O) -> Self;
|
||||
fn empty(shape: Shape<D>) -> Self;
|
||||
}
|
||||
|
||||
pub trait TensorOpsAdd<P, const D: usize>:
|
||||
|
|
Loading…
Reference in New Issue